**Example of GAN algorithm:** https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

**Kaggle Dataset:** https://www.kaggle.com/datasets/chenghanpu/brain-tumor-mri-and-ct-scan


In [None]:
# If running on google colab, uncomment the following line
# !pip install -r requirements.txt

In [None]:
# Get dataset
import gdown 

url = "https://drive.google.com/drive/folders/1s5Y-JimbgWDvQy5XW8xt7HK1RerpPAmv?usp=drive_link"
gdown.download_folder(url, output="." , quiet=False)

In [None]:
# Imports
import torch
import numpy
import pandas
import numpy as np

MRI_train = np.load('data/final_project/test_input.npy')
CT_train = np.load('data/final_project/train_output.npy')

In [None]:
# prompt: visualize the data above the data points are 2D images of size: size 65536

import matplotlib.pyplot as plt

# Assuming MRI_train contains 2D images of size 256x256
image_size = 256

for i in range(5):
  plt.imshow(MRI_train[i].reshape(image_size, image_size), cmap='gray')
  plt.show()
  plt.imshow(CT_train[i].reshape(image_size, image_size), cmap='gray')
  plt.show()
  print(CT_train[i].shape)

**Training**  
Adversarial Loss (Discriminator Loss):

Purpose: This loss measures how well the discriminators can distinguish between real and fake samples.
Indicator: A decreasing adversarial loss indicates that the generators are creating more realistic images that fool the discriminators.

Cycle Consistency Loss:

Purpose: This loss ensures that the mappings from one domain to another are consistent.
Indicator: A decreasing cycle consistency loss indicates that the generators are learning to generate images that are consistent when translated back and forth between the two domains.

Identity Loss (Optional):

Purpose: If included, identity loss ensures that the generators do not change the input image unnecessarily.
Indicator: A low identity loss indicates that the generators are preserving the input image structure when it belongs to the target domain.

Overall Generator Loss:

Purpose: This is the sum of the adversarial loss, cycle consistency loss, and identity loss (if used).
Indicator: Monitoring the overall generator loss provides an overview of how well the generators are performing across all aspects.

Discriminator Loss: The discriminator's goal is to distinguish between real and fake samples effectively, which translates to maximizing its loss function. The loss will therefore increase in magnitude and become negative

In [None]:
# Necessary Imports for training:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from sklearn.metrics import mean_squared_error
from gen_model import Generator
from dis_model import Discriminator

## Define generators and discriminators

# MRI images and CT scan images are grayscale therefore our input_channel = 1 and output_channel = 1 (With this kaggle dataset)
mri_channel = 1
ct_channel = 1

# Generators
generator_mri2ct = Generator(1, 1)
generator_ct2mri = Generator(ct_channel, mri_channel)

# Discriminators
discriminator_mri = Discriminator(mri_channel)
discriminator_ct = Discriminator(ct_channel)

# Define Optimizers *betas control the exponential decay rates*
optimizer_gen = torch.optim.Adam(list(generator_mri2ct.parameters()) + list(generator_ct2mri.parameters()), lr=1e-4, betas=(0.5, 0.999))
optimizer_dis_mri = torch.optim.Adam(list(discriminator_mri.parameters()) + list(discriminator_ct.parameters()), lr=1e-6, betas=(0.5, 0.999))
optimizer_dis_ct = torch.optim.Adam(list(discriminator_mri.parameters()) + list(discriminator_ct.parameters()), lr=1e-6, betas=(0.5, 0.999))


# Training Loop parameters
num_epochs = 10
batch_size = 4  # Adjust this based on system's memory
lambda_cycle = 10  # Adjust this value based on experimentation
lambda_identity = 5  # If you use identity loss, adjust this as needed

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move models and tensors to GPU
generator_mri2ct.to(device)
generator_ct2mri.to(device)
discriminator_mri.to(device)
discriminator_ct.to(device)
# Create data loader for training
dataset = list(zip(MRI_train, CT_train))  # Combine MRI_train and CT_train
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Criterion
criterion = nn.BCEWithLogitsLoss()
criterion_cycle = nn.SmoothL1Loss()

# Identity Loss
use_identity_loss = False
loss_identity_ct = "NA"
loss_identity_mri = "NA"

# Training Loop
for epoch in range(num_epochs):
    # Set models to training mode
    generator_mri2ct.train()
    generator_ct2mri.train()
    discriminator_mri.train()
    discriminator_ct.train()

    for i, (real_mri, real_ct) in enumerate(dataloader):
        real_mri = real_mri.to(device)
        real_ct = real_ct.to(device)

        real_mri = real_mri.float().unsqueeze(0)
        real_ct = real_ct.float().unsqueeze(0)


        # Reshaping Tensor to match the generator input shape
        real_mri = real_mri.permute(1, 0, 2, 3)
        real_ct = real_ct.permute(1, 0, 2, 3)

        # Generate fake images
        fake_ct = generator_mri2ct(real_mri)
        fake_mri = generator_ct2mri(real_ct)

        ### Discriminator Training ###
        optimizer_dis_mri.zero_grad()
        optimizer_dis_ct.zero_grad()

        # Compute discriminator losses for MRI and CT images
        loss_dis_mri_real = discriminator_mri(real_mri).mean()
        loss_dis_mri_fake = discriminator_mri(fake_mri.detach()).mean()
        loss_dis_mri = loss_dis_mri_fake - loss_dis_mri_real

        loss_dis_ct_real = discriminator_ct(real_ct).mean()
        loss_dis_ct_fake = discriminator_ct(fake_ct.detach()).mean()
        loss_dis_ct = loss_dis_ct_fake - loss_dis_ct_real

        # Total discriminator loss
        loss_dis = (loss_dis_mri + loss_dis_ct) / 2

        # Backpropagation and optimizer step for discriminator
        loss_dis.backward()
        optimizer_dis_mri.step()
        optimizer_dis_ct.step()

        ### Generator Training ###
        optimizer_gen.zero_grad()

        # Adversarial losses for generators
        loss_gen_mri2ct_adv = -discriminator_ct(fake_ct).mean()
        loss_gen_ct2mri_adv = -discriminator_mri(fake_mri).mean()

        # Cycle consistency losses
        reconstructed_mri = generator_ct2mri(fake_ct)
        reconstructed_ct = generator_mri2ct(fake_mri)

        loss_cycle_mri = criterion_cycle(reconstructed_mri, real_mri)
        loss_cycle_ct = criterion_cycle(reconstructed_ct, real_ct)

        # Identity losses (optional)
        if use_identity_loss:
            identity_mri = generator_ct2mri(real_mri)
            identity_ct = generator_mri2ct(real_ct)

            loss_identity_mri = criterion_identity(identity_mri, real_mri)
            loss_identity_ct = criterion_identity(identity_ct, real_ct)

        # Total generator loss
        loss_gen = (loss_gen_mri2ct_adv + loss_gen_ct2mri_adv) + lambda_cycle * (loss_cycle_mri + loss_cycle_ct)
        if use_identity_loss:
            loss_gen += lambda_identity * (loss_identity_mri + loss_identity_ct)

        # Backpropagation and optimizer step for generator
        loss_gen.backward()
        optimizer_gen.step()

        # Print loss statistics
        if i % 100 == 0:
            print('[%d/%d][%d/%d] Loss_D_MRI: %.4f Loss_D_CT: %.4f Loss_G_MRI2CT_adv: %.4f Loss_G_CT2MRI_adv: %.4f Loss_Cycle_MRI: %.4f Loss_Cycle_CT: %.4f' %
                  (epoch, num_epochs, i, len(dataloader), loss_dis_mri.item(), loss_dis_ct.item(), loss_gen_mri2ct_adv.item(), loss_gen_ct2mri_adv.item(), loss_cycle_mri.item(), loss_cycle_ct.item()))
