In [1]:
import os
import sys
sys.path.append("../")
import logging as log
import tqdm

import numpy as np

from src import LOG_DIR, DATA_DIR, MODEL_DIR
from src.utils.utils import setup

from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    AddChanneld,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
    ToTensord,
    EnsureType,
    Activations,

)
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR
from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)


import torch

In [2]:
log.captureWarnings(True)
log.basicConfig(filename=os.path.join(LOG_DIR, "train.log"), level=log.DEBUG, 
                filemode="w", format='%(asctime)s %(levelname)6s %(message)7s', datefmt='%d-%m-%Y %I:%M:%S %p')

log.info("Running Setup..")
setup()

In [3]:
# Defining Hyperparameters

### Transforms Hyperparameters
a_min = -200
a_max = 200

### Unet Hyperparameters
num_samples = 4
roi_size = (96, 96, 96)
feature_size = 16
hidden_size = 768
mlp_dim = 3072
num_heads = 12
dropout_rate = 0.0

### Training Hyperparameters
max_epochs = 1300
batch_size = 1
lr = 1e-4

hyperparams = {
    "a_min": a_min,
    "a_max": a_max,
    "num_samples": num_samples,
    "img_size": roi_size,
    "feature_size": feature_size,
    "hidden_size": hidden_size,
    "mlp_dim": mlp_dim,
    "num_heads": num_heads,
    "dropout_rate": dropout_rate,
    "max_epochs": max_epochs,
    "batch_size": batch_size,
    "lr": lr
}

In [4]:
log.info("Defining training transforms..")
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=a_min,
            a_max=a_max,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=roi_size,
            pos=1,
            neg=1,
            num_samples=num_samples,
            image_key="image",
            image_threshold=0,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[0],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[1],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[2],
            prob=0.10,
        ),
        RandRotate90d(
            keys=["image", "label"],
            prob=0.10,
            max_k=3,
        ),
        RandShiftIntensityd(
            keys=["image"],
            offsets=0.10,
            prob=0.50,
        ),
        ToTensord(keys=["image", "label"]),
    ]
)

log.info("Defining validation transforms..")
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=a_min,
            a_max=a_max,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        ToTensord(keys=["image", "label"]),
    ]
)

In [None]:
log.info("Loading Data ...")

data_dir = DATA_DIR
split_JSON = "dataset.json"
datasets = data_dir + split_JSON
datalist = load_decathlon_datalist(datasets, True, "train")
val_files = load_decathlon_datalist(datasets, True, "val")

train_ds = CacheDataset(
    data=datalist,
    transform=train_transforms,
    cache_num=24,
    cache_rate=1.0,
    num_workers=8,
)
train_loader = DataLoader(
    train_ds, batch_size=1, shuffle=True, num_workers=8, pin_memory=True
)

val_ds = CacheDataset(
    data=val_files, transform=val_transforms, cache_num=6, cache_rate=1.0, num_workers=4
)
val_loader = DataLoader(
    val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True
)

In [6]:
log.info("Initializing Model ...")

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNETR(
    in_channels=1,
    out_channels=1,
    img_size=roi_size,
    feature_size=feature_size,
    hidden_size=hidden_size,
    mlp_dim=mlp_dim,
    num_heads=num_heads,
    pos_embed="conv",
    norm_name="instance",
    res_block=True,
    conv_block=True,
    dropout_rate=dropout_rate,
).to(device)


log.info("Defining Loss ...")
loss_function = DiceCELoss(sigmoid=True)
torch.backends.cudnn.benchmark = True

log.info("Defining Optimizer ...")
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

In [None]:
def validation(epoch_iterator_val):
    model.eval()
    dice_vals = list()
    
    with torch.no_grad():
        for step, batch in enumerate(epoch_iterator_val):
            val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda())

            val_outputs = sliding_window_inference(val_inputs, roi_size, 4, model)
            val_outputs_list = decollate_batch(val_outputs)
            val_output_convert = [
                post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list
            ]

            # val_labels_list = decollate_batch(val_labels)
            # val_labels_convert = [
            #     post_label(val_label_tensor) for val_label_tensor in val_labels_list
            # ]
            
            dice_metric(y_pred=val_output_convert, y=val_labels)
            dice = dice_metric.aggregate().item()
            dice_vals.append(dice)
            
            epoch_iterator_val.set_description(
                "Validate (%d / %d Steps) (dice=%2.5f)" % (global_step, 10.0, dice)
            )
       
        dice_metric.reset()
    
    mean_dice_val = np.mean(dice_vals)
    return mean_dice_val


def train(global_step, train_loader, dice_val_best, global_step_best):
    model.train()

    epoch_loss = 0
    step = 0

    epoch_iterator = tqdm(
        train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True
    )

    for step, batch in enumerate(epoch_iterator):
        step += 1
        x, y = (batch["image"].cuda(), batch["label"].cuda())
        logit_map = model(x)
        loss = loss_function(logit_map, y)
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        epoch_iterator.set_description(
            "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, max_iterations, loss)
        )

        if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations:
            epoch_iterator_val = tqdm(
                val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True
            )
            dice_val = validation(epoch_iterator_val)
            epoch_loss /= step
            epoch_loss_values.append(epoch_loss)
            metric_values.append(dice_val)
            if dice_val > dice_val_best:
                dice_val_best = dice_val
                global_step_best = global_step

                torch.save(
                    model.state_dict(), os.path.join(MODEL_DIR, "best_metric_model.pth")
                )

                print(f"Model Was Saved ! Current Best Avg. Dice: {dice_val_best} Current Avg. Dice: {dice_val}")
            else:
                print(f"Model Was Not Saved ! Current Best Avg. Dice: {dice_val_best} Current Avg. Dice: {dice_val}")

        global_step += 1

    return global_step, dice_val_best, global_step_best


max_iterations = 25000
eval_num = 500

post_label = AsDiscrete(to_onehot=14)
post_pred = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

dice_metric = DiceMetric(include_background=True, reduction="mean")

global_step = 0
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []

while global_step < max_iterations:
    global_step, dice_val_best, global_step_best = train(
        global_step, train_loader, dice_val_best, global_step_best
    )

model.load_state_dict(torch.load(os.path.join(MODEL_DIR, "best_metric_model.pth")))