# ISLES 2022 UNet


In [None]:
!pip install -q 'monai[all]'

In [None]:
import os
import gc
import glob
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
import matplotlib.pyplot as plt

from funcs import get_paths, load_metrics, save_metrics, rand_crop

from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    Rotated,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    EnsureTyped,
    EnsureType,
    Invertd,
    Rotate90d,
    RandRotated,
    RandShiftIntensityd,
    RandGaussianNoised
)
from monai.handlers.utils import from_engine

from UNet3D import UNet
from AttUNet import AttUNet
from TransUNet import TransUNet
from monai.networks.nets import SwinUNETR
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch

In [None]:
class ConvertToBinaryLabel(MapTransform):
    """
    Convert labels to binary mask from the brats classes:
    """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            # merge labels 1, 2 and 3 to construct binary mask
            result.append(torch.logical_or(torch.logical_or(d[key] == 2, d[key] == 3), d[key] == 1))
            d[key] = torch.stack(result, axis=0).float()
        return d


# Change Device to Cuda

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print(torch.cuda.get_device_properties(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

# Load Data

In [None]:
def get_paths(mask):
    data_dir = os.path.join(os.getcwd(), 'BraTS2021_Training_Data')
    tmp = []
    for path in Path(data_dir).rglob(mask):
        tmp.append(path.resolve())
    return tmp

In [None]:
training_dir = os.path.join(os.getcwd(), 'metrics')

saved_files = load_metrics("brats_datasplit")
    
train_files = saved_lists

In [None]:
# training_dir = os.path.join(os.getcwd(), 'metrics')

# image_paths = get_paths('*t1.nii.gz')
# label_paths = get_paths('*seg.nii.gz')

# # To ensure I am not pulling in the training set with no mask/labels
# assert(len(image_paths) == len(label_paths))
# data_length = len(image_paths)

# data_dicts = [
#     {"image": image_name, "label": label_name}
#     for image_name, label_name in zip(image_paths, label_paths)
# ]

# # Because, why not? 
# assert(len(data_dicts) == data_length)

In [None]:
print(len(train_files))

In [None]:
# # store this specific datasplit for future training - uses pickle
# save_metrics('brats_data', data_dicts)

# Transforms using Monai

In [None]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]), # Load image file or files from provided path based on reader.
        ConvertToBinaryLabel(keys="label"),
        EnsureChannelFirstd(keys=["image"]), #adds a channel dimension if the data doesn't have one ... torch.Size([1, ...]) = torch.Size([1, 1, ...

        Orientationd(keys=["image", "label"], axcodes="LPS"),
        Rotate90d(keys=["image", "label"], k=1, spatial_axes=(0,2)), # rotate data so it looks like it should do? ... doesn't feel right when viewing otherwise
        ScaleIntensityRanged(
            keys=["image"], a_min=0.0, a_max=6000.0,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        RandRotated(keys=["image", "label"], prob=0.2, range_x=0.3),
        RandRotated(keys=["image", "label"], prob=0.2, range_y=0.3),
        RandRotated(keys=["image", "label"], prob=0.2, range_z=0.3),
        RandShiftIntensityd(
            keys=["image"],
            offsets=0.10,
            prob=0.50,
        ),
        RandGaussianNoised(keys=["image"]),
        EnsureTyped(keys=["image", "label"]) # converts the data to a pytorch tensor
    ]
)

        
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]), # Load image file or files from provided path based on reader.
        EnsureChannelFirstd(keys=["image"]), #adds a channel dimension if the data doesn't have one ... torch.Size([1, ...]) = torch.Size([1, 1, ...
        ConvertToBinaryLabel(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="LPS"),
        Rotate90d(keys=["image", "label"], k=1, spatial_axes=(0,2)), # rotate data so it looks like it should do? ... doesn't feel right when viewing otherwise
        ScaleIntensityRanged(
            keys=["image"], a_min=0.0, a_max=6000.0,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        EnsureTyped(keys=["image", "label"]) # converts the data to a pytorch tensor
    ]
)
        
cropper = RandCropByPosNegLabeld(
    keys=["image", "label"],
    label_key="label",
    spatial_size=(96, 96, 96),   # provides size of each image within the batch
    pos=1,      # pos / (pos + neg) = ratio of postivie and negative samples picked... 
    neg=1,      # with pos = neg = 1, ratio = 0.5 so it picks equal pos (stoke) and neg (no stroke) for sample.
    num_samples=4,   # number of smaller volumes to create from the original volume
    image_key="image",
    image_threshold=0,
)

In [None]:
# check_ds = Dataset(data=val_files, transform=val_transforms)
# check_loader = DataLoader(check_ds, batch_size=1)
# check_data = first(check_loader)
# image, label = (check_data["image"], check_data["label"])
# print(f"image shape: {image.shape}, label shape: {label.shape}")

In [None]:
# # Does the data look right?
# i = 50
# batch = 0
# plt.figure("check", (12, 6))
# plt.subplot(1, 2, 1)
# plt.title("image")
# plt.imshow(image[0][0][i,:,:], cmap="gray")
# plt.subplot(1, 2, 2)
# plt.title("label")
# plt.imshow(label[0][0][i,:,:])
# plt.show()

In [None]:
if device.type == 'cuda':
    train_ds = CacheDataset(
        data=train_files, 
        transform=train_transforms,
        cache_rate=0.5, 
        num_workers=4
    )
    val_ds = CacheDataset(
        data=val_files, 
        transform=val_transforms, 
        cache_rate=1.0, 
        num_workers=4
    )
else:
    train_ds = Dataset(data=train_files, transform=train_transforms)
    val_ds = Dataset(data=val_files, transform=val_transforms)
    

# 4 batch size in DataLoader and 4 samples per scan from RandCropByPosNegLabeld creates an actual batch size of 16 ... data has shape (16, 1, 223, 197, 189)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)

In [None]:
# model = UNet().to(device)

# model = AttUNet().to(device)

# model = TransUNet().to(device)

model = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=2, feature_size=48).to(device)

if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs")
    model = nn.DataParallel(model)

model.to(device)

In [None]:
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")
scaler = torch.cuda.amp.GradScaler()

In [None]:
max_epochs = 100
epoch_loss_values = []
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])

for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        with torch.cuda.amp.autocast():
            step += 1
            inputs, labels = (
                batch_data["image"].to(device),
                batch_data["label"].to(device),
            )
            inputs, labels = rand_crop(inputs, labels, cropper)
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            epoch_loss += loss.item()
        
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

In [None]:
print(
    f"train completed, best_metric: {best_metric:.4f} "
    f"at epoch: {best_metric_epoch}")

In [None]:
save_metrics('pretrained-aunet-values', (epoch_loss_values, metric_values, best_metric, best_metric_epoch))
torch.save({
            'epoch': max_epochs,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, "pretrained-aunet_model_and_optim")

In [None]:
plt.figure("train", (12, 6))
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
z = val_epoch_loss_values
plt.xlabel("epoch")
plt.plot(x, y, label = "training")
plt.plot(x, z, label = "validation")
plt.legend()
plt.show()