# Train set AVG PSNR= 25.2899

# Test set AVG PSNR=26.9565

# A: Full Dose,  B: Quarter Dose

In [11]:
import torch
import itertools
from model import UNetGenerator,get_norm_layer,init_weight
from utils import LambdaLR
from torchvision import transforms
import os
import torch.nn as nn
from tqdm.auto import tqdm
from dataset import makeDataset
from torch.utils.data import DataLoader
from torchmetrics import PeakSignalNoiseRatio
from torch.utils.tensorboard import SummaryWriter
import pickle
import numpy as np

In [12]:
def normalize(tensor,zero2one=False):
    shape=tensor.shape
    tensor_norm=tensor.reshape([shape[0],shape[1],-1])
    tensor_max=tensor_norm.max(dim=-1,keepdim=True)[0]
    tensor_min=tensor_norm.min(dim=-1,keepdim=True)[0]
    tensor_norm=((tensor_norm-tensor_min)/(tensor_max-tensor_min))
    if zero2one:
        tensor_norm=tensor_norm.reshape(shape)
    else:
        tensor_norm=((tensor_norm-0.5)*2.0).reshape(shape)
    return tensor_norm

In [13]:
norm_type='batch'
epochs=80
lr=0.0002
ngf=64
use_droplayer=False
batch_size=8
norm_layer=get_norm_layer(norm_type)
DEVICE='cuda' if torch.cuda.is_available() else 'cpu'

In [14]:
denoising_Unet=UNetGenerator(in_c=1,out_c=1,ngf=ngf,norm_layer=norm_layer,use_drop=use_droplayer)

In [15]:
denoising_Unet.to(DEVICE)
denoising_Unet.apply(init_weight)
if torch.cuda.device_count() >1:
    print('training with {} GPUs'.format(torch.cuda.device_count()))
    denoising_Unet=torch.nn.DataParallel(denoising_Unet)

In [16]:
psnr=PeakSignalNoiseRatio(data_range=1,dim=[-1,-2,-3])
writer=SummaryWriter()
trainTransforms=transforms.Compose([
    transforms.RandomCrop((256,256)),
])

In [17]:
trainDataset=makeDataset(trainTransforms,alignB=True,sameTransformB=True)
validDataset=makeDataset(transform=None,mode='test',alignB=True,sameTransformB=True)

trainLoader=DataLoader(trainDataset,batch_size=batch_size,shuffle=True)
validLoader=DataLoader(validDataset,batch_size=batch_size,shuffle=False)

In [18]:
loss_fn=nn.MSELoss()
optimizer=torch.optim.Adam(denoising_Unet.parameters(),lr=lr,betas=(0.5,0.999))
lr_schedular=torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=LambdaLR(epochs).step)

In [19]:
os.makedirs('./final_result_denoising_unet',exist_ok=True)
history={'loss':[], 'PSNR':[]}

In [20]:
trainstep=(len(trainLoader.dataset)//batch_size)+1 
outtertqdm=tqdm(range(epochs))
best_psnr=0
for epoch in outtertqdm:
    innertqdm=tqdm(range(trainstep),leave=False)
    dataiter=iter(trainLoader)
    denoising_Unet.train()
    
    stepcnt=0
    stepValidcnt=0
    total_loss=0
    total_PSNR=0
    
    for step in innertqdm:
        stepcnt+=1
        fd, qd=next(dataiter)
        fd, qd=(fd.to(DEVICE), qd.to(DEVICE))
        fd_predicted=denoising_Unet(qd)
        loss=loss_fn(fd_predicted, fd)
        denoising_Unet.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss+=loss
        
    with torch.no_grad():
        denoising_Unet.eval()
        for (fd, qd) in validLoader:
            fd, qd=(fd.to(DEVICE), qd.to(DEVICE))
            fd_predicted=denoising_Unet(qd)
            
            fd_predicted=normalize(fd_predicted,zero2one=True)
            fd=normalize(fd,zero2one=True)
            total_PSNR+=psnr(fd_predicted, fd)
            stepValidcnt+=1
    
    avgLoss=total_loss.cpu().detach().numpy()/stepcnt
    avgPSNR=total_PSNR.cpu().detach().numpy()/stepValidcnt
    
    history['loss'].append(avgLoss)
    history['PSNR'].append(avgPSNR)
    
    writer.add_scalar('loss', avgLoss, epoch)
    writer.add_scalar('PSNR', avgPSNR, epoch)
    
    lr_schedular.step()
    
    outtertqdm.set_postfix({'PSNR':'{:.03f}'.format(avgPSNR),'PSNR_Diff':'{:.03f}'.format(avgPSNR-26.9565)})
    with open('./final_result_denoising_unet/history.pkl','wb') as f:
            pickle.dump(history,f)
    if best_psnr<avgPSNR:
        best_psnr=avgPSNR
        torch.save(denoising_Unet.state_dict(),'./final_result_denoising_unet/denoising_unet.pt')
writer.flush()
writer.close()

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]