# Phase II: Unsupervised Approach

### Import Required Modules

In [1]:
import os     
import glob             
import random                         
import nibabel as nib                    

import numpy as np
import matplotlib.pyplot as plt              
from sklearn.preprocessing import MinMaxScaler 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from generator import Generator
from discriminator import Discriminator
from dataset import MRIDataset

### Unzip Data Archive
This cell unpacks the archive.zip file in the data directory. If this file has already been unpacked, leave this cell commented out, as it takes several minutes to run.

In [2]:
# import shutil
# shutil.unpack_archive('data/archive.zip', 'data/archive')

### Training Function

This function is used to train the cycle-GAN model. It does this in three parts:
1. Train the T1 discriminator by feeding it a real T1 and a fake T1, and measuring the loss.
2. Train the T2 discriminator by feeding it a real T2 and a fake T2, and measuring the loss.
3. Train the T1 and T2 generators by measuring their adverserial losses (the generator's ability to fool the discriminator) and their cycle losses (the generator's ability to create cycle-consistent images).

The parameters to this function are as follows:
- gen1: generates a T1 using T2 as input (T2->T1)
- gen2: generates a T2 using T1 as input (T1->T2)
- genOpt: optimizer for generator
- disc1: classifies real/fake T1 
- disc2: classifies real/fake T2 
- discOpt: optimizer for discriminator
- mse: mean-square error loss function
- mae: mean-absolute error loss function
- loader: data loader
- device: target device to run on
- plot: set to True if you wish to display real and generated images on the final iteration, otherwise False.


In [5]:
def train(gen1, gen2, genOpt, disc1, disc2, discOpt, mse, mae, loader, device, plot=False):
    # iterate through the entire dataset
    for index, (t1, t2) in enumerate(loader):
        t1 = t1.to(device)
        t2 = t2.to(device)
        
        # ------------------- t1 discriminator forward pass ------------------- #
        # use the t2 -> t1 generator to generate a fake t1
        t1GenFake = gen1(t2)
        
        # send the real t1 and fake t1 (from above) through the t1 discriminator
        t1DiscReal = disc1(t1)
        t1DiscFake = disc1(t1GenFake.detach())
        
        # calculate t1 discriminator loss (targets are ones for real and zeros for fake)
        t1DiscRealLoss = mse(t1DiscReal, torch.ones_like(t1DiscReal))
        t1DiscFakeLoss = mse(t1DiscFake, torch.zeros_like(t1DiscFake))
        t1DiscLoss = t1DiscRealLoss + t1DiscFakeLoss
                
        # ------------------- t2 discriminator forward pass ------------------- #
        # use the t1 -> t2 generator to generate a fake t2
        t2GenFake = gen2(t1)
        
        # send the real t2 and fake t2 (from above) through the t2 discriminator
        t2DiscReal = disc2(t2)
        t2DiscFake = disc2(t2GenFake.detach())
        
        # calculate t2 discriminator loss (targets are ones for real and zeros for fake)
        t2DiscRealLoss = mse(t2DiscReal, torch.ones_like(t2DiscReal))
        t2DiscFakeLoss = mse(t2DiscFake, torch.zeros_like(t2DiscFake))
        t2DiscLoss = t2DiscRealLoss + t2DiscFakeLoss
        
        # ------------------- backward pass for discriminators ------------------- #
        discLoss = t1DiscLoss+t2DiscLoss
        discOpt.zero_grad()
        discLoss.backward()
        discOpt.step()
        
        # ------------------- forward pass for generators ------------------- #
        # calculate adversarial loss (generator's ability to fool discriminator)
        t1DiscFake = disc1(t1GenFake)
        t2DiscFake = disc2(t2GenFake)
        
        t1AdvLoss = mse(t1DiscFake, torch.ones_like(t1DiscFake))
        t2AdvLoss = mse(t2DiscFake, torch.ones_like(t2DiscFake))
        advLoss = t1AdvLoss + t2AdvLoss

        # calculate cycle loss (ability of generator to create cycle-consistent images)
        t1Cycle = gen1(t2GenFake)
        t2Cycle = gen2(t1GenFake)
        
        t1CycleLoss = mae(t1, t1Cycle)
        t2CycleLoss = mae(t2, t2Cycle)
        cycleLoss = (t1CycleLoss + t2CycleLoss) * 10
        
        # ------------------- backward pass for generators ------------------- #
        genLoss = advLoss + cycleLoss
        genOpt.zero_grad()
        genLoss.backward()
        genOpt.step() 
        
        # show losses and (optionally) real and generated images on final iteration
        if index == len(loader.dataset)-1:
            print(f'disc. loss: {discLoss.item()}\tgen. loss: {genLoss.item()}')
            
            if plot:
                # true t1 (from dataset)
                plt.figure(figsize=(12,8))
                plt.subplot(121)
                plt.imshow(t1[0,:,:].cpu())
                plt.title('True T1')
                
                # fake t2 (generated using true t1 as input)
                plt.subplot(122)
                plt.imshow(t2GenFake[0,:,:].detach().cpu())
                plt.title('Fake T2')
                plt.show()

                # true t2 (from dataset)
                plt.figure(figsize=(12,8))
                plt.subplot(121)
                plt.imshow(t2[0,:,:].cpu())
                plt.title('True T2')

                # fake t1 (generated using true t2 as input)
                plt.subplot(122)
                plt.imshow(t1GenFake[0,:,:].detach().cpu())
                plt.title('Fake T1')
                plt.show()


### Testing Function

The following function can be used to evaluate the generators of the cycle-GAN.

The parameters to this function are as follows:
- model: the directory of the generator model to be evaluated
- loader: the validation data loader
- lossFunction: the loss function with which to evaluate the model


In [6]:
def test(modelDir, loader, lossFunction):    
    print(f'\n\n\n\nBeginning Testing on {model.__class__.__name__}...')
    
    model = torch.load(modelDir)
    averageLoss = 0.0
    with torch.no_grad():
        for index, (t1, t2) in enumerate(loader):
            t1 = t1.to(device)
            t2 = t2.to(device)

            # get model output, compute loss, add to running loss
            output = model(t1)
            loss = lossFunction(output, t2)
            averageLoss += loss.item()
            
            # display every 20th image pair
            if b % 20 == 0:
                print(f'loss for b={b}:', loss.item())

                # plot a random slice of the output and label
                plt.figure(figsize=(12,8))

                plt.subplot(131)
                plt.imshow((t1[0,:,:]).cpu())
                plt.title('T1 Image')

                plt.subplot(132)
                plt.imshow((t2[0,:,:]).cpu())
                plt.title('T2 Image')

                plt.subplot(133)
                plt.imshow((output[0,:,:]).cpu())
                plt.title('Predicted T2 Image')
                plt.show()
        
        print(f'Average loss across all validation images: {averageLoss/batchSize}')
            

### Train the Model
This cell defines the device, training data loader, generators, discriminators, loss functions, and optimizers, and then invokes the training function to train the model. After every 10th epoch, starting at epoch 50, the two generators and discriminators will be saved in model/epoch{epoch_number}.

In [7]:
# target device to run training on
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'device: {device}')

# training loader, used to extract and preprocess images from the archive
trainDir = 'data/archive/MICCAI_BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData'
dataset = MRIDataset(t1Dir=trainDir, t2Dir=trainDir, train=True)
loader = DataLoader(dataset,batch_size=1,shuffle=True)

# cycle-GAN components
gen1 = Generator().to(device)        # t2->t1 generator
gen2 = Generator().to(device)        # t1->t2 generator
disc1 = Discriminator().to(device)   # t1 discriminator
disc2 = Discriminator().to(device)   # t2 discriminator

# loss functions
mse = nn.MSELoss()     # used for discriminator losses and adversarial loss
mae = nn.L1Loss()      # used for cycle loss

# discriminator and generator optimizers
discOpt = optim.Adam(list(disc1.parameters()) + list(disc2.parameters()),lr=2e-4,betas=(0.5,0.999))
genOpt = optim.Adam(list(gen1.parameters()) + list(gen2.parameters()),lr=2e-4,betas=(0.5,0.999))

# train the model for 151 epochs
nepochs = 151
for epoch in range(nepochs):
    print(f'\n\n\n* =================================== EPOCH {epoch} =================================== *')
    train(gen1, gen2, genOpt, disc1, disc2, discOpt, mse, mae, loader, device, True)
    
    # every 10th epoch after 50, save the models
    if epoch >= 50 and epoch % 10 == 0:
        
        modelDir = f'model/epoch{epoch}/'
        if not os.path.exists(modelDir):
            os.makedirs(modelDir)

        torch.save(gen1.state_dict(), modelDir + f'gen1epoch{epoch}.pth')
        torch.save(gen2.state_dict(), modelDir + f'gen2epoch{epoch}.pth')
        torch.save(disc1.state_dict(), modelDir + f'disc1epoch{epoch}.pth')
        torch.save(disc2.state_dict(), modelDir + f'disc2epoch{epoch}.pth')
        