# 测试

In [None]:
from dataclasses import dataclass
import time
import logging
import numpy as np
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
import torch
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from utils.loader import data_generator, DenoisingDataset
from utils.run import DnCNN, sum_squared_error


logging.basicConfig(filename='logs/test.log',
                    filemode="w",
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s-%(funcName)s',
                    level=logging.DEBUG)

In [None]:
@dataclass
class Config:
    epochs: int = 180
    root: str = "data/Train400"
    lr: float = 1e-3
    batch_size: int = 128
    sigma: int = 25 # noise level

config = Config()

In [None]:
cuda = torch.cuda.is_available()
model = DnCNN()
model.train()
# criterion = nn.MSELoss(reduction = 'sum')  # PyTorch 0.4.1
criterion = sum_squared_error()
if cuda:
    model = model.cuda()

data = data_generator(config.root, batch_size=config.batch_size)
xs = data.astype('float32')/255.0
xs = torch.from_numpy(xs.transpose((0, 3, 1, 2)))  # tensor of the clean patches, NXCXHXW
DDataset = DenoisingDataset(xs, config.sigma)
DLoader = DataLoader(dataset=DDataset, num_workers=4, drop_last=True, batch_size=config.batch_size, shuffle=True)
optimizer = optim.Adam(model.parameters(), lr=config.lr)
scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.2)  # learning rates

In [4]:
for epoch in range(config.epochs):
    epoch_loss = 0
    start_time = time.time()
    for n_count, batch_yx in enumerate(DLoader):
        optimizer.zero_grad()
        if cuda:
            batch_x, batch_y = batch_yx[1].cuda(), batch_yx[0].cuda()
        else:
            batch_x, batch_y = batch_yx[1], batch_yx[0]
        loss = criterion(model(batch_y), batch_x)
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()
        if n_count % 10 == 0:
            logging.info(f'{epoch+1:4d} {n_count:4d} / {xs.size(0)//config.batch_size:4d} loss = {loss.item()/config.batch_size:2.4f}')
    # scheduler.step(epoch)  # step to the learning rate in this epcoh
    elapsed_time = time.time() - start_time
    logging.info(f'epcoh = {epoch+1:4d} , loss =  {epoch_loss/n_count:4.4f} , time = {elapsed_time:4.2f} s')
    np.savetxt('build/train_result.txt', np.hstack((epoch+1, epoch_loss/n_count, elapsed_time)), fmt='%2.4f')
    # torch.save(model.state_dict(), os.path.join(save_dir, 'model_%03d.pth' % (epoch+1)))
    torch.save(model, f"models/model_{epoch:03d}.pth")