In [1]:
import os
import math

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter

from tqdm import tqdm
from collections import deque


# Device configuration
device = torch.device('cuda:1')
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
bs = 1
valid_size = 200

def npy_loader(path):
    sample = np.load(path)
    return sample

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.165, 0.275)
]) 


# Datasets
train_dataset_Y = datasets.DatasetFolder(root='./AAPM/train/full_dose',
                                            loader=npy_loader,
                                            extensions=['.npy'],
                                           transform = transform)
train_dataset_X = datasets.DatasetFolder(root='./AAPM/train/quarter_dose',
                                               loader=npy_loader,
                                               extensions=['.npy'],
                                              transform = transform)


test_dataset_Y = datasets.DatasetFolder(root='./AAPM/test/full_dose',
                                            loader=npy_loader,
                                            extensions=['.npy'],
                                           transform = transform)
test_dataset_X = datasets.DatasetFolder(root='./AAPM/test/quarter_dose',
                                               loader=npy_loader,
                                               extensions=['.npy'],
                                              transform = transform)

# valid_dataset_Y = torch.utils.data.Subset(valid_test_dataset_Y, range(0, 200))
# valid_dataset_X = torch.utils.data.Subset(valid_test_dataset_X, range(0, 200))

# test_dataset_Y = torch.utils.data.Subset(valid_test_dataset_Y, range(200, 421))
# test_dataset_X = torch.utils.data.Subset(valid_test_dataset_X, range(200, 421))

# Dataloaders
train_loader_Y = torch.utils.data.DataLoader(dataset=train_dataset_Y, batch_size=bs, shuffle=True, drop_last = True)
train_loader_X = torch.utils.data.DataLoader(dataset=train_dataset_X, batch_size=bs, shuffle=True, drop_last = True)

# valid_loader_Y = torch.utils.data.DataLoader(dataset=valid_dataset_Y, batch_size=1, shuffle=False)
# valid_loader_X = torch.utils.data.DataLoader(dataset=valid_dataset_X, batch_size=1, shuffle=False)

test_loader_Y = torch.utils.data.DataLoader(dataset=test_dataset_Y, batch_size=1, shuffle=False)
test_loader_X = torch.utils.data.DataLoader(dataset=test_dataset_X, batch_size=1, shuffle=False)

train_dataset_length = len(train_dataset_X)
# valid_dataset_length = len(valid_dataset_X)
test_dataset_length = len(test_dataset_X)

In [3]:
# edited from https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py

class DoubleConv(nn.Module):

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


In [4]:
# edited from https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py

class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.inc = DoubleConv(1, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024,256)
        self.up2 = Up(512, 128)
        self.up3 = Up(256, 64)
        self.up4 = Up(128, 64)
        self.outc = nn.Conv2d(64, 1, kernel_size=1)
        self.act = nn.Tanh()

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return self.act(x)

In [5]:
# https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py

class Discriminator(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
                                   nn.LeakyReLU(negative_slope=0.2, inplace=True),
                                   nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
                                   nn.InstanceNorm2d(128),
                                   nn.LeakyReLU(negative_slope=0.2, inplace=True),
                                   nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
                                   nn.InstanceNorm2d(256),
                                   nn.LeakyReLU(negative_slope=0.2, inplace=True),
                                   nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
                                   nn.InstanceNorm2d(512),
                                   nn.LeakyReLU(negative_slope=0.2, inplace=True),
                                   nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
                                  )
        

    def forward(self, input):
        
        return self.model(input)

In [6]:
# build network
# img_dim = train_dataset.data.shape[1] * train_dataset.data.shape[2] * train_dataset.data.shape[3]

# G_Y:X->Y, D_Y:Y->R
# G_X:Y->X, D_X:X->R
G_Y = Generator().to(device)
G_X = Generator().to(device)
D_Y = Discriminator().to(device)
D_X = Discriminator().to(device)

In [7]:
# parameters
lambd_cyc = 10
lambd_id = 0.5
lr = 2e-4
patch_size = 62

In [8]:
# optimizer
G_Y_optimizer = optim.Adam(G_Y.parameters(), lr = lr)
G_X_optimizer = optim.Adam(G_X.parameters(), lr = lr)
D_Y_optimizer = optim.Adam(D_Y.parameters(), lr = lr)
D_X_optimizer = optim.Adam(D_X.parameters(), lr = lr)

In [9]:
def D_train_adv(D, G, x, y, D_optimizer):
    criterion = nn.MSELoss()
    #=======================Train the discriminator=======================#
    D.zero_grad()

    # train discriminator on real
    D_output_real = D(y)
    D_real_loss = criterion(D_output_real, torch.ones(bs, 1, patch_size, patch_size).to(device))

    # train discriminator on fake
    D_output_fake = D(G(x))
    D_fake_loss = criterion(D_output_fake,  torch.zeros(bs, 1, patch_size, patch_size).to(device))
    
    # gradient backprop & optimize ONLY D's parameters
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()
        
    return  D_loss.data.item()

In [10]:
def G_train_adv(G, D, x, y, G_optimizer):
    criterion = nn.MSELoss()
    #=======================Train the generator=======================#
    G.zero_grad()

    D_output_fake = D(G(x))
    G_loss = criterion(D_output_fake, torch.ones(bs, 1, patch_size, patch_size).to(device))

    # gradient backprop & optimize ONLY G's parameters
    G_loss.backward()
    G_optimizer.step()
        
    return G_loss.data.item()

In [11]:
def G_train_cyc(G_X, G_Y, x, y, G_X_optimizer, G_Y_optimizer):
    criterion = nn.L1Loss()
    
    G_X.zero_grad()
    G_Y.zero_grad()
    
    cyc_loss = lambd_cyc * (criterion(G_X(G_Y(x)), x) + criterion(G_Y(G_X(y)), y))
    
    cyc_loss.backward()
    G_X_optimizer.step()
    G_Y_optimizer.step()
    
    return cyc_loss.data.item()
    

In [12]:
def G_train_id(G_X, G_Y, x, y, G_X_optimizer, G_Y_optimizer):
    criterion = nn.L1Loss()
    
    G_X.zero_grad()
    G_Y.zero_grad()
    
    id_loss = lambd_id * (criterion(G_Y(x), x) + criterion(G_X(y), y))
    
    id_loss.backward()
    G_X_optimizer.step()
    G_Y_optimizer.step()
    
    return id_loss.data.item()
    

In [13]:
def calc_orig_PSNR(x, y):
    criterion = nn.MSELoss()
    with torch.no_grad():
        MSE = criterion(x, y)
        orig_PSNR = 10 * torch.log((y.max() - y.min())**2 / MSE) / math.log(10)
    return orig_PSNR

def calc_orig_PSNR_loader(loader_X, loader_Y):
    orig_PSNR_list = []
    for x, y in zip(loader_X, loader_Y):
        x = x[0].to(device) # remove labels
        y = y[0].to(device)
        orig_PSNR_list.append(calc_orig_PSNR(x,y))
    orig_PSNR_loader = torch.mean(torch.stack(orig_PSNR_list)).item()
    return orig_PSNR_loader

train_orig_PSNR = calc_orig_PSNR_loader(train_loader_X, train_loader_Y)
test_orig_PSNR = calc_orig_PSNR_loader(test_loader_X, test_loader_Y)

In [14]:
def calc_PSNR(G_Y, x, y):
    criterion = nn.MSELoss()
    with torch.no_grad():
        MSE = criterion(G_Y(x), y)
        PSNR = 10 * torch.log((y.max() - y.min())**2 / MSE) / math.log(10)
    return PSNR

def calc_PSNR_loader(G_Y, loader_X, loader_Y):
    PSNR_list = []
    for x, y in zip(loader_X, loader_Y):
        x = x[0].to(device) # remove labels
        y = y[0].to(device)
        PSNR_list.append(calc_PSNR(G_Y, x,y))
    PSNR_loader = torch.mean(torch.stack(PSNR_list)).item()
    return PSNR_loader

In [None]:
n_epoch = 100
if 'epoch' in locals():
    if complete:
        init_epoch = epoch+1
    else:
        init_epoch = epoch
else:
    init_epoch = 1
if not 'max_PSNR_diff' in locals():
    max_PSNR_diff = float('-inf')

writer = SummaryWriter(log_dir = 'runs/' + str(lambd_cyc) + '_' + str(lambd_id))

D_X_losses,G_X_losses,D_Y_losses,G_Y_losses,cyc_losses, id_losses, train_PSNRs, valid_PSNRs = deque(),deque(),deque(),deque(),deque(),deque(),deque(),deque()

for epoch in range(init_epoch, init_epoch+n_epoch):
    complete = False # marker for whether the training completed the last epoch
    
    for step, (x, y) in tqdm(enumerate(zip(train_loader_X, train_loader_Y)),total=train_dataset_length):
        x = x[0].to(device) # remove labels
        y = y[0].to(device)
        
        D_X_losses.append(D_train_adv(D_X, G_X, x, y, D_X_optimizer))
        G_X_losses.append(G_train_adv(G_X, D_X, x, y, G_X_optimizer))
        D_Y_losses.append(D_train_adv(D_Y, G_Y, x, y, D_Y_optimizer))
        G_Y_losses.append(G_train_adv(G_Y, D_Y, x, y, G_Y_optimizer))
        cyc_losses.append(G_train_cyc(G_X, G_Y, x, y, G_X_optimizer, G_Y_optimizer))
        id_losses.append(G_train_id(G_X, G_Y, x, y, G_X_optimizer, G_Y_optimizer))
        train_PSNRs.append(calc_PSNR(G_Y, x, y))
        
        if len(D_X_losses) > train_dataset_length:
            D_X_losses.popleft()
            G_X_losses.popleft()
            D_Y_losses.popleft()
            G_Y_losses.popleft()
            cyc_losses.popleft()
            id_losses.popleft()
            train_PSNRs.popleft()
        
        if step % 50 == 0:
            loss_D = torch.mean(torch.FloatTensor(D_X_losses) + torch.FloatTensor(D_Y_losses)).item()
            loss_G = torch.mean(torch.FloatTensor(G_X_losses)+torch.FloatTensor(G_Y_losses)).item()
            loss_cyc = torch.mean(torch.FloatTensor(cyc_losses)).item()
            loss_id = torch.mean(torch.FloatTensor(id_losses)).item()
            train_PSNR = torch.mean(torch.FloatTensor(train_PSNRs)).item()
            train_PSNR_diff = train_PSNR - train_orig_PSNR
            
            step_total = (epoch-1)*train_dataset_length + step
            writer.add_scalar('loss_D', loss_D, step_total)
            writer.add_scalar('loss_G', loss_G, step_total)
            writer.add_scalar('loss_cyc', loss_cyc, step_total)
            writer.add_scalar('loss_id', loss_id, step_total)
            writer.add_scalar('train_PSNR_diff', train_PSNR_diff, step_total)
    
    test_PSNR = calc_PSNR_loader(G_Y, test_loader_X, test_loader_Y)
    test_PSNR_diff = test_PSNR - test_orig_PSNR
    
    # valid_PSNR = calc_PSNR_loader(G_Y, valid_loader_X, valid_loader_Y)
    # valid_PSNR_diff = valid_PSNR - valid_orig_PSNR
    
    writer.add_scalar('test_PSNR_diff', test_PSNR_diff, epoch)
    
    if test_PSNR_diff > max_PSNR_diff:
        max_PSNR_diff = test_PSNR_diff
        torch.save({
            'G_X': G_X.state_dict(),
            'G_Y': G_Y.state_dict(),
            'D_X': D_X.state_dict(),
            'D_Y': D_Y.state_dict(),
            'G_X_optimizer': G_X_optimizer.state_dict(),
            'G_Y_optimizer': G_Y_optimizer.state_dict(),
            'D_X_optimizer': D_X_optimizer.state_dict(),
            'D_Y_optimizer': D_Y_optimizer.state_dict(),
        }, "unet_best_" + str(lambd_cyc) + '_' + str(lambd_id) + ".pt")
        
    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f, loss_cyc: %.3f, loss_id: %.3f, test_PSNR_diff: %.3f' % (
            epoch, n_epoch, loss_D, loss_G, loss_cyc, loss_id, test_PSNR_diff))
    complete = True
    

100%|█████████████████████████████████████| 3839/3839 [1:21:18<00:00,  1.27s/it]


[1/100]: loss_d: 0.999, loss_g: 0.503, loss_cyc: 0.023, loss_id: 0.003, test_PSNR_diff: 3.338


 31%|████████████                           | 1184/3839 [25:08<56:03,  1.27s/it]

In [15]:
checkpoint = torch.load("unet_best_10_5.pt")
G_X.load_state_dict(checkpoint['G_X'])
G_Y.load_state_dict(checkpoint['G_Y'])
D_X.load_state_dict(checkpoint['D_X'])
D_Y.load_state_dict(checkpoint['D_Y'])
G_X_optimizer.load_state_dict(checkpoint['G_X_optimizer'])
G_Y_optimizer.load_state_dict(checkpoint['G_Y_optimizer'])
D_X_optimizer.load_state_dict(checkpoint['D_X_optimizer'])
D_Y_optimizer.load_state_dict(checkpoint['D_Y_optimizer'])


In [16]:
test_PSNR = calc_PSNR_loader(G_Y, test_loader_X, test_loader_Y)
print("The PSNR diffrence on test set is " + str(test_PSNR - test_orig_PSNR))

The PSNR diffrence on test set is 4.366001129150391


In [None]:
def PSNR_list_loader(G_Y, loader_X, loader_Y):
    PSNR_list = []
    for x, y in zip(loader_X, loader_Y):
        x = x[0].to(device) # remove labels
        y = y[0].to(device)
        PSNR_list.append(calc_PSNR(G_Y, x,y))
    return PSNR_list

def orig_PSNR_list_loader(loader_X, loader_Y):
    orig_PSNR_list = []
    for x, y in zip(loader_X, loader_Y):
        x = x[0].to(device) # remove labels
        y = y[0].to(device)
        orig_PSNR_list.append(calc_orig_PSNR(x,y))
    return orig_PSNR_list

In [None]:
test_PSNR_list = PSNR_list_loader(G_Y, test_loader_X, test_loader_Y)
test_orig_PSNR_list = orig_PSNR_list_loader(test_loader_X, test_loader_Y)

In [None]:
plt.hist((torch.stack(test_PSNR_list) - torch.stack(test_orig_PSNR_list)).cpu())

In [None]:
diff_list = torch.stack(test_PSNR_list) - torch.stack(test_orig_PSNR_list)

In [None]:
best_index = torch.argmax(diff_list)
worst_index = torch.argmin(diff_list)

In [None]:
for i in [best_index, worst_index]:
    criterion = nn.MSELoss()

    x = test_dataset_X[i][0].view(1,1,512,512)
    y_pred = G_Y(x.to(device)).cpu()
    x = x.view(1,512,512)
    y_pred = y_pred.view(1,512,512)
    y = test_dataset_Y[i][0].view(1,512,512)

    diff = y_pred - x

    with torch.no_grad():
        MSE = criterion(y_pred, y)
        PSNR = 10 * torch.log((y.max() - y.min())**2 / MSE) / math.log(10)
        MSE_orig = criterion(x, y)
        PSNR_orig = 10 * torch.log((y.max() - y.min())**2 / MSE_orig) / math.log(10)


    save_folder = str(round(PSNR.item(),1)) + '_' + str(round(PSNR_orig.item(),1))
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)

    save_image(x, save_folder + '/x.png', normalize=True)
    save_image(y_pred, save_folder + '/y_pred.png', normalize=True)
    save_image(y, save_folder + '/y.png', normalize=True)
    save_image(diff, save_folder + '/diff.png', normalize=True)

