## Deep Learning 23/24
### Project

Beatriz Moreira, FC54514 \
Rute Patuleia, FC51780 \
Tiago Assis, FC62609

In [None]:
import warnings 
warnings.filterwarnings('ignore')
import os
import shutil
import random
import time
from datetime import datetime
import gc 
import pandas as pd
import numpy as np
import math
import matplotlib.pyplot as plt
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
%load_ext tensorboard
import torchvision.transforms as FT
!pip install torchinfo
from torchinfo import summary
import nibabel as nib
import glob
from tqdm import tqdm

## Data loading, preprocessing and augmenting

In [None]:
# Defining a class to handle data augmentations
class Augmentations:
    def __init__(self, img, mask):
        self.img = img
        self.mask = mask
       
    
    def rotate_scale(self):
        """
        Rotates and scales images equally on every dimension. Scaling and rotation are applied with a probability of 0.2 each.
        The angles of rotation (in degrees) are each drawn from U(-30, 30).
        Scaling is implemented via multiplying coordinates with a scaling factor in the voxel grid.
        Thus, scale factors smaller than one result in a "zoom out" effect while values larger one result in a "zoom in" effect. 
        The scaling factor is sampled from U(0.7, 1.4).
        Interpolation for images is done linearly, while masks are interpolated using the nearest pixels.
        """
        angle = random.uniform(-30, 30)
        if angle < 0: # FT.functional.rotate only works with positive angles
            angle += 360
        scale_factor = random.uniform(0.7, 1.4)
        if random.random() < 0.2:
            self.mask = FT.functional.rotate(self.mask,angle, interpolation=FT.InterpolationMode.NEAREST)
            self.img = FT.functional.rotate(self.img, angle, interpolation=FT.InterpolationMode.BILINEAR)
        if random.random() < 0.2:
            self.mask = F.interpolate(self.mask.unsqueeze(0), 
                                      [int(shape * scale_factor) for shape in self.mask.shape[1:]], 
                                      mode="nearest").squeeze(0)
            self.img = F.interpolate(self.img.unsqueeze(0), 
                                     [int(shape * scale_factor) for shape in self.img.shape[1:]], 
                                     mode="trilinear", 
                                     align_corners=True).squeeze(0)
        return self.img, self.mask
        
        
    def gaussian_noise(self):
        """
        Zero centered additive Gaussian noise is added to each voxel in the sample independently. 
        This augmentation is applied with a probability of 0.15. The variance of the noise is drawn from U(0, 0.1).
        """
        if random.random() < 0.15:
            variance = random.uniform(0, 0.1)
            noise = torch.randn_like(self.img) * torch.sqrt(torch.tensor(variance))
            self.img += noise
        return self.img
        
        
    def gaussian_blur(self):
        """
        Blurring is applied with a probability of 0.2 per sample. 
        If this augmentation is triggered in a sample, blurring is applied with a probability of 0.5 for each of the associated modalities. 
        The width (in voxels) of the Gaussian kernel, sigma, is sampled from U(0.5, 1.5) independently for each modality.
        """
        if random.random() < 0.2:
            for channel in range(self.img.shape[0]):
                if random.random() < 0.5:
                    sigma = random.uniform(0.5, 1.5)
                    self.img[channel,:,:,:] = FT.functional.gaussian_blur(self.img[channel,:,:,:], kernel_size=3, sigma=sigma)
        return self.img
    
    
    def brightness(self):
        """
        Voxel intensities are multiplied by x ~ U(0.7, 1.3) with a probability of 0.15.
        """
        if random.random() < 0.15:
            brightness_factor = torch.tensor(np.random.uniform(0.7, 1.3))
            self.img *= brightness_factor
        return self.img
        
        
    def contrast(self):
        """
        Voxel intensities are multiplied by x ~ U(0.65, 1.5) with a probability of 0.15.
        Following multiplication, the values are clipped to their original value range.
        """
        if random.random() < 0.15:
            contrast_factor = torch.tensor(np.random.uniform(0.65, 1.5))
            min_val = torch.min(self.img)
            max_val = torch.max(self.img)
            self.img *= contrast_factor
            self.img = torch.clamp(self.img, min_val, max_val)
        return self.img
    
    
    def low_resolution(self):
        """
        This augmentation is applied with a probability of 0.25 per sample and 0.5 per associated modality. 
        Triggered modalities are downsampled by a factor of U(1, 2) using nearest neighbor interpolation 
        and then sampled back up to their original size with cubic interpolation.
        """
        if random.random() < 0.25:
            for channel in range(self.img.shape[0]):
                if random.random() < 0.5:
                    scale_factor = np.random.uniform(1, 2)
                    scaled = F.interpolate(self.img[channel,:,:,:].unsqueeze(0).unsqueeze(0), 
                                           [int(shape * 1 / scale_factor) for shape in self.img.shape[1:]], 
                                           mode='nearest')
                    scaled = F.interpolate(scaled, 
                                           self.img.shape[1:], 
                                           mode='trilinear', 
                                           align_corners=True).squeeze(0).squeeze(0)
                    self.img[channel,:,:,:] = scaled
        return self.img
        
        
    def gamma(self):
        """
        This augmentation is applied with a probability of 0.15. 
        The patch intensities are scaled to a factor of [0, 1] of their respective value range. 
        Then, a nonlinear intensity transformation is applied per voxel: { i_new = i_old ** gamma } with gamma ~ U(0.7, 1.5).
        The voxel intensities are subsequently scaled back to their original value range. 
        With a probability of 0.15, this augmentation is applied with the voxel intensities being inverted
        prior to transformation: { (1 - i_new) = (1 - i_old) ** gamma }
        """
        if random.random() < 0.15:
            gamma = torch.tensor(random.uniform(0.7, 1.5))
            min_val = torch.min(self.img)
            max_val = torch.max(self.img)
            self.img = (self.img - min_val) / (max_val - min_val)
            if random.random() < 0.15:
                self.img = 1 - self.img
            self.img = self.img ** gamma
            self.img = self.img * (max_val - min_val) + min_val
        return self.img
        
        
    def mirror(self):
        """
        All patches are mirrored along all axes with a probability of 0.5.
        """
        if random.random() < 0.5:
            axes = tuple(range(len(self.img.shape[1:])))
            self.img = torch.flip(self.img, dims=axes)
            self.mask = torch.flip(self.mask, dims=axes)
        return self.img, self.mask
    
    
    def transforms(self):
        self.img, self.mask = self.rotate_scale()
        self.img = self.gaussian_noise()
        self.img = self.gaussian_blur()
        self.img = self.brightness()
        self.img = self.contrast()
        self.img = self.low_resolution()
        self.img = self.gamma()
        self.img, self.mask = self.mirror()
        return self.img, self.mask

In [2]:
# Defining a class to handle image loading, preprocessing and augmentations
class BratsDataset(Dataset):
    def __init__(self, root_paths, test=False):
        self.test = test
        self.root_paths = root_paths
        # Getting all image paths from the 3 parts
        self.folder_paths = sorted([f"{root_path}/{folder}" for root_path in root_paths for folder in os.listdir(f"{root_path}")])


    def __getitem__(self, index):
        folder = self.folder_paths[index]
        combined_img = self.preprocess_images(folder)
        # If not test set, work with images and mask, otherwise only the images
        if not self.test:
            mask = self.preprocess_masks(folder)
            return combined_img, mask
        return combined_img, folder


    def preprocess_images(self, folder):
        imgs = []
        for modality in ["-t1c.nii", "-t1n.nii", "-t2w.nii", "-t2f.nii"]:
            # Get image path based on folder name and modality
            img_path = os.path.join(folder, folder.split("/")[-1] + modality) 
            # To handle Kaggle messing up zipped folders:
            try:
                img = nib.load(img_path).get_fdata()
            except nib.filebasedimages.ImageFileError: 
                file = os.listdir(img_path)[0]
                img_path = os.path.join(img_path, file)
                img = nib.load(img_path).get_fdata()
            # Crop irrelevant parts of the volume
            if not self.test:
                img = img[35:-25,40:-20,:140] 
            # Standardizing considering only the brain volume
            img[img > 0] = (img[img > 0] - np.mean(img[img > 0])) / np.std(img[img > 0]) 
            img = torch.tensor(img, dtype=torch.float32)
            imgs.append(img)
        # Stack as a 4 channel input; 1 channel per modality
        combined_img = torch.stack(imgs) 
        return combined_img


    def preprocess_masks(self, folder):
        mask_path = os.path.join(folder, folder.split("/")[-1] + "-seg.nii")
        mask = nib.load(mask_path).get_fdata()
        mask = mask[35:-25,40:-20,:140]
        mask = torch.tensor(mask, dtype=torch.uint8)
        # Reassign mask labels considering the challenge targets:
        # Label 1 - ET: considers only ET
        mask_ET = torch.zeros_like(mask)
        mask_ET[mask == 3] = 1
        # Label 2 - TC: considers only NCR and Edema
        mask_TC = torch.zeros_like(mask)
        mask_TC[(mask == 1) | (mask == 3)] = 1
        # Label 3 - WT: considers every part of the tumor
        mask_WT = torch.zeros_like(mask)
        mask_WT[(mask == 1) | (mask == 2) | (mask == 3)] = 1
        # Stack as a 3 channel mask; 1 channel per label
        mask = torch.stack([mask_ET, mask_TC, mask_WT]) 
        return mask
    
                                          
    def __len__(self):
        return len(self.folder_paths)

## Model loss function and evaluation metrics

In [None]:
class TotalDiceLoss(nn.Module):
    def __init__(self, smooth=1.):
        super().__init__()
        self.smooth = smooth

    def forward(self, outputs, targets):
        outputs = F.sigmoid(outputs)
        total_loss = 0
        for channel in range(3):
            o_flat = outputs[:,channel,:,:,:].contiguous().view(-1)
            t_flat = targets[:,channel,:,:,:].contiguous().view(-1)
            intersection = torch.sum(o_flat * t_flat)
            sums = torch.sum(o_flat) + torch.sum(t_flat)
            dice = (2. * intersection + self.smooth) / (sums + self.smooth)
            total_loss += 1 - dice
        return total_loss / 3

In [None]:
def dice_per_class(outputs, targets, mean_only=False, d_slice=None, smooth=1.):
    outputs = F.sigmoid(outputs)
    outputs = torch.where(outputs > 0.5, torch.tensor(1), torch.tensor(0))
    dice_et = None
    dice_tc = None
    dice_wt = None
    for channel in range(3):
        if d_slice is not None:
            o_flat = outputs[channel,:,:,d_slice].contiguous().view(-1)
            t_flat = targets[channel,:,:,d_slice].contiguous().view(-1)
        else:
            o_flat = outputs[:,channel,:,:,:].contiguous().view(-1)
            t_flat = targets[:,channel,:,:,:].contiguous().view(-1)
        intersection = torch.sum(o_flat * t_flat)
        sums = torch.sum(o_flat) + torch.sum(t_flat)
        dice = (2. * intersection + smooth) / (sums + smooth)
        if channel == 0:
            dice_et = dice.item()
        elif channel == 1:
            dice_tc = dice.item()
        elif channel == 2:
            dice_wt = dice.item()
    if mean_only:
        return (dice_et + dice_tc + dice_wt) / 3
    return dice_et, dice_tc, dice_wt

## Dilated U-Net

In [3]:
class DoubleConvEncoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.InstanceNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.InstanceNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.double_conv(x)
        return x


class DoubleConvDecoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels*2, in_channels, kernel_size=3, padding=1),
            nn.InstanceNorm3d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.InstanceNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.double_conv(x)
        return x
    

class DoubleDilatedConv(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.double_dilated_conv = nn.Sequential(
            nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=2, dilation=2),
            nn.InstanceNorm3d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=2, dilation=2),
            nn.InstanceNorm3d(in_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        supervision1 = self.double_dilated_conv(x)
        out = torch.cat([supervision1, x], 1)
        return out, supervision1
        
        
class DilatedConvBottleneck(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = DoubleConvEncoder(in_channels, out_channels)
        self.double_dilated_conv = DoubleDilatedConv(out_channels)
        self.conv_out = nn.Sequential(
            nn.Conv3d(out_channels*2, in_channels, kernel_size=3, padding=1),
            nn.InstanceNorm3d(in_channels),
            nn.ReLU(inplace=True),
        )
        
    def forward(self, x):
        x = self.double_conv(x)
        x, supervision1 = self.double_dilated_conv(x)
        x = self.conv_out(x)
        return x, supervision1


class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = DoubleConvEncoder(in_channels, out_channels)
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)

    def forward(self, x):
        concat = self.double_conv(x)
        x = self.pool(concat)
        return x, concat


class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2., mode="trilinear", align_corners=True)
        self.double_conv = DoubleConvDecoder(in_channels, out_channels)

    def forward(self, x, concat):
        x = self.up(x)
        if x.size() != concat.size():
            x = F.interpolate(x, size=concat.size()[2:], mode='trilinear', align_corners=True)
        x = torch.cat([x, concat], 1)
        x = self.double_conv(x)
        return x
    

# Initialize weights with He initialization
def weights_init_he(m):
    if isinstance(m, nn.Conv3d):
        nn.init.kaiming_normal_(m.weight.data, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 1e-4)

In [4]:
class DilatedUNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.down_conv1 = DownSample(in_channels, 48)
        self.down_conv2 = DownSample(48, 96)
        self.down_conv3 = DownSample(96, 192)
        self.bottle_neck = DilatedConvBottleneck(192, 384)
        self.up_conv1 = UpSample(192, 96)
        self.up_conv2 = UpSample(96, 48)
        self.up_conv3 = UpSample(48, 48)
        self.out1 = nn.Conv3d(384, num_classes, kernel_size=1)
        self.out2 = nn.Conv3d(192, num_classes, kernel_size=1)
        self.out3 = nn.Conv3d(96, num_classes, kernel_size=1)
        self.out4 = nn.Conv3d(48, num_classes, kernel_size=1)
        self.final_out = nn.Conv3d(48, num_classes, kernel_size=1)
        self.apply(weights_init_he)

    def forward(self, x):
        x, down1 = self.down_conv1(x)
        x, down2 = self.down_conv2(x)
        x, down3 = self.down_conv3(x)
        supervision2, supervision1 = self.bottle_neck(x)
        supervision3 = self.up_conv1(supervision2, down3)
        supervision4 = self.up_conv2(supervision3, down2)
        x = self.up_conv3(supervision4, down1)
        
        final_out = self.final_out(x)
        supervision4 = F.interpolate(self.out4(supervision4), final_out.size()[2:], mode="trilinear", align_corners=True)
        supervision3 = F.interpolate(self.out3(supervision3), final_out.size()[2:], mode="trilinear", align_corners=True)
        supervision2 = F.interpolate(self.out2(supervision2), final_out.size()[2:], mode="trilinear", align_corners=True)
        supervision1 = F.interpolate(self.out1(supervision1), final_out.size()[2:], mode="trilinear", align_corners=True)
        return final_out, supervision4, supervision3, supervision2, supervision1

## Baseline nnU-Net

In [5]:
class nnUNet_DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.InstanceNorm3d(out_channels),
            nn.LeakyReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.InstanceNorm3d(out_channels),
            nn.LeakyReLU(inplace=True)
        )

    def forward(self, x):
        x = self.double_conv(x)
        return x      


class nnUNet_DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nnUNet_DoubleConv(in_channels, out_channels)
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)

    def forward(self, x):
        concat = self.double_conv(x)
        x = self.pool(concat)
        return x, concat


class nnUNet_UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up_conv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)
        self.double_conv = nnUNet_DoubleConv(out_channels*2, out_channels)

    def forward(self, x, concat):
        x = self.up_conv(x)
        if x.size() != concat.size():
            x = F.interpolate(x, size=concat.size()[2:], mode='trilinear', align_corners=True)
        x = torch.cat([x, concat], 1)
        x = self.double_conv(x)
        return x
    

# Initialize weights with He initialization
def weights_init_he_nnunet(m):
    if isinstance(m, nn.Conv3d):
        nn.init.kaiming_normal_(m.weight.data, mode='fan_in', nonlinearity='leaky_relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 1e-4)

In [6]:
class nnUNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.encoder1 = nnUNet_DownSample(in_channels, 32)
        self.encoder2 = nnUNet_DownSample(32, 64)
        self.encoder3 = nnUNet_DownSample(64, 128)
        self.encoder4 = nnUNet_DownSample(128, 256)
        self.encoder5 = nnUNet_DownSample(256, 320)
        self.bottle_neck = nnUNet_DoubleConv(320, 320)
        self.decoder1 = nnUNet_UpSample(320, 320)
        self.decoder2 = nnUNet_UpSample(320, 256)
        self.decoder3 = nnUNet_UpSample(256, 128)
        self.decoder4 = nnUNet_UpSample(128, 64)
        self.decoder5 = nnUNet_UpSample(64, 32)
        self.out1 = nn.Conv3d(256, num_classes, kernel_size=1)
        self.out2 = nn.Conv3d(128, num_classes, kernel_size=1)
        self.out3 = nn.Conv3d(64, num_classes, kernel_size=1)
        self.final_out = nn.Conv3d(32, num_classes, kernel_size=1)
        self.apply(weights_init_he_nnunet)

    def forward(self, x):
        x, down1 = self.encoder1(x)
        x, down2 = self.encoder2(x)
        x, down3 = self.encoder3(x)
        x, down4 = self.encoder4(x)
        x, down5 = self.encoder5(x)
        x = self.bottle_neck(x)
        x = self.decoder1(x, down5)
        supervision1 = self.decoder2(x, down4)
        supervision2 = self.decoder3(supervision1, down3)
        supervision3 = self.decoder4(supervision2, down2)
        x = self.decoder5(supervision3, down1)
        
        final_out = self.final_out(x)
        supervision3 = F.interpolate(self.out3(supervision3), final_out.size()[2:], mode="trilinear", align_corners=True)
        supervision2 = F.interpolate(self.out2(supervision2), final_out.size()[2:], mode="trilinear", align_corners=True)
        supervision1 = F.interpolate(self.out1(supervision1), final_out.size()[2:], mode="trilinear", align_corners=True)
        return final_out, supervision3, supervision2, supervision1

## Train and validation functions

In [None]:
def plot_pred_vs_truth(preds, mask, img, dices, d_slice):    
    fig = plt.figure(figsize=(15,5))
    for i in range(len(preds)-1):
        preds[i] = torch.where(preds[i] > 0.5, torch.tensor(255), torch.tensor(0))
        ax = fig.add_subplot(2,len(preds)-1,i+1)
        plt.imshow(preds[i][:,:,:,d_slice].permute(1,2,0))
        ax.set_title(f"Dice: {dices[i]:.4f} / Supervision level {i+1}")
        
    preds[-1] = torch.where(preds[-1] > 0.5, torch.tensor(255), torch.tensor(0))
    ax = fig.add_subplot(2,len(preds)-1,len(preds))
    plt.imshow(preds[-1][:,:,:,d_slice].permute(1,2,0))
    ax.set_title(f"Dice: {dices[-1]:.4f} / Final prediction")
    
    ax = fig.add_subplot(2,len(preds)-1,len(preds)+1)
    plt.imshow(mask[:,:,:,d_slice].permute(1,2,0) * 255)
    ax.set_title("Ground truth")
    
    ax = fig.add_subplot(2,len(preds)-1,len(preds)+2)
    plt.imshow(img[3,:,:,d_slice], cmap="gray")
    ax.set_title("T2-FLAIR")
    
    plt.tight_layout()
    return fig

In [None]:
def save_metrics(writer, running_loss, running_dice, epoch, loader_size, i, step_size, val=False):
    global CHECKPOINT
    if val:
        writer.add_scalar('Dice/Val/Mean', running_dice[0] / step_size, (epoch + CHECKPOINT - 1) * loader_size + i + 1)
        writer.add_scalar('Dice/Val/ET', running_dice[1] / step_size, (epoch + CHECKPOINT - 1) * loader_size + i + 1)
        writer.add_scalar('Dice/Val/TC', running_dice[2] / step_size, (epoch + CHECKPOINT - 1) * loader_size + i + 1)
        writer.add_scalar('Dice/Val/WT', running_dice[3] / step_size, (epoch + CHECKPOINT - 1) * loader_size + i + 1)
    else:
        writer.add_scalar('Loss/Train/', running_loss / step_size, (epoch + CHECKPOINT - 1) * loader_size + i + 1)
        writer.add_scalar('Dice/Train/', running_dice / step_size, (epoch + CHECKPOINT - 1) * loader_size + i + 1)
    writer.flush()


def save_plots(writer, preds, mask, img, epoch, loader_size, i, d_slice=85, val=False):
    global CHECKPOINT
    preds = [pred[0].cpu().detach() for pred in preds]
    mask = mask[0].cpu().detach()
    img = img[0].cpu().detach()
    dices = [dice_per_class(pred, mask, mean_only=True, d_slice=d_slice) for pred in preds]
    writer.add_figure('Pred-Truth/Val' if val else 'Pred-Truth/Train',
                      plot_pred_vs_truth(preds, mask, img, dices, d_slice=d_slice),
                      global_step=(epoch + CHECKPOINT - 1) * loader_size + i + 1)
    writer.flush()

In [None]:
def save_checkpoint(epoch, prev_point, model, optimizer, scheduler):
    states = {
        "epoch": epoch + prev_point,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict()
    }
    torch.save(states, f"{MODEL_SAVE_PATH}/{MODEL_NAME}-epoch{epoch + prev_point}-{datetime.now().strftime('%Y-%m-%d_%H-%M')}.pth")
    
    
def load_checkpoint(model, optimizer, scheduler, load_path):
    states = torch.load(load_path)
    epoch = states["epoch"]
    model.load_state_dict(states["model_state_dict"])
    optimizer.load_state_dict(states["optimizer_state_dict"])
    scheduler.load_state_dict(states["scheduler_state_dict"])
    scheduler.step()
    return epoch

In [None]:
def train_model(writer, model, loader, loss_fn1, loss_fn2, optimizer, epoch):
    running_total_loss_per_batch = 0.0
    running_dice_per_batch = 0.0
    
    model.train()
    for i, (imgs, masks) in enumerate(loader):
        imgs = imgs.to(DEVICE)
        masks = masks.to(DEVICE)
        
        imgs_to_restack = []
        masks_to_restack = []
        for j in range(imgs.shape[0]):
            data_augmentation = Augmentations(imgs[j], masks[j])
            img, mask = data_augmentation.transforms()
            img = F.interpolate(img.unsqueeze(0), (128,128,128), mode="trilinear", align_corners=True).squeeze(0)
            mask = F.interpolate(mask.unsqueeze(0), (128,128,128), mode="nearest").squeeze(0)
            imgs_to_restack.append(img)
            masks_to_restack.append(mask)
        imgs = torch.stack(imgs_to_restack)
        masks = torch.stack(masks_to_restack)
        
        optimizer.zero_grad()
        
        preds = model(imgs)
        
        if loss_fn2 is not None:
            losses = [loss_fn1(pred, masks) + loss_fn2(pred, masks.to(torch.float32)) for pred in preds]
        else:
            losses = [loss_fn1(pred, masks) for pred in preds]
        total_loss = sum(losses)
        
        running_total_loss_per_batch += total_loss.item()
        running_dice_per_batch += 1 - loss_fn1(preds[0], masks).item()
        
        if i % 20 == 19:
            save_metrics(writer, running_total_loss_per_batch, running_dice_per_batch, epoch, len(loader), i, step_size=20)
            running_total_loss_per_batch = 0.0
            running_dice_per_batch = 0.0
        if i % 100 == 99:
            save_plots(writer, preds[::-1], masks, imgs, epoch, len(loader), i)

        total_loss.backward()
        optimizer.step()
        

def eval_model(writer, model, loader, epoch):
    running_mean_dice_per_batch = 0.0
    running_dice_et_per_batch = 0.0
    running_dice_tc_per_batch = 0.0
    running_dice_wt_per_batch = 0.0
    
    model.eval()
    with torch.no_grad():
        for i, (imgs, masks) in enumerate(loader):
            imgs = imgs.to(DEVICE)
            masks = masks.to(DEVICE)
            imgs = F.interpolate(imgs, (128,128,128), mode="trilinear", align_corners=True)
            masks = F.interpolate(masks, (128,128,128), mode="nearest")
            
            preds = model(imgs)
            
            final_dices = dice_per_class(preds[0], masks)
            final_mean_dice = sum(final_dices) / 3
            running_mean_dice_per_batch += final_mean_dice
            running_dice_et_per_batch += final_dices[0]
            running_dice_tc_per_batch += final_dices[1]
            running_dice_wt_per_batch += final_dices[2]
            
            if i % 5 == 4:
                running_dices = [running_mean_dice_per_batch, running_dice_et_per_batch, running_dice_tc_per_batch, running_dice_wt_per_batch]
                save_metrics(writer, None, running_dices, epoch, len(loader), i, step_size=5, val=True)
                running_mean_dice_per_batch = 0.0
                running_dice_et_per_batch = 0.0
                running_dice_tc_per_batch = 0.0
                running_dice_wt_per_batch = 0.0
            if i % 20 == 19:
                save_plots(writer, preds[::-1], masks, imgs, epoch, len(loader), i, val=True)

In [None]:
def trainer(CHECKPOINT, writer, model, train_dataloader, val_dataloader, loss_fn1, loss_fn2, optimizer, scheduler):
    for epoch in range(1,EPOCHS+1):
        train_model(writer, model, train_dataloader, loss_fn1, loss_fn2, optimizer, epoch)
        if epoch % 2 == 1:
            save_checkpoint(epoch, CHECKPOINT, model, optimizer, scheduler)
            eval_model(writer, model, val_dataloader, epoch)
        writer.add_scalar('LR', scheduler.get_last_lr()[0], epoch + CHECKPOINT)
        writer.flush()
        scheduler.step()

## Configs

In [None]:
MODEL_TO_TRAIN = "nnunet"

In [None]:
if MODEL_TO_TRAIN == "dilatedunet":
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    SPLIT = [0.8, 0.2]
    LEARNING_RATE = 1e-2
    BATCH_SIZE = 2
    EPOCHS = 15
    TRAIN_PATHS = ["/kaggle/input/brats2023-part1", "/kaggle/input/brats2023-part2", "/kaggle/input/brats2023-part3"]
    MODEL_NAME = "dilatedunet_default"
    MODEL_SAVE_PATH = f"/kaggle/working/{MODEL_NAME}/"
    RESUME_CHECKPOINT_PATH = None
    os.makedirs(MODEL_SAVE_PATH)

    train_dataset = BratsDataset(TRAIN_PATHS)

    train_dataset, val_dataset = random_split(train_dataset, SPLIT, generator=torch.Generator().manual_seed(13))

    train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_dataloader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=True)

    CHECKPOINT = 0
    model = DilatedUNet(in_channels=4, num_classes=3).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 15)
    loss_fn1 = TotalDiceLoss()
    loss_fn2 = nn.BCEWithLogitsLoss()
    if RESUME_CHECKPOINT_PATH is not None:
        CHECKPOINT = load_checkpoint(model, optimizer, scheduler, RESUME_CHECKPOINT_PATH)

In [None]:
if MODEL_TO_TRAIN == "dilatedunet":
    writer_dilated = SummaryWriter(f'/kaggle/working/runs/{MODEL_NAME}_{CHECKPOINT}_{datetime.now().strftime("%Y-%m-%d_%H-%M")}')

In [None]:
if MODEL_TO_TRAIN == "dilatedunet":
    trainer(CHECKPOINT, writer_dilated, model, train_dataloader, val_dataloader, loss_fn1, loss_fn2, optimizer, scheduler)

---

In [None]:
if MODEL_TO_TRAIN == "nnunet":
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    SPLIT = [0.8, 0.2]
    LEARNING_RATE = 1e-4
    BATCH_SIZE = 2
    EPOCHS = 15
    TRAIN_PATHS = ["/kaggle/input/brats2023-part1", "/kaggle/input/brats2023-part2", "/kaggle/input/brats2023-part3"]
    MODEL_NAME = "nnunet_default"
    MODEL_SAVE_PATH = f"/kaggle/working/{MODEL_NAME}/"
    RESUME_CHECKPOINT_PATH = None
    os.makedirs(MODEL_SAVE_PATH)

    train_dataset = BratsDataset(TRAIN_PATHS)

    train_dataset, val_dataset = random_split(train_dataset, SPLIT, generator=torch.Generator().manual_seed(13))

    train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_dataloader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=True)

    model = nnUNet(in_channels=4, num_classes=3).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.PolynomialLR(optimizer, total_iters=15, power=0.9)
    loss_fn1 = TotalDiceLoss()
    loss_fn2 = nn.BCEWithLogitsLoss()
    CHECKPOINT = 0
    if RESUME_CHECKPOINT_PATH is not None:
        CHECKPOINT = load_checkpoint(model, optimizer, scheduler, RESUME_CHECKPOINT_PATH)

In [None]:
if MODEL_TO_TRAIN == "nnunet":
    writer_nnunet = SummaryWriter(f'/kaggle/working/runs/{MODEL_NAME}_{CHECKPOINT}_{datetime.now().strftime("%Y-%m-%d_%H-%M")}')

In [None]:
if MODEL_TO_TRAIN == "nnunet":
    trainer(CHECKPOINT, writer_nnunet, model, train_dataloader, val_dataloader, loss_fn1, loss_fn2, optimizer, scheduler)

## Inference

In [8]:
def plot_test_pred(img, pred_mask, file_name, d_slice=80):
    img = F.interpolate(img, (240, 240, 155), mode="trilinear", align_corners=True)
    img = img.squeeze(0).cpu().detach()
    
    fig = plt.figure(figsize=(15,5))
    plt.suptitle(f'Prediction for case <{file_name}> | z-slice no. {d_slice}', y=0.85)
    ax = fig.add_subplot(1,5,1)
    plt.imshow(img[1,:,:,d_slice], cmap="gray")
    ax.set_title("T1-native")
    ax = fig.add_subplot(1,5,2)
    plt.imshow(img[0,:,:,d_slice], cmap="gray")
    ax.set_title("T1-weighted")
    ax = fig.add_subplot(1,5,3)
    plt.imshow(img[2,:,:,d_slice], cmap="gray")
    ax.set_title("T2-weighted")
    ax = fig.add_subplot(1,5,4)
    plt.imshow(img[3,:,:,d_slice], cmap="gray")
    ax.set_title("T2-FLAIR")
    ax = fig.add_subplot(1,5,5)
    plt.imshow(pred_mask[:,:,d_slice])
    
    return fig

In [9]:
def inference(test_path, model_state, model_name, device, writer):
    if model_name == "dilatedunet":
        model = DilatedUNet(in_channels=4, num_classes=3).to(device)
    elif model_name == "nnunet":
        model = nnUNet(in_channels=4, num_classes=3).to(device)
    else:
        raise NameError("Model not defined")
    model.load_state_dict(model_state)
    model.eval()
    test_dataset = BratsDataset(test_path, test=True)
  
    for i, (img, folder) in tqdm(enumerate(test_dataset), total=len(test_dataset)):
        img = img.to(device)
        img = F.interpolate(img.unsqueeze(0), (128,128,128), mode="trilinear", align_corners=True)

        pred_mask = model(img)

        # Upscale to final required resolution for evaluation
        pred_mask = F.interpolate(pred_mask[0], (240, 240, 155), mode="nearest").squeeze(0)
        pred_mask = torch.where(pred_mask > 0.5, torch.tensor(1), torch.tensor(0))

        # Post-processing: If number of voxels with ET < 200, turn ET to NCR to avoid bad online evaluation scores for ET
        one_channel_pred = torch.zeros((240,240,155))
        one_channel_pred[(pred_mask[1] - pred_mask[0]) == 1] = 1
        one_channel_pred[(pred_mask[2] - pred_mask[1]) == 1] = 2
        if torch.sum(pred_mask[0]) < 200:
            one_channel_pred[pred_mask[0] == 1] = 1
        else:
            one_channel_pred[pred_mask[0] == 1] = 3

        # Saves the prediction mask as a nii file for challenge evaluation
        affine = np.eye(4)
        affine[:3, 3] = np.array([0, 239, 0])
        nib_pred = nib.Nifti1Image(one_channel_pred.numpy(), affine=affine, dtype=np.float64)
        file_name = f"{folder.split('/')[-1]}.nii.gz"
        nib.save(nib_pred, f"{SAVE_PATH}/{model_name}/{file_name}")

        # Store the images of the predictions
        writer.add_figure(f'Test-Images/{model_name}',
                            plot_test_pred(img, one_channel_pred, file_name.replace(".nii.gz", "")),
                            global_step=i)
        writer.flush()

In [13]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TEST_PATH = ["/kaggle/input/brats2023-validation"]
MODEL_PTH = "/kaggle/input/nnunet/pytorch/epoch20/1/nnunet-training_checkpoints-epoch20.pth"
SAVE_PATH = "/kaggle/working/test_preds"
MODEL_TO_TEST = "nnunet"
os.makedirs(f"{SAVE_PATH}/{MODEL_TO_TEST}")

writer_test = SummaryWriter(f'/kaggle/working/runs/{MODEL_TO_TEST}/')
states = torch.load(MODEL_PTH, map_location=DEVICE)

In [None]:
inference(TEST_PATH, states["model_state_dict"], MODEL_TO_TEST, DEVICE, writer_test)
shutil.make_archive('nnunet-epoch20', 'zip', f'/kaggle/working/test_preds/{MODEL_TO_TEST}')