In [None]:
!pip install monai
!pip install nibabel

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from torch import nn
from torchvision import datasets, transforms
import torchvision.transforms as T
import numpy as np
import matplotlib.pyplot as plt
from torch.optim import lr_scheduler
import numpy as np

import os
import monai
from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd, ResizeWithPadOrCropd, NormalizeIntensityd, ToTensor
from torch.utils.data import DataLoader, random_split, Dataset

device = 'cuda' if torch.cuda.is_available() else 'cpu'
rng = np.random.default_rng()

In [None]:
num_decoders = 8
decoder_depth = 5
decoder_channels = 32
img_dim = 64
img_channels = 5
batch_size = 16
corruption_channels = [0, 1, 2, 3, 4]

In [None]:
!ls /scratch/ras10116/shared/cvproject/

# Dataloader with transforms after slicing

In [None]:
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the path to the data
data_path = "" # TODO This path needs to point to an actual data source

class CustomDataset(Dataset):
    def __init__(self, file_paths, transforms, slice_range=(40, 76)):
        self.file_paths = file_paths
        self.transforms = transforms
        self.slice_range = slice_range
        self.images = self._load_images()
        self.slice_sets = self._create_slice_sets()

    def _load_images(self):
        loader = LoadImaged(keys=['T2'])
        images = []
        for file_path in self.file_paths:
            data = loader({'T2': file_path})
            images.append(data['T2'])
        return images

    def _create_slice_sets(self):
        slice_sets = []
        for i, _ in enumerate(self.images):
            for slice_start in range(*self.slice_range):
                slice_sets.append((i, slice_start))
        return slice_sets

    def __len__(self):
        return len(self.slice_sets)

    def __getitem__(self, idx):
        image_idx, slice_start = self.slice_sets[idx]
        sequential_slices = self.images[image_idx][:, :, slice_start:slice_start + 5]
        transformed_slices = self.transforms({'T2': sequential_slices})['T2']
        return transformed_slices.squeeze().permute(2, 0, 1)

# Define transformations, excluding LoadImaged
transforms = Compose([
    EnsureChannelFirstd(keys=['T2']),
    ResizeWithPadOrCropd(keys=['T2'], spatial_size=(64, 64, 5)),
    NormalizeIntensityd(keys=['T2'], nonzero=True, channel_wise=True),
    ToTensor()
])


# load images at the beginning.

In [None]:
# Instantiate the dataset
file_paths = [os.path.join(data_path, f) for f in os.listdir(data_path) if f.endswith('.nii.gz')]
dataset = CustomDataset(file_paths, transforms)

# Split dataset into training and validation sets
num_train = len(dataset)
split = int(num_train * 0.2)
train_set, valid_set = random_split(dataset, [num_train - split, split])

# DataLoaders for training and validation
train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_set, batch_size=4, shuffle=False, num_workers=4)

In [None]:
# Retrieve a batch for demonstration
train_loader_iterator = iter(train_loader)
batch = next(train_loader_iterator)

print(batch.shape)
import matplotlib.pyplot as plt

for i in range(5):
  # Example: Get the first sample of 5 slices
  first_item = batch[0][0]

  # Now you can display the image
  plt.imshow(first_item.squeeze(), cmap='gray')
  plt.show()


In [None]:
class normal_distribute_block(nn.Module):
    def __init__(self, img_dim, sd, cc = []):
        super().__init__()
        self.img_dim = img_dim
        self.sd = sd
        self.corruption_channels = cc
        rng = np.random.default_rng()

    def up_sd(self):
      self.sd = self.sd + 0.05

    def forward(self, x):
        initial_deviations = torch.ones(x.shape).to(device)
        perturbation = self.sd * torch.randn((self.img_dim, self.img_dim)).to(device)
        for i in self.corruption_channels:
            initial_deviations[:, i, :, :] = perturbation + torch.ones((self.img_dim, self.img_dim)).to(device)
        initial_mean = torch.zeros((self.img_dim, self.img_dim)).to(device)
        initial_x = x
        x = initial_deviations * x
        x = x + initial_mean
        return x

In [None]:
class diffusion_forward(nn.Module):
    def __init__(self, img_dim, num_layers):
        super().__init__()

        self.transformation = nn.ModuleList([
            normal_distribute_block(img_dim, 0.05, corruption_channels) for i in range(num_layers)
        ])

        self.num_layers = num_layers
        self.img_dim = img_dim
        self.rm_rec = 0
        self.rm_sq = 0

    def inc_diff(self):
      for block in self.transformation:
        block.up_sd()

    def up_removed_box(self):
      self.rm_sq = self.rm_sq + 1
      self.rm_rec = self.rm_rec + 1

    def forward(self, x):
        x = x.unsqueeze(0).repeat(self.num_layers+1, 1, 1, 1, 1)
        for i, blur in enumerate(self.transformation):
            x[i+1] = blur(x[i])

        return x

In [None]:
class Unet(nn.Module):
    def __init__(self, img_dim, num_layers, in_channels, initial_channels):
        super().__init__()

        self.convolution_list = []
        self.upscale_list = []
        self.num_layers = num_layers
        self.relu = nn.LeakyReLU()
        self.max_pool = nn.MaxPool2d(2, 2)
        self.output_conv = nn.Conv2d(initial_channels, img_channels, 1) 

        self.convolution_list = nn.ModuleList([])

        for i in range(num_layers):
            self.convolution_list.append(nn.ModuleList([nn.Conv2d(in_channels, initial_channels, 3, padding=1), nn.Conv2d(initial_channels, initial_channels, 3, padding=1), nn.Conv2d(initial_channels * 2, initial_channels, 3, padding=1)]))
            in_channels = initial_channels
            initial_channels = initial_channels * 2

        self.intermediate_conv = nn.Conv2d(in_channels, initial_channels, 3, padding = 1)
        self.middle_conv = nn.Conv2d(initial_channels, initial_channels, 3, padding = 1)

    def upscale(self, x):
        x = x.repeat_interleave(2, dim=2)
        x = x.repeat_interleave(2, dim=3)
        return x

    def forward(self, x):
        xs_list = []
        for layer in range(self.num_layers):
            x = self.convolution_list[layer][0](x)
            for i in range(1, 3):
                x = self.relu(self.convolution_list[layer][1](x))
            xs_list.append(x)
            x = self.max_pool(x)

        x = self.intermediate_conv(x)
        x = self.middle_conv(x)

        for layer in range(self.num_layers):
            backward_layer = self.num_layers - layer - 1
            x = self.upscale(x)
            x = self.convolution_list[backward_layer][2](x)
            x = x + xs_list[backward_layer]
            for i in range(1, 3):
                x = self.convolution_list[backward_layer][1](x)

        x = self.output_conv(x)
        return x


In [None]:
class diffusion_backward(nn.Module):
  def __init__(self, img_dim, num_layers, in_channels, initial_channels, num_decoders):
      super().__init__()

      self.num_decoders = num_decoders
      self.unets = nn.ModuleList([
          Unet(img_dim, num_layers, in_channels, initial_channels) for i in range(num_decoders)
      ])

  def forward(self, x):
      x_record = x.clone().unsqueeze(0).repeat(self.num_decoders+1, 1, 1, 1, 1)
      for i in range(len(self.unets)):
          x = self.unets[i](x.clone().detach())
          x_record[num_decoders - i - 1] = x

      return x_record

In [None]:
# Dataset preparation with transformations
dataset = CustomDataset(file_paths, transforms)

# Split dataset into training and validation sets
num_train = len(dataset)
split = int(num_train * 0.2)
train_set, valid_set = random_split(dataset, [num_train - split, split])

# DataLoaders for training and validation
train_loader = DataLoader(train_set, batch_size= batch_size, shuffle=True, num_workers=0)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, num_workers=0)



In [None]:
# Retrieve a batch for demonstration
train_loader_iterator = iter(train_loader)

batch = next(train_loader_iterator)
print(batch.shape)
import matplotlib.pyplot as plt
for batch_idx in range(4):
  image = batch[batch_idx][2]
  plt.imshow(image, cmap='gray')  # Display the third slice of the images --> this is the target
  plt.show()

print("All samples: ", num_train )


In [None]:
model1 = diffusion_forward(img_dim, num_decoders).to(device)
model2 = diffusion_backward(img_dim, decoder_depth, img_channels, decoder_channels, num_decoders).cuda()
loss = nn.MSELoss()

model2.train()

In [None]:
optimizers = []
for i in range(len(model2.unets)):
    optimizers.append(torch.optim.Adam(model2.unets[num_decoders - i - 1].parameters(), lr=1, eps=1))

torch.autograd.set_detect_anomaly(True)

schedulers = []
for i in range(len(model2.unets)):
    exp_lr_scheduler = lr_scheduler.StepLR(optimizers[i], step_size=50, gamma=0.5)

cumulative_loss = 0
for level in range(8):
    model1.inc_diff()
    if level % 2 == 0:
        model1.up_removed_box()
    for i in range(1, 500):
        cumulative_loss = 0
        for batch_idx, data in enumerate(train_loader):
            images = data.to(device)
            corrupted = model1(images)
            restored = model2(corrupted[num_decoders])
            total_loss = loss(restored, corrupted)

            for j in range(num_decoders):
                optimizers[j].zero_grad()

            cumulative_loss += total_loss.item()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model2.parameters(), 0.1)
            
            for j in range(num_decoders):
                optimizers[j].step()


        exp_lr_scheduler.step()

        
        print(f'Difficulty Level: {level} , Epoch no: {i} , Loss: {cumulative_loss}')

In [None]:
loss = nn.MSELoss()
# Validation step
model2.eval()  # Set the model to evaluation mode
validation_loss = 0.0

with torch.no_grad():  # Disable gradient computation
    for batch_idx, data in enumerate(valid_loader):  # Assume you have a validation_loader defined
        images = data.to(device)
        corrupted = model1(images)
        restored = model2(corrupted[num_decoders])
        loss_val = loss(restored[:,:,2:3,:,:], corrupted[:, :, 2:3, :, :])  # Use the appropriate loss function
        validation_loss += loss_val.item()  # Sum up batch loss

validation_loss /= len(valid_loader.dataset)  # Get the average loss over the entire validation dataset
print(f'Epoch: 1: Validation Loss: {validation_loss}')




In [None]:
batch = next(train_loader_iterator)
print(batch.shape)
images = batch.to(device)
import matplotlib.pyplot as plt
corrupted = model1(images)
restored = model2(corrupted[num_decoders].clone())
print(restored.shape)
image = images[0, 2].cpu().detach()
plt.imshow(image, cmap='gray')
plt.show()
image = corrupted[-1, 0, 2].cpu().detach()
plt.imshow(image, cmap='gray')
plt.show()
image = restored[0, 0, 2].cpu().detach()
plt.imshow(image, cmap='gray')
plt.show()
print(nn.MSELoss()(corrupted[-1, 0, 2], images[0, 2]).item())
print(nn.MSELoss()(restored[0, 0, 2], images[0, 2]).item())

In [None]:
file_path = os.path.join(os.getcwd(), "diffusion_model.pth")
torch.save(model2.state_dict(), file_path)