In [5]:
import monai
import time
from monai.transforms import (
    Compose,
    LoadImaged,
    RandCropByPosNegLabeld,
    Spacingd,
    RandWeightedCrop,
    RandRotate,
    Rand3DElasticd,
    RandRotated,
    EnsureChannelFirstd,
    RandFlip,
    ScaleIntensityd,
    RandFlipd)
import tqdm
from monai.networks.nets import UNet
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, Dataset ,nifti_saver, PatchDataset
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
from glob import glob
from monai.networks.blocks import Convolution
from monai.networks.nets import Discriminator, Generator
from monai.utils import progress_bar
import torch.nn as nn
from torch.utils.data import DataLoader 
import torchmetrics #need to download torchmetrics for CNN, so just gonna continue editing gan script and then run a kbatch script

In [6]:
gad_t1= sorted(glob('/home/fogunsan/scratch/degad/derivatives/passing_dataset/*/*_acq-gad_resampled_T1w.nii.gz'))# gad images who's corresponding nongad images underwent a rigid transform
nongad_t1= sorted(glob('/home/fogunsan/scratch/degad/derivatives/normalized_fcm/*/*_acq-nongad_normalized_fcm.nii.gz')) # nongad images which underwent a rigid transform and underwent fcm normalization
image_dict = [{"image": gad_name, "label": nongad_name} for gad_name, nongad_name in zip(gad_t1,nongad_t1)] #creates list of dictionaries, with gad and nongad images labelled
train_files, test_files = image_dict[0:2], image_dict[38:] #creates a list of dictionaries for each set (training, val, testing), with keys of gad and nongad in each index 

In [7]:
num_train_files = len(train_files)
num_patches = 25#patches per image
batch_size = 2
date = "April27"
load_images= Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityd(keys = ["image"], minv=0.0, maxv=1.0)])# applying min max normalization only on gad images
 
train_imgs_cache = CacheDataset(data=train_files, transform=load_images, cache_rate=1.0, num_workers = 1) # dataset with cache mechanism that can load data and cache deterministic transforms’ result during training.

patching_func= RandCropByPosNegLabeld( # gonna use this function to create patches
            keys = ["image", "label"],
            label_key = "image",
            spatial_size=(32,32,32),
            pos = 1,
            neg = 0.01, # much larger probability of sampling foreground
            num_samples= num_patches# CHANGE BACK TO 5000
        )
patch_transforms = Compose(RandRotated(keys =["image", "label"], range_x = 0.8, range_y = 0.8, range_z = 0.8, prob = 0.4), RandFlipd(keys =["image", "label"], prob = 0.2, spatial_axis=1))# flipping along y-axis (horizontally)

train_patches_dataset = PatchDataset(data =train_imgs_cache, patch_func=patching_func, samples_per_image=num_patches, transform = patch_transforms)
train_patches_dataset = CacheDataset(data=train_patches_dataset, cache_rate=1.0, num_workers = 1, copy_cache=True) # dataset with cache mechanism that can load data and cache deterministic transforms’ result during training.


Loading dataset: 100%|██████████| 2/2 [00:07<00:00,  3.73s/it]
Loading dataset: 100%|██████████| 50/50 [00:18<00:00,  2.77it/s]


In [8]:
Generator=UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=1,
            channels=(16, 32,64,128,256,512,512,512),
            strides=(2, 2, 2, 2,1,1,1),
            dropout= 0.2,
        )
gen = Generator
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
gen.apply(monai.networks.normal_init)
gen_model = gen.to(device)

In [9]:
class GANDiscriminator(nn.Module):
    def __init__(self, in_channels=2, kernel_size=3):
        super().__init__()
       
        self.conv1 = nn.Sequential(
            nn.Conv3d(in_channels, 64, kernel_size, stride=2, padding=1),
            nn.InstanceNorm3d(64),
            nn.PReLU()
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv3d(64, 128, kernel_size, stride=2, padding=1),
            nn.InstanceNorm3d(128),
            nn.PReLU()
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv3d(128, 256, kernel_size, stride=2, padding=1),
            nn.InstanceNorm3d(256),
            nn.PReLU()
        )
        
        self.conv_out = nn.Conv3d(256, 1, kernel_size, stride=1, padding=0)
        self.tanh = nn.Tanh()
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.tanh(x)
        return x
disc = GANDiscriminator()
disc.apply(monai.networks.normal_init)
disc_model = disc.to(device)

In [10]:
def GeneratorLoss(nongad_images,degad_images, fake_preds):
    """
    Loss function is the sum of the binary cross entropy between the error of the discriminator output btwn gad and degad (fake prediction) and the root mean square error betwn as nongad and degad images multiplies by scalar weight coefficient
    nongad_image= real nongad images from the sample
    degad_images= generated nongad images from the generator
    fake_preds: output of discriminator when fed fake data
    """
    
    coeff = 0.01
    
    BCE_loss= torch.nn.BCELoss() 
    real_target = torch.ones((fake_preds.shape[0], fake_preds.shape[1], fake_preds.shape[2], fake_preds.shape[3], fake_preds.shape[4])) #new_full returns a tensor filled with 1 with the same shape as the discrminator prediction 
    fake_preds = torch.sigmoid(fake_preds) # applying sigmmoid function to output of the discriminator to map probability between 0 and 1
    BCE_fake = BCE_loss(fake_preds.to(device), real_target.to(device)) # BCE loss btwn the output of discrim when fed fake data and 1 <- generator wants to minimize this
    L1_loss = torch.nn.L1Loss()
    loss = L1_loss(degad_images, nongad_images)  # producing RMSE between ground truth nongad and degad
    generator_loss = loss*coeff + BCE_fake
    return generator_loss

In [11]:
def DiscriminatorLoss(real_preds, fake_preds):
    """
    Loss function for the discriminator: The discriminator loss is calculated by taking the sum of the L2 error of the discriminator output btwn gad and nongad( real prediction ) and the L2 error of the output btwn gad and degad( fake predition)
    
    real_preds: output of discriminator when fed real data
    fake_preds: output of discriminator when fed fake data
    """
    
    real_target = torch.ones((real_preds.shape[0], real_preds.shape[1], real_preds.shape[2],real_preds.shape[3], real_preds.shape[4])) #new_full returns a tensor filled with 1 with the same shape as the discrminator prediction 
    
    fake_target = torch.zeros((fake_preds.shape[0], fake_preds.shape[1], fake_preds.shape[2], fake_preds.shape[3], fake_preds.shape[4])) #new_full returns a tensor filled with 0 w/ the same shape as the generator prediction
    BCE_loss =  torch.nn.BCELoss().to(device)  # creates a losss value for each batch, averaging the value across all elements
    # Apply sigmoid to discriminator outputs, to fit between 0 and 1
    real_preds = torch.sigmoid(real_preds).cuda()
    fake_preds = torch.sigmoid(fake_preds).cuda()
    
    BCE_fake = BCE_loss(fake_preds.cuda(), fake_target.cuda()) # BCE loss btwn the output of discrim when fed fake data and 0 <- generator wants to minimize this
    BCE_real = BCE_loss(real_preds.cuda(), real_target.cuda()) # BCE loss btwn the output of discrim when fed real data and 1 <- generator wants to minimize this
    
    return BCE_real + BCE_fake
   

In [12]:
learning_rate = 2e-4
betas = (0.5, 0.999)
gen_opt = torch.optim.Adam(gen_model.parameters(), lr = learning_rate, betas=betas)
disc_opt = torch.optim.Adam(disc_model.parameters(), lr = learning_rate, betas=betas)

epoch_loss_values = [0] # list of generator  loss calculated at the end of each epoch
disc_loss_values = [0] # list of discriminator loss values calculated at end of each epoch
disc_train_steps = 10 # number of times to loop thru discriminator for each batch

gen_training_steps = int(num_train_files * num_patches / batch_size) # number of (generator) steps in an epoch
disc_training_steps = disc_train_steps*gen_training_steps #number of (discriminator) steps per epoch
max_epochs = 2
start = time.time()

train_loader = DataLoader(train_patches_dataset, batch_size=batch_size, shuffle=True)

for epoch in range(max_epochs):
    gen_model.train()
    disc_model.train()# setting models to training mode
    epoch_loss = 0 # initializing epoch loss for generator
    disc_epoch_loss = 0 # initializing  epoch loss for discriminator
    progress_bar(
            index = epoch +1,
            count = max_epochs, 
            desc = f"epoch {epoch + 1}, avg gen loss: {epoch_loss_values[-1]:.4f}, avg disc loss: {disc_loss_values[-1]:.4f}",
        )
    for i,train_batch in enumerate(train_loader):# iterating through dataloader
        
        gad_images =train_batch["image"].cuda()# initial gad image of batch
        nongad_images = train_batch["label"].cuda() # initial nongad image of batch that i plan on concatenating onto since bug with batch function in monai
        gen_opt.zero_grad()
        degad_images = gen_model(gad_images) # feeding CNN with gad images
        #disc_real_pred = disc_model(torch.cat([gad_images, nongad_images], dim=1))
        disc_fake_pred = disc_model(torch.cat([gad_images, degad_images], dim=1)) # getting disc losses when fed fake images
        
        gen_loss = GeneratorLoss(nongad_images, degad_images, disc_fake_pred) # getting generator losses
        gen_loss.backward()# computes gradient(derivative) of current tensor, automatically frees part of greaph that creates loss
        gen_opt.step() # updates parameters to minimize loss
        epoch_loss += gen_loss.item() # adding generator loss for this batch to total gen loss for this epoch
        for _ in range(disc_train_steps):
            gad_images = gad_images.clone().detach() # 
            nongad_images = nongad_images.clone().detach()#need to recall it for each iteration to avoid error message of backpropagation through a graph a second time after gradients have been freed
            degad_images = gen_model(gad_images) # feeding CNN with gad images
            disc_opt.zero_grad() # resetting gradient for discrminator to 0
            disc_real_pred = disc_model(torch.cat([gad_images, nongad_images], dim=1))
            disc_fake_pred = disc_model(torch.cat([gad_images, degad_images], dim=1)) # getting disc losses when fed fake images
            disc_loss = DiscriminatorLoss(disc_real_pred,disc_fake_pred)
            disc_loss.backward() #initializes back propagation to compute gradient of current tensors 
            disc_opt.step() # updates parameters to minimize loss
            disc_epoch_loss += disc_loss.item() # taking sum of disc loss for the number of steps for this batch
    epoch_loss = epoch_loss / gen_training_steps # epoch loss is the total loss by the end of that epoch divided by the number of steps
    epoch_loss_values.append(epoch_loss) #updates the loss value for that epoch
    disc_epoch_loss= disc_epoch_loss / disc_training_steps# average disc epoch loss is the total loss divided by the number of discriminator steps
    disc_loss_values.append(disc_epoch_loss) # avg disc loss is the total loss divided by the total disc steps in the epoch
end = time.time()
time = end - start





In [None]:
with open ('/home/fogunsan/scratch/degad/derivatives/GAN_network/April21/model_stats.txt', 'w') as file:  
    file.write(f'training time: {time} \n')  
    file.write(f'generator loss: {epoch_loss_values[-1]} discriminator loss: {disc_loss_values[-1]}')

In [None]:
################Saving trained generator and discriminator networks

torch.save(gen_model.state_dict(), "/home/fogunsan/scratch/degad/derivatives/GAN_network/April21/trained_generator.pt")
torch.save(disc_model.state_dict(), "/home/fogunsan/scratch/degad/derivatives/GAN_network/April21/trained_discriminator.pt")

In [None]:
plt.figure(figsize=(12, 5))
plt.semilogy(*zip(*gen_step_loss), label="Generator Loss")
plt.semilogy(*zip(*disc_step_loss), label="Discriminator Loss")
plt.grid(True, "both", "both")
plt.legend()
plt.savefig('/home/ROBARTS/fogunsanya/graham/scratch/degad/derivatives/GAN_network/April21/lossfunctions.png')
