# Colab Configuration

## Git config

In [None]:
!git config --global user.email "your@email.com"
!git config --global user.name "username"

In [None]:
!rm -rf "/content/Lung-Cancer-Segmentation"
!git clone https://github.com/shubranshugupta/Lung-Cancer-Segmentation.git

In [None]:
%cd Lung-Cancer-Segmentation

In [None]:
!git checkout -b unetr
!git merge unetr remotes/origin/unetr
!git branch -a

## Installing requirements

In [None]:
!pip install -r "/content/Lung-Cancer-Segmentation/requirements.txt"

In [None]:
%cd Lung-Cancer-Segmentation

## AWS Config

In [None]:
!aws configure

# Training

In [1]:
import os
import sys
sys.path.append("../")
import logging as log

from src.data.unetr import LungDataModule
from src import ROOT_FOLDER
from src.model.unetr import MyUNETR
from src.utils.utils import setup
from src.utils.mlflow import mlflow_setup
from src.utils.callback import MyModelCheckpoint

import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
)

from monai.transforms import (
    AddChanneld,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
    ToTensord,
)

import mlflow

In [2]:
log.captureWarnings(True)
log.basicConfig(filename=os.path.join(ROOT_FOLDER, "logs", "train.log"), encoding='utf-8', level=log.DEBUG, 
                filemode="w", format='%(asctime)s %(levelname)6s %(message)7s', datefmt='%d-%m-%Y %I:%M:%S %p')

# Setup data
log.info("Running Setup..")
setup()

In [3]:
# Defining Hyperparameters

# Transforms Hyperparameters
a_min = -200
a_max = 200

# Unet Hyperparameters
num_samples = 4
img_size = (96, 96, 96)
feature_size = 16
hidden_size = 768
mlp_dim = 3072
num_heads = 12
dropout_rate = 0.0

# Training Hyperparameters
max_epochs = 1300
batch_size = 1
lr = 1e-4

hyperparams = {
    "a_min": a_min,
    "a_max": a_max,
    "num_samples": num_samples,
    "img_size": img_size,
    "feature_size": feature_size,
    "hidden_size": hidden_size,
    "mlp_dim": mlp_dim,
    "num_heads": num_heads,
    "dropout_rate": dropout_rate,
    "max_epochs": max_epochs,
    "batch_size": batch_size,
    "lr": lr
}

In [4]:
log.info("Defining training transforms..")
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=a_min,
            a_max=a_max,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=img_size,
            pos=1,
            neg=1,
            num_samples=num_samples,
            image_key="image",
            image_threshold=0,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[0],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[1],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[2],
            prob=0.10,
        ),
        RandRotate90d(
            keys=["image", "label"],
            prob=0.10,
            max_k=3,
        ),
        RandShiftIntensityd(
            keys=["image"],
            offsets=0.10,
            prob=0.50,
        ),
        ToTensord(keys=["image", "label"]),
    ]
)

log.info("Defining validation transforms..")
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=a_min,
            a_max=a_max,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        ToTensord(keys=["image", "label"]),
    ]
)

In [5]:
log.info("Initializing MLFlow Logger..")
experiment_name = "Test"
input_data_size = (num_samples, batch_size, img_size[0], img_size[1], img_size[2])
output_data_size = (num_samples, batch_size, img_size[0], img_size[1], img_size[2])

while True:
    try:
        mlf_logger, model_log, signature = mlflow_setup(experiment_name, input_data_size, output_data_size, hyperparams)
        break
    except Exception as e:
        log.error(f"MLFlow Logger setup failed with error: {e}")
        log.info("Trying it again..")
        mlflow.end_run()

# mlf_logger, model_log, signature = mlflow_setup(experiment_name, input_data_size, output_data_size, hyperparams)

In [6]:
log.info("Initializing Data Module..")
loader = LungDataModule(
    file_path=os.path.join(ROOT_FOLDER, "data", "raw", "dataset.json"),
    batch_size=batch_size,
    val_transform=val_transforms,
    train_transform=train_transforms,
    num_workers=6
)

In [7]:
log.info("Initializing Model..")
model = MyUNETR(
    img_size = img_size,
    feature_size = feature_size,
    hidden_size = hidden_size,
    mlp_dim = mlp_dim,
    num_heads = num_heads,
    dropout_rate = dropout_rate,
    lr = lr
)
print(model.summary())

UNETR(
  (vit): ViT(
    (patch_embedding): PatchEmbeddingBlock(
      (patch_embeddings): Sequential(
        (0): Rearrange('b c (h p1) (w p2) (d p3) -> b (h w d) (p1 p2 p3 c)', p1=16, p2=16, p3=16)
        (1): Linear(in_features=4096, out_features=768, bias=True)
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (blocks): ModuleList(
      (0): TransformerBlock(
        (mlp): MLPBlock(
          (linear1): Linear(in_features=768, out_features=3072, bias=True)
          (linear2): Linear(in_features=3072, out_features=768, bias=True)
          (fn): GELU()
          (drop1): Dropout(p=0.0, inplace=False)
          (drop2): Dropout(p=0.0, inplace=False)
        )
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): SABlock(
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
          (qkv): Linear(in_features=768, out_features=2304, bias=False)
          (drop_output): Dropout(p=0.0, inplace=False)
          (d

In [8]:
log.info("Initializing Callbacks.")
progress_bar = TQDMProgressBar()
early_stopping = EarlyStopping(monitor="val_epoch_loss", patience=10, mode="min", verbose=True)
checkpoint_callback = MyModelCheckpoint(
    dirpath=os.path.join(ROOT_FOLDER, "logs", "checkpoints"),
    filename="checkpoint",
    save_top_k=1,
    verbose=True,
    monitor="val_epoch_loss",
    mode="min",
    mlflow_model=model_log,
    signature=signature
)
lr_logger = LearningRateMonitor(logging_interval="epoch")

In [9]:
log.info("Initializing Trainer.")
trainer = pl.Trainer(max_epochs=max_epochs,
                    callbacks=[early_stopping, checkpoint_callback, lr_logger], logger=mlf_logger)

# start training
log.info("Starting Training.")

try:
    trainer.fit(model, loader)
except Exception as e:
    log.error(f"Training failed with error: {e}")
    mlflow.end_run()


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Loading dataset: 100%|██████████| 4/4 [00:52<00:00, 13.15s/it]
Loading dataset: 100%|██████████| 3/3 [00:21<00:00,  7.16s/it]

  | Name          | Type       | Params
---------------------------------------------
0 | _model        | UNETR      | 92.8 M
1 | loss_function | DiceCELoss | 0     
---------------------------------------------
92.8 M    Trainable params
0         Non-trainable params
92.8 M    Total params
371.163   Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]