# ISLES 2022 UNet


In [None]:
# !pip install -q 'monai[all]'

In [10]:
import os
import gc
import glob
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
import matplotlib.pyplot as plt

from funcs import get_paths,load_metrics, save_metrics, rand_crop

from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    Rotated,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    EnsureTyped,
    EnsureType,
    Invertd,
    Rotate90d,
    RandRotated,
    RandShiftIntensityd,
    RandGaussianNoised
)
from monai.handlers.utils import from_engine

from UNet3D import UNet
from AttUNet import AttUNet
from TransUNet import TransUNet
from monai.networks.nets import SwinUNETR
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch

# Change Device to Cuda

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print(torch.cuda.get_device_properties(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

Using device: cpu



# Load Data

In [None]:
training_dir = os.path.join(os.getcwd(), 'metrics')

saved_files = load_metrics("brats_datasplit")
    
train_files, val_files, test_files = saved_lists

In [None]:
training_dir = os.path.join(os.getcwd(), 'metrics')

image_paths = get_paths('*T1w*')
label_paths = get_paths('*mask*')

# To ensure I am not pulling in the training set with no mask/labels
assert(len(image_paths) == len(label_paths))
data_length = len(image_paths)

data_dicts = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(image_paths, label_paths)
]

# Because, why not? 
assert(len(data_dicts) == data_length)

num_of_cases = len(data_dicts)
train_size = ceil((num_of_cases / 100) * 70)
validation_size = floor((num_of_cases / 100) * 15)
test_size = floor((num_of_cases / 100) * 15)

assert(train_size+validation_size+test_size == len(data_dicts))

random.shuffle(data_dicts)

train_files = data_dicts[:train_size]
val_files = data_dicts[train_size:train_size+validation_size]
test_files = data_dicts[train_size+validation_size:]

In [None]:
print(len(train_files))
print(len(val_files))
print(len(test_files))

In [None]:
# # store this specific datasplit for future training
# save_metrics('datasplit', (train_files, val_files, test_files))

# Transforms using Monai

In [None]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]), # Load image file or files from provided path based on reader.
        EnsureChannelFirstd(keys=["image", "label"]), #adds a channel dimension if the data doesn't have one ... torch.Size([1, ...]) = torch.Size([1, 1, ...
        Orientationd(keys=["image", "label"], axcodes="LPS"),
        Rotate90d(keys=["image", "label"], k=1, spatial_axes=(0,2)), # rotate data so it looks like it should do? ... doesn't feel right when viewing otherwise
        ScaleIntensityRanged(
            keys=["image"], a_min=0.0, a_max=302.0,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        RandRotated(keys=["image", "label"], prob=0.2, range_x=0.3),
        RandRotated(keys=["image", "label"], prob=0.2, range_y=0.3),
        RandRotated(keys=["image", "label"], prob=0.2, range_z=0.3),
        RandShiftIntensityd(
            keys=["image"],
            offsets=0.10,
            prob=0.50,
        ),
        RandGaussianNoised(keys=["image"]),
        EnsureTyped(keys=["image", "label"]) # converts the data to a pytorch tensor
    ]
)

        
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]), # Load image file or files from provided path based on reader.
        EnsureChannelFirstd(keys=["image", "label"]), #adds a channel dimension if the data doesn't have one ... torch.Size([1, ...]) = torch.Size([1, 1, ...
        Orientationd(keys=["image", "label"], axcodes="LPS"),
        Rotate90d(keys=["image", "label"], k=1, spatial_axes=(0,2)), # rotate data so it looks like it should do? ... doesn't feel right when viewing otherwise
        ScaleIntensityRanged(
            keys=["image"], a_min=0.0, a_max=302.0,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        EnsureTyped(keys=["image", "label"]) # converts the data to a pytorch tensor
    ]
)
        
cropper = RandCropByPosNegLabeld(
    keys=["image", "label"],
    label_key="label",
    spatial_size=(96, 96, 96),   # provides size of each image within the batch
    pos=1,      # pos / (pos + neg) = ratio of postivie and negative samples picked... 
    neg=1,      # with pos = neg = 1, ratio = 0.5 so it picks equal pos (stoke) and neg (no stroke) for sample.
    num_samples=4,   # number of smaller volumes to create from the original volume
    image_key="image",
    image_threshold=0,
)

In [None]:
if device.type == 'cuda':
    train_ds = CacheDataset(
        data=train_files, 
        transform=train_transforms,
        cache_rate=1.0, 
        num_workers=4
    )
    val_ds = CacheDataset(
        data=val_files, 
        transform=val_transforms, 
        cache_rate=1.0, 
        num_workers=4
    )
else:
    train_ds = Dataset(data=train_files, transform=train_transforms)
    val_ds = Dataset(data=val_files, transform=val_transforms)

# 4 batch size in DataLoader and 4 samples per scan from RandCropByPosNegLabeld creates an actual batch size of 16 ... data has shape (16, 1, 223, 197, 189)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)

In [None]:
# model = UNet().to(device)

# model = AttUNet().to(device)

# model = TransUNet().to(device)

model = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=2, feature_size=48).to(device)

if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs")
    model = nn.DataParallel(model)

model.to(device)

In [None]:
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")
scaler = torch.cuda.amp.GradScaler()

In [None]:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
max_epochs = 300
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
val_epoch_loss_values = []
train_metric_values = []
metric_values = []
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])

for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        with torch.cuda.amp.autocast():
            step += 1
            inputs, labels = (
                batch_data["image"],
                batch_data["label"],
            )
            inputs, labels = rand_crop(inputs, labels, cropper)
            outputs = model(inputs.to(device))
            loss = loss_function(outputs, labels.to(device))
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            epoch_loss += loss.item()
        
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    step = 0
    val_epoch_loss = 0
    for batch_data in val_loader:
        step += 1
        inputs, labels = (
            batch_data["image"],
            batch_data["label"],
        )
        inputs, labels = rand_crop(inputs, labels, cropper)
        outputs = model(inputs.to(device))
        loss = loss_function(outputs, labels.to(device))
        val_epoch_loss += loss.item()
    val_epoch_loss /= step
    val_epoch_loss_values.append(val_epoch_loss)
    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                roi_size = (96, 96, 96)
                sw_batch_size = 4
                val_outputs = sliding_window_inference(
                    val_inputs, roi_size, sw_batch_size, model)
                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)
            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()
            metric_values.append(metric)
            
            for train_data in train_loader:
                train_inputs, train_labels = (
                    train_data["image"].to(device),
                    train_data["label"].to(device),
                )
                roi_size = (96, 96, 96)
                sw_batch_size = 4
                train_outputs = sliding_window_inference(
                    train_inputs, roi_size, sw_batch_size, model)
                train_outputs = [post_pred(i) for i in decollate_batch(train_outputs)]
                train_labels = [post_label(i) for i in decollate_batch(train_labels)]
                dice_metric(y_pred=train_outputs, y=train_labels)
            train_metric = dice_metric.aggregate().item()
            dice_metric.reset()        
            train_metric_values.append(train_metric)
            
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(
                    training_dir, "best_metric_model.pth"))
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f"\nbest mean dice: {best_metric:.4f} "
                f"at epoch: {best_metric_epoch}"
            )  

In [None]:
print(
    f"train completed, best_metric: {best_metric:.4f} "
    f"at epoch: {best_metric_epoch}")

In [None]:
save_metrics('unet-values', (epoch_loss_values, val_epoch_loss_values, train_metric_values, metric_values, best_metric, best_metric_epoch))
torch.save({
            'epoch': max_epochs,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, "unet_model_and_optim")