# Exercises-MRI-segmentation

Coding exercises for appying to the position at the Paris Brain Institute. The base code is taken from the following tutorial: https://colab.research.google.com/github/fepegar/torchio-notebooks/blob/main/notebooks/TorchIO_MONAI_PyTorch_Lightning.ipynb#scrollTo=QixbF3koO99H. 

TODO: write an introduction to the problem, the type of data that is going to be used, the number of labels, training strategy, etc...

## Original code

In [None]:
# imports 
import time
from pathlib import Path
from datetime import datetime

import torch
from torch.utils.data import random_split, DataLoader
import monai
import gdown
import pandas as pd
import torchio as tio
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import seaborn as sns

sns.set()
plt.rcParams['figure.figsize'] = 12, 8
monai.utils.set_determinism()

print('Last run on', time.ctime())

%load_ext tensorboard

In [None]:
class MedicalDecathlonDataModule(pl.LightningDataModule):
    def __init__(self, task, google_id, batch_size, train_val_ratio):
        super().__init__()
        self.task = task
        self.google_id = google_id
        self.batch_size = batch_size
        self.dataset_dir = Path(task)
        self.train_val_ratio = train_val_ratio
        self.subjects = None
        self.test_subjects = None
        self.preprocess = None
        self.transform = None
        self.train_set = None
        self.val_set = None
        self.test_set = None
    
    def get_max_shape(self, subjects):
        import numpy as np
        dataset = tio.SubjectsDataset(subjects)
        shapes = np.array([s.spatial_shape for s in dataset])
        return shapes.max(axis=0)
    
    def download_data(self):
        if not self.dataset_dir.is_dir():
            url = f'https://drive.google.com/uc?id={self.google_id}'
            output = f'{self.task}.tar'
            gdown.download(url, output, quiet=False)
            !tar xf {output}

        def get_niis(d):
            return sorted(p for p in d.glob('*.nii*') if not p.name.startswith('.'))

        image_training_paths = get_niis(self.dataset_dir / 'imagesTr')
        label_training_paths = get_niis(self.dataset_dir / 'labelsTr')
        image_test_paths = get_niis(self.dataset_dir / 'imagesTs')
        return image_training_paths, label_training_paths, image_test_paths

    def prepare_data(self):
        image_training_paths, label_training_paths, image_test_paths = self.download_data()

        self.subjects = []
        for image_path, label_path in zip(image_training_paths, label_training_paths):
            # 'image' and 'label' are arbitrary names for the images
            subject = tio.Subject(
                image=tio.ScalarImage(image_path),
                label=tio.LabelMap(label_path)
            )
            self.subjects.append(subject)
        
        self.test_subjects = []
        for image_path in image_test_paths:
            subject = tio.Subject(image=tio.ScalarImage(image_path))
            self.test_subjects.append(subject)
    
    def get_preprocessing_transform(self):
        preprocess = tio.Compose([
            tio.RescaleIntensity((-1, 1)),
            tio.CropOrPad(self.get_max_shape(self.subjects + self.test_subjects)),
            tio.EnsureShapeMultiple(8),  # for the U-Net
            tio.OneHot(),
        ])
        return preprocess
    
    def get_augmentation_transform(self):
        augment = tio.Compose([
            tio.RandomAffine(),
            tio.RandomGamma(p=0.5),
            tio.RandomNoise(p=0.5),
            tio.RandomMotion(p=0.1),
            tio.RandomBiasField(p=0.25),
        ])
        return augment

    def setup(self, stage=None):
        num_subjects = len(self.subjects)
        num_train_subjects = int(round(num_subjects * self.train_val_ratio))
        num_val_subjects = num_subjects - num_train_subjects
        splits = num_train_subjects, num_val_subjects
        train_subjects, val_subjects = random_split(self.subjects, splits)

        self.preprocess = self.get_preprocessing_transform()
        augment = self.get_augmentation_transform()
        self.transform = tio.Compose([self.preprocess, augment])

        self.train_set = tio.SubjectsDataset(train_subjects, transform=self.transform)
        self.val_set = tio.SubjectsDataset(val_subjects, transform=self.preprocess)
        self.test_set = tio.SubjectsDataset(self.test_subjects, transform=self.preprocess)

    def train_dataloader(self):
        return DataLoader(self.train_set, self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_set, self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_set, self.batch_size)

In [None]:
data = MedicalDecathlonDataModule(
    task='Task04_Hippocampus',
    google_id='1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C',
    batch_size=16,
    train_val_ratio=0.8,
)

In [None]:
data.prepare_data()
data.setup()
print('Training:  ', len(data.train_set))
print('Validation: ', len(data.val_set))
print('Test:      ', len(data.test_set))

# Lightning Model

In [None]:
class Model(pl.LightningModule):
    def __init__(self, net, criterion, learning_rate, optimizer_class):
        super().__init__()
        self.lr = learning_rate
        self.net = net
        self.criterion = criterion
        self.optimizer_class = optimizer_class
    
    def configure_optimizers(self):
        optimizer = self.optimizer_class(self.parameters(), lr=self.lr)
        return optimizer
    
    def prepare_batch(self, batch):
        return batch['image'][tio.DATA], batch['label'][tio.DATA]
    
    def infer_batch(self, batch):
        x, y = self.prepare_batch(batch)
        y_hat = self.net(x)
        return y_hat, y

    def training_step(self, batch, batch_idx):
        y_hat, y = self.infer_batch(batch)
        loss = self.criterion(y_hat, y)
        self.log('train_loss', loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        y_hat, y = self.infer_batch(batch)
        loss = self.criterion(y_hat, y)
        self.log('val_loss', loss)
        return loss

In [None]:
# U-Net model from monai
unet = monai.networks.nets.UNet(
    dimensions=3,
    in_channels=1,
    out_channels=3,
    channels=(8, 16, 32, 64),
    strides=(2, 2, 2),
)

model = Model(
    net=unet,
    criterion=monai.losses.DiceCELoss(softmax=True),
    learning_rate=1e-2,
    optimizer_class=torch.optim.AdamW,
)

early_stopping = pl.callbacks.early_stopping.EarlyStopping(
    monitor='val_loss',
)
trainer = pl.Trainer(
    gpus=1 if torch.cuda.is_available() else 0,
    # precision=16,
    callbacks=[early_stopping],
)
trainer.logger._default_hp_metric = False

In [None]:
start = datetime.now()
print('Training started at', start)
trainer.fit(model=model, datamodule=data)
print('Training duration:', datetime.now() - start)

# Exercise 1
- Write a training code for a similar training as in the tutorial, but without the
pytorch_lightning library.
- Make one script with a command line for training.
- In the training loop use the automatic mixed precision from Pytorch (with autocast and
GradScaler) in order to train with FP16 precision instead of the default FP32.

In [None]:
# extra imports 
import numpy as np
import os
from tqdm import tqdm

# hyperparameters
config = {
    'task': 'Task04_Hippocampus',
    'google_id': '1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C',
    'batch_size': 16,
    'train_val_ratio': 0.8,

    'epochs': 100,
    'lr': 1e-2,
    'early_stopping': 10, # -1 to disable, else insert patience

    'best_models_dir': 'best_models',
}

assert config["early_stopping"] == -1 or config["early_stopping"] > 0, "early_stopping must be -1 or > 0"
assert config["train_val_ratio"] > 0 and config["train_val_ratio"] < 1, "train_val_ratio must be > 0 and < 1"
assert config["batch_size"] > 0, "batch_size must be > 0"

In [None]:
# data download & preparation
data = MedicalDecathlonDataModule(
    task=config['task'],
    google_id=config['google_id'],
    batch_size=config['batch_size'],
    train_val_ratio=config['train_val_ratio'],
)

data.prepare_data()
data.setup()

train_data_loader = data.train_dataloader()
val_data_loader = data.val_dataloader()
test_data_loader = data.test_dataloader()

In [None]:
# visualize a training example
batch = next(iter(train_data_loader))

batch_image = batch['image']['data']
batch_label = batch['label']['data']

print(f'The shape of the data is {batch_image.shape}')

# plot a slice and the corresponding label
slice_idx = 30

plt.figure('image', (12, 6))
plt.subplot(1, 2, 1)
plt.title('image')
plt.imshow(batch_image[0, 0, :, :, slice_idx], cmap='gray')
plt.subplot(1, 2, 2)
plt.title('label')
plt.imshow(batch_label[0, 0, :, :, slice_idx], cmap='gray')
plt.show()

# Define the model

In [None]:
# U-Net model from monai
unet = monai.networks.nets.UNet(
    dimensions=3,
    in_channels=1,
    out_channels=3,
    channels=(8, 16, 32, 64),
    strides=(2, 2, 2),
)

model = unet
criterion=monai.losses.DiceCELoss(softmax=True)
optimizer=torch.optim.AdamW(model.parameters(), lr=config['lr'])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Training Loop


In [None]:
# for early stopping 
best_val_loss = np.inf
patience_counter = 0

for epoch in range(config['epochs']):
    # training loop
    model.train()
    for batch in train_data_loader:
        x, y = batch['image']['data'].to(device), batch['label']['data'].to(device)
        
        model.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)

        loss.backward()
        optimizer.step()

    # validation loop
    model.eval()
    for batch in val_data_loader:
        x, y = batch['image']['data'].to(device), batch['label']['data'].to(device)
        
        logits = model(x)
        val_loss = criterion(logits, y)

    print(f'Epoch {epoch + 1}/{config["epochs"]}, train loss: {loss.item():.4f}, val loss: {val_loss.item():.4f}')
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), os.path.join(config['best_models_dir'], f'best_model.pth'))

    elif config['early_stopping'] != -1:
        patience_counter += 1
        if patience_counter == config['early_stopping']:
            print('Training stopped due to early stopping')
            break

torch.save(model.state_dict(), os.path.join(config['best_models_dir'], f'last_model.pth'))

# test loop
pass

        