# Train set AVG PSNR= 25.2899

# Test set AVG PSNR=26.9565

# A: Full Dose,  B: Quarter Dose

In [1]:
import torch
import itertools
from model import Discriminator,UNetGenerator,get_norm_layer,init_weight
from utils import Buffer,LambdaLR
from torchvision import transforms
import os
import torch.nn as nn
from tqdm.auto import tqdm
from dataset import makeDataset
from torch.utils.data import DataLoader
from torchmetrics import PeakSignalNoiseRatio
from torch.utils.tensorboard import SummaryWriter
import pickle
import numpy as np

In [2]:
def normalize(tensor,zero2one=False):
    shape=tensor.shape
    tensor_norm=tensor.reshape([shape[0],shape[1],-1])
    tensor_max=tensor_norm.max(dim=-1,keepdim=True)[0]
    tensor_min=tensor_norm.min(dim=-1,keepdim=True)[0]
    tensor_norm=((tensor_norm-tensor_min)/(tensor_max-tensor_min))
    if zero2one:
        tensor_norm=tensor_norm.reshape(shape)
    else:
        tensor_norm=((tensor_norm-0.5)*2.0).reshape(shape)
    return tensor_norm

In [3]:
norm_type='batch'
epochs=80
lambda_identity=0.5
lambda_A=10.0
lambda_B=10.0
lr=0.0002
ngf=64
ndf=64
use_droplayer=False
batch_size=8
norm_layer=get_norm_layer(norm_type)
DEVICE='cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
G_A2B=UNetGenerator(in_c=1,out_c=1,ngf=ngf,norm_layer=norm_layer,use_drop=use_droplayer)
G_B2A=UNetGenerator(in_c=1,out_c=1,ngf=ngf,norm_layer=norm_layer,use_drop=use_droplayer)
D_A=Discriminator(in_c=1,ndf=ndf,n_layers=3,norm_layer=norm_layer)
D_B=Discriminator(in_c=1,ndf=ndf,n_layers=3,norm_layer=norm_layer)

In [5]:
G_A2B.to(DEVICE)
G_B2A.to(DEVICE)
D_A.to(DEVICE)
D_B.to(DEVICE)

G_A2B.apply(init_weight)
G_B2A.apply(init_weight)
D_A.apply(init_weight)
D_B.apply(init_weight)

if torch.cuda.device_count() >1:
    print('training with {} GPUs'.format(torch.cuda.device_count()))
    G_A2B=torch.nn.DataParallel(G_A2B)
    G_B2A=torch.nn.DataParallel(G_B2A)
    D_A=torch.nn.DataParallel(D_A)
    D_B=torch.nn.DataParallel(D_B)

In [6]:
BufferA=Buffer()
BufferB=Buffer()
psnr=PeakSignalNoiseRatio(data_range=1,dim=[-1,-2,-3])
writer=SummaryWriter()

In [7]:
trainTransforms=transforms.Compose([
    transforms.RandomCrop((256,256)),
])

In [8]:
trainDataset=makeDataset(trainTransforms,alignB=True,sameTransformB=True)
validDataset=makeDataset(transform=None,mode='test',alignB=True,sameTransformB=True)

trainLoader=DataLoader(trainDataset,batch_size=batch_size,shuffle=True)
validLoader=DataLoader(validDataset,batch_size=batch_size,shuffle=False)

In [9]:
lossGAN=nn.MSELoss()
lossCycle=nn.L1Loss()
lossIdentity=nn.L1Loss()

optGAN=torch.optim.Adam(itertools.chain(G_A2B.parameters(),G_B2A.parameters()),lr=lr,betas=(0.5,0.999))
optD=torch.optim.Adam(itertools.chain(D_A.parameters(),D_B.parameters()),lr=lr,betas=(0.5,0.999))

lr_schedular_G=torch.optim.lr_scheduler.LambdaLR(optGAN,lr_lambda=LambdaLR(epochs).step)
lr_schedular_D=torch.optim.lr_scheduler.LambdaLR(optD,lr_lambda=LambdaLR(epochs).step)
history={'loss_G_A':[],'loss_Cycle_A':[],'loss_idt_A':[],
         'loss_G_B':[],'loss_Cycle_B':[],'loss_idt_B':[],'loss_D_A':[],'loss_D_B':[],'PSNR':[]}

In [10]:
if not os.path.exists('./final_result'):
    os.makedirs('./final_result')
if not os.path.exists('./final_result/GAN_FD_to_QD'):
    os.makedirs('./final_result/GAN_FD_to_QD')
if not os.path.exists('./final_result/GAN_QD_to_FD'):
    os.makedirs('./final_result/GAN_QD_to_FD')
if not os.path.exists('./final_result/Discriminator_A'):
    os.makedirs('./final_result/Discriminator_A')
if not os.path.exists('./final_result/Discriminator_B'):
    os.makedirs('./final_result/Discriminator_B')

In [11]:
trainstep=(len(trainLoader.dataset)//batch_size)+1 
outtertqdm=tqdm(range(epochs))
best_psnr=0
for epoch in outtertqdm:
    innertqdm=tqdm(range(trainstep),leave=False)
    dataiter=iter(trainLoader)
    G_A2B.train()
    G_B2A.train()
    D_A.train()
    D_B.train()
    
    stepcnt=0
    stepValidcnt=0
    totalLoss_G_A=0
    totalLoss_G_Cycle_A=0
    totalLoss_G_idt_A=0
    totalLoss_G_B=0
    totalLoss_G_Cycle_B=0
    totalLoss_G_idt_B=0
    totalLoss_D_A=0
    totalLoss_D_B=0
    totalPSNR=0
    
    for step in innertqdm:
        stepcnt+=1
        (realA,realB)=next(dataiter)
        (realA,realB)=(realA.to(DEVICE),realB.to(DEVICE))
        ######################################Generator#############################################
        optGAN.zero_grad()
        #####Identity Loss#####
        if lambda_identity>0.0:
            loss_idt_A=lossIdentity(G_B2A(realA),realA)*lambda_A*lambda_identity
            loss_idt_B=lossIdentity(G_A2B(realB),realB)*lambda_B*lambda_identity
        else:
            loss_idt_A=0
            loss_idt_B=0
        #####GAN Loss#####
        fakeA=G_B2A(realB)
        fakeB=G_A2B(realA)
        pred_D_A_fakeA=D_A(fakeA)
        pred_D_B_fakeB=D_B(fakeB)
        
        targetReal=torch.ones_like(pred_D_A_fakeA,device=DEVICE)
        targetFake=torch.zeros_like(pred_D_A_fakeA,device=DEVICE)
        
        loss_G_A2B=lossGAN(pred_D_B_fakeB,targetReal)
        loss_G_B2A=lossGAN(pred_D_A_fakeA,targetReal)
        #####Cycle Loss#####
        
        cycleA=G_B2A(fakeB)
        cycleB=G_A2B(fakeA)
        
        loss_cycle_A=lossCycle(cycleA,realA)*lambda_A
        loss_cycle_B=lossCycle(cycleB,realB)*lambda_B
        #####Final Loss#####
        lossG=loss_G_A2B+loss_G_B2A+loss_cycle_A+loss_cycle_B+loss_idt_A+loss_idt_B
        lossG.backward()
        optGAN.step()
        
        ###################################Discriminator##########################################
        optD.zero_grad()
        #####Discriminator A Loss#####
        fakeA=BufferA.push_pop(fakeA)
        loss_D_A_Real=lossGAN(D_A(realA),targetReal)
        loss_D_A_Fake=lossGAN(D_A(fakeA.detach()),targetFake)
        loss_D_A=(loss_D_A_Fake+loss_D_A_Real)*0.5
        #####Discriminator B Loss#####
        fakeB=BufferB.push_pop(fakeB)
        loss_D_B_Real=lossGAN(D_B(realB),targetReal)
        loss_D_B_Fake=lossGAN(D_B(fakeB.detach()),targetFake)
        loss_D_B=(loss_D_B_Real+loss_D_B_Fake)*0.5
        #####Final Loss#####
        loss_D_A.backward()
        loss_D_B.backward()
        optD.step()
        ########################################Finish###############################################
        totalLoss_G_A += loss_G_A2B
        totalLoss_G_Cycle_A += loss_cycle_A
        totalLoss_G_idt_A += loss_idt_A
        totalLoss_G_B += loss_G_B2A
        totalLoss_G_Cycle_B += loss_cycle_B
        totalLoss_G_idt_B += loss_idt_B
        totalLoss_D_A += loss_D_A
        totalLoss_D_B += loss_D_B
    
    with torch.no_grad():
        G_B2A.eval() #B: qd   A: fd
        for (realA,realB) in validLoader:
            (realA,realB)=(realA.to(DEVICE),realB.to(DEVICE))
            fakeA=G_B2A(realB)
            
            fakeA=normalize(fakeA,zero2one=True)
            #realA=(realA+1.0)*0.5
            realA=normalize(realA,zero2one=True)
            batch_psnr=psnr(fakeA,realA)
            totalPSNR+=batch_psnr
            stepValidcnt+=1
            
    
    avgLoss_G_A=totalLoss_G_A.cpu().detach().numpy()/stepcnt
    avgLoss_G_Cycle_A=totalLoss_G_Cycle_A.cpu().detach().numpy()/stepcnt
    avgLoss_G_idt_A=totalLoss_G_idt_A.cpu().detach().numpy()/stepcnt
    
    avgLoss_G_B=totalLoss_G_B.cpu().detach().numpy()/stepcnt
    avgLoss_G_Cycle_B=totalLoss_G_Cycle_B.cpu().detach().numpy()/stepcnt
    avgLoss_G_idt_B=totalLoss_G_idt_B.cpu().detach().numpy()/stepcnt
    
    avgLoss_D_A=totalLoss_D_A.cpu().detach().numpy()/stepcnt
    avgLoss_D_B=totalLoss_D_B.cpu().detach().numpy()/stepcnt
    avgPSNR=totalPSNR.cpu().detach().numpy()/stepValidcnt
    
    history['loss_G_A'].append(avgLoss_G_A)
    history['loss_Cycle_A'].append(avgLoss_G_Cycle_A)
    history['loss_idt_A'].append(avgLoss_G_idt_A)
    history['loss_G_B'].append(avgLoss_G_B)
    history['loss_Cycle_B'].append(avgLoss_G_Cycle_B)
    history['loss_idt_B'].append(avgLoss_G_idt_B)
    history['loss_D_A'].append(avgLoss_D_A)
    history['loss_D_B'].append(avgLoss_D_B)
    history['PSNR'].append(avgPSNR)
    
    writer.add_scalar('GAN_A',avgLoss_G_A,epoch)
    writer.add_scalar('Cycle_A',avgLoss_G_Cycle_A,epoch)
    writer.add_scalar('Idt_A',avgLoss_G_idt_A,epoch)
    writer.add_scalar('GAN_B',avgLoss_G_B,epoch)
    writer.add_scalar('Cycle_B',avgLoss_G_Cycle_B,epoch)
    writer.add_scalar('Idt_B',avgLoss_G_idt_B,epoch)
    writer.add_scalar('D_A',avgLoss_D_A,epoch)
    writer.add_scalar('D_B',avgLoss_D_B,epoch)
    writer.add_scalar('PSNR',avgPSNR,epoch)
    
    lr_schedular_G.step()
    lr_schedular_D.step()
    
    outtertqdm.set_postfix({'PSNR':'{:.03f}'.format(avgPSNR),'Diff':'{:.03f}'.format(avgPSNR-26.9565)})
    
    if best_psnr<avgPSNR:
        best_psnr=avgPSNR
        torch.save(G_A2B.state_dict(),'./final_result/GAN_FD_to_QD/GAN.pt')
        torch.save(G_B2A.state_dict(),'./final_result/GAN_QD_to_FD/GAN.pt')
        torch.save(D_A.state_dict(),'./final_result/Discriminator_A/Disc_A.pt')
        torch.save(D_B.state_dict(),'./final_result/Discriminator_B/Disc_B.pt')

        with open('./final_result/history.pkl','wb') as f:
            pickle.dump(history,f)
    
writer.flush()
writer.close()
        

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]

  0%|          | 0/480 [00:00<?, ?it/s]