In [None]:
!pip install gdown
!pip install monai[all]==1.1.0
!pip install torch trochvision

In [None]:
import os
import time
import random
import gdown
from glob import glob
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from monai.config import print_config
from monai.transforms import (
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    Resized,
    LoadImaged,
    Orientationd,
    RandGaussianNoised,
    RandCropByPosNegLabeld,
    ScaleIntensityRanged,
    ScaleIntensityd,
    Spacingd,
    RandRotated,
    ToTensord,
    RandAffined,
    Activations,
    AsDiscrete,
    EnsureTyped
)

from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.utils import first
from monai.losses import DiceLoss, DiceFocalLoss
from monai.metrics import DiceMetric, MeanIoU
from monai.networks.nets import SwinUNETR, UNETR
from monai.networks.layers import Norm
from monai.inferers import sliding_window_inference

import torch

# %env PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512
is_cuda = torch.cuda.is_available()
device = "cuda" if(is_cuda) else "cpu"
if(is_cuda):
    torch.cuda.empty_cache()
print_config()

## <span style="color:red">Note</span>

Download data from [drive](https://drive.google.com/file/d/1yU_MeDoMHnLg0vNAy4Hb7_liJBogWl_N/view?usp=share_link) link and save the data in `/data` folder.

Folder Structure

```txt
├── data
    ├── imageTr        
    ├── labelsTR     
    └── test      
```

In [None]:
os.makedirs("data", exist_ok=True)
os.makedirs("model", exist_ok=True)

In [None]:
ROOT_DATA_FOLDER=r"data"
IMAGE_FOLDER=r"imagesTr"
LABEL_FOLDER=r"labelsTr"
TEST_FOLDER=r"test"
ROOT_MODEL_DIR=r"model"

In [None]:
image_path = sorted(glob(os.path.join(ROOT_DATA_FOLDER, IMAGE_FOLDER, "*.nii.gz")))
label_path = sorted(glob(os.path.join(ROOT_DATA_FOLDER, LABEL_FOLDER, "*.nii.gz")))
test_path = sorted(glob(os.path.join(ROOT_DATA_FOLDER, TEST_FOLDER, "*.nii.gz")))

data_path = [{"image": image_name, "label": label_name} for image_name, label_name in zip(image_path, label_path)]
test_path = [{"image": image_name} for image_name in test_path]

print("Total Train Data: ", len(data_path))
print("Total Test Data: ", len(test_path))

In [None]:
val_path = data_path[:3]
val_path += data_path[len(data_path)-1:]
train_path = [i for i in data_path if i not in val_path]

In [None]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        EnsureTyped(keys=["image", "label"]),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityd(keys=["image"]),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Resized(
            keys=["image", "label"], 
            spatial_size=(128, 128, 64), 
            mode=('trilinear', 'nearest')
        ),
        RandAffined(
            keys=['image', 'label'],
            mode=('bilinear', 'nearest'),
            prob=0.6, spatial_size=(128, 128, 64),
            rotate_range=(0, 0, np.pi/12),
            scale_range=(0.1, 0.1, 0.1)
        ),
        RandRotated(keys=['image', 'label'], prob=0.5, range_x=10.0),
        RandGaussianNoised(keys='image', prob=0.5),
        ToTensord(keys=["image", "label"]),
    ]
)

val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        EnsureTyped(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityd(keys=["image"]),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Resized(
            keys=["image", "label"], 
            spatial_size=(128, 128, 64), 
            mode=('trilinear', 'nearest')
        ),
        ToTensord(keys=["image", "label"]),
    ]
)

test_transforms = Compose(
    [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        EnsureTyped(keys=["image"]),
        Orientationd(keys=["image"], axcodes="RAS"),
        Spacingd(
            keys=["image"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear"),
        ),
        ScaleIntensityd(keys=["image"]),
        CropForegroundd(keys=["image"], source_key="image"),
        Resized(
            keys=["image"], 
            spatial_size=(128, 128, 64), 
            mode=('trilinear')
        ),
        ToTensord(keys=["image"]),
    ]
)

In [None]:
batch_size=1
worker=0
train_dataset = Dataset(data=train_path, transform=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=worker)

val_dataset = Dataset(data=val_path, transform=val_transforms)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=worker)

test_dataset = Dataset(data=test_path, transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=worker)

In [None]:
def save_model(path, model, optimizer, train_loss, val_metric):
    data = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_metric': val_metric,
    }
    
    torch.save(data, path)

def load_model(path):
    data = torch.load(path, map_location=device)
    
    model_state = data['model']
    optimizer_state = data['optimizer']
    train_loss = data['train_loss']
    val_metric = data['val_metric']
    
    return (model_state, optimizer_state, train_loss, val_metric)

In [None]:
max_epochs = 100
val_interval = 1
gradient_accumulations = 2

model = UNETR(
    img_size=(128, 128, 64),
    in_channels=1,
    out_channels=2,
    dropout_rate=0.5
).to(device)

optimizer = torch.optim.Adam(model.parameters(), 1e-3, weight_decay=1e-5)

loss_function = DiceLoss(sigmoid=True, to_onehot_y=True, include_background=False)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)

dice_metric = DiceMetric(include_background=False, reduction="mean")

post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])

if(is_cuda):
    scaler = torch.cuda.amp.GradScaler()
    torch.backends.cudnn.benchmark = True

model

In [None]:
# load model if saved
model_file = "" ## model name if saved
path = os.path.join(ROOT_MODEL_DIR, model_file)

if(model_file!="" && os.path.isfile(path)):
    model_state, optimizer_state, train_loss, val_metric = load_model(path)
    model.load_state_dict(model_state)
    optimizer.load_state_dict(optimizer_state)
    best_metric = max(val_metric)
else:
    train_loss, val_metric = np.array([]), np.array([])
    best_metric = -1

best_metric_epoch = -1

In [None]:
print(f"Total Train: {train_loss.shape[0]}")
print(f"Best Metric: {best_metric}")

total_start = time.time()
for epoch in range(max_epochs):
    epoch_start = time.time()
    print("-" * 20)
    print(f"epoch {epoch + 1}/{max_epochs}")
    
    model.train()
    epoch_loss = 0
    steps = 0

    step_start = time.time()
    for batch_idx, batch_data in enumerate(train_loader):
        steps += 1
        inputs=batch_data['image'].to(device)
        labels=batch_data['label'].to(device)
        
        optimizer.zero_grad()
        #forward pass
        if is_cuda:
            with torch.cuda.amp.autocast():
                outputs = model(inputs)
                loss = loss_function(outputs, labels)
        else:
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
        
        #backward pass
        if is_cuda:
            scaler.scale(loss).backward()
            epoch_loss += loss.item()
            scaler.unscale_(optimizer)
            scaler.step(optimizer)
            scaler.update()
        else:
            epoch_loss += loss.item()
            loss.backward()
            optimizer.step()
        
        if(steps%20==0):
            print(
                f"{steps}/{len(train_dataset) // train_loader.batch_size}"
                f", train_loss: {loss.item():.4f}"
                f", step time: {(time.time() - step_start):.4f}"
            )
            step_start = time.time()
    
    lr_scheduler.step()
    epoch_loss /= steps
    train_loss = np.append(train_loss, epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    
    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_input = val_data['image'].to(device)
                val_labels = val_data['label'].to(device)
                
                roi_size = (128, 128, 64)
                sw_batch_size = 1
                
                if torch.cuda.is_available():
                    with torch.cuda.amp.autocast():
                        val_outputs = sliding_window_inference(val_input, roi_size, sw_batch_size, model)
                else:
                    val_outputs = sliding_window_inference(val_input, roi_size, sw_batch_size, model)

                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                
                dice_metric(y_pred=val_outputs, y=val_labels)
                
            metric = dice_metric.aggregate().item()
            dice_metric.reset()
            val_metric = np.append(val_metric, metric)
            
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                
                # saving model
                path = os.path.join(ROOT_MODEL_DIR, f"unetr_model_without_{metric:.4f}.pth")
                save_model(path, model, optimizer, train_loss, val_metric)
                print("Saved best model and Optimizer.")
                
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f"\nbest mean dice: {best_metric:.4f}"
                f" at epoch: {best_metric_epoch}"
            )
            
    print(f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}")

total_time = time.time() - total_start
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}, total time: {total_time}.")

In [None]:
model.cpu()

checker = first(val_loader)
image, label = checker['image'], checker['label']

with torch.no_grad():
    output = model(image)
    output = torch.nn.Softmax()(output)
    output = torch.round(output)

    print(f"Output shape: {output.shape}")

    for slide in range(output.shape[2]):
        plt.figure("Test Model", (12, 12))

        plt.subplot(1, 3, 1)
        plt.title("Input")
        plt.imshow(image[0, 0, :, :, slide], cmap = "gray")

        plt.subplot(1, 3, 2)
        plt.title("Label")
        plt.imshow(image[0, 0, :, :, slide], cmap = "gray")
        plt.imshow(label[0, 0, :, :, slide], cmap = 'jet', alpha = 0.5)

        plt.subplot(1, 3, 3)
        plt.title("Output")
        plt.imshow(image[0, 0, :, :, slide], cmap = "gray")
        plt.imshow(output[0, 1, :, :, slide], cmap = 'jet', alpha = 0.5)

        plt.show()