In [None]:
!pip install -q -U monai wandb

In [None]:
import wandb

wandb.init(
    project="brain-tumor-segmentation",
    entity="lifesciences",
    job_type="train_baseline",
)

config = wandb.config

In [None]:
from monai.utils import set_determinism

config.seed = 0
set_determinism(seed=config.seed)

In [None]:
import torch
from monai.transforms import MapTransform


class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    Convert labels to multi-channels based on brats classes:
    label 1 is the peritumoral edema
    label 2 is the GD-enhancing tumor
    label 3 is the necrotic and non-enhancing tumor core
    The possible classes are TC (Tumor core), WT (Whole tumor), and ET (Enhancing tumor).

    Reference: https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/brats_segmentation_3d.ipynb

    """

    def __call__(self, data):
        data_dict = dict(data)
        for key in self.keys:
            result = []
            # merge label 2 and label 3 to construct Tumor Core
            result.append(torch.logical_or(data_dict[key] == 2, data_dict[key] == 3))
            # merge labels 1, 2 and 3 to construct Whole Tumor
            result.append(
                torch.logical_or(
                    torch.logical_or(data_dict[key] == 2, data_dict[key] == 3),
                    data_dict[key] == 1,
                )
            )
            # label 2 is Enhancing Tumor
            result.append(data_dict[key] == 2)
            data_dict[key] = torch.stack(result, axis=0).float()
        return data_dict

In [None]:
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
    LoadImaged,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    EnsureTyped,
    EnsureChannelFirstd,
)


config.roi_size = [224, 224, 144]

train_transform = Compose(
    [
        # load 4 Nifti images and stack them together
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        RandSpatialCropd(
            keys=["image", "label"], roi_size=config.roi_size, random_size=False
        ),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
    ]
)
val_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

In [None]:
artifact = wandb.use_artifact(
    "lifesciences/brain-tumor-segmentation/decathlon_brain_tumor:v0", type="dataset"
)
artifact_dir = artifact.download()

In [None]:
from monai.apps import DecathlonDataset

config.num_workers = 4

train_dataset = DecathlonDataset(
    root_dir=artifact_dir,
    task="Task01_BrainTumour",
    transform=val_transform,
    section="training",
    download=False,
    cache_rate=0.0,
    num_workers=config.num_workers,
)
val_dataset = DecathlonDataset(
    root_dir=artifact_dir,
    task="Task01_BrainTumour",
    transform=val_transform,
    section="validation",
    download=False,
    cache_rate=0.0,
    num_workers=config.num_workers,
)

In [None]:
from monai.data import DataLoader

# apply train_transforms to the training dataset
train_dataset.transform = train_transform

config.batch_size = 2

# create the train_loader
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
)

# create the val_loader
val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers,
)

In [None]:
from monai.networks.nets import SegResNet

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

config.model_blocks_down = [1, 2, 2, 4]
config.model_blocks_up = [1, 1, 1]
config.model_init_filters = 16
config.model_in_channels = 4
config.model_out_channels = 3
config.model_dropout_prob = 0.2

# create model
model = SegResNet(
    blocks_down=config.model_blocks_down,
    blocks_up=config.model_blocks_up,
    init_filters=config.model_init_filters,
    in_channels=config.model_in_channels,
    out_channels=config.model_out_channels,
    dropout_prob=config.model_dropout_prob,
).to(device)

In [None]:
config.initial_learning_rate = 1e-4
config.weight_decay = 1e-5
config.max_train_epochs = 25

# create optimizer
optimizer = torch.optim.Adam(
    model.parameters(),
    config.initial_learning_rate,
    weight_decay=config.weight_decay,
)

# create learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=config.max_train_epochs
)

In [None]:
config.dice_loss_smoothen_numerator = 0
config.dice_loss_smoothen_denominator = 1e-5
config.dice_loss_squared_prediction = True
config.dice_loss_target_onehot = False
config.dice_loss_apply_sigmoid = True

from monai.losses import DiceLoss

loss_function = DiceLoss(
    smooth_nr=config.dice_loss_smoothen_numerator,
    smooth_dr=config.dice_loss_smoothen_denominator,
    squared_pred=config.dice_loss_squared_prediction,
    to_onehot_y=config.dice_loss_target_onehot,
    sigmoid=config.dice_loss_apply_sigmoid,
)

In [None]:
from monai.metrics import DiceMetric

dice_metric = DiceMetric(include_background=True, reduction="mean")
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

In [None]:
# use automatic mixed-precision to accelerate training
scaler = torch.cuda.amp.GradScaler()
torch.backends.cudnn.benchmark = True

In [None]:
from monai.inferers import sliding_window_inference

config.inference_roi_size = (240, 240, 160)


def inference(model, input):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=config.inference_roi_size,
            sw_batch_size=1,
            predictor=model,
            overlap=0.5,
        )

    with torch.cuda.amp.autocast():
        return _compute(input)

In [None]:
wandb.define_metric("epoch/epoch_step")
wandb.define_metric("epoch/*", step_metric="epoch/epoch_step")
wandb.define_metric("batch/batch_step")
wandb.define_metric("batch/*", step_metric="batch/batch_step")
wandb.define_metric("validation/validation_step")
wandb.define_metric("validation/*", step_metric="validation/validation_step")

batch_step = 0
validation_step = 0
metric_values = []
metric_values_tumor_core = []
metric_values_whole_tumor = []
metric_values_enhanced_tumor = []

In [None]:
import os
from tqdm.auto import tqdm
from monai.data import decollate_batch

config.validation_intervals = 1
config.checkpoint_dir = "./checkpoints"

# Create checkpoint directory
os.makedirs(config.checkpoint_dir, exist_ok=True)

epoch_progress_bar = tqdm(range(config.max_train_epochs), desc="Training:")

for epoch in epoch_progress_bar:
    model.train()
    epoch_loss = 0

    total_batch_steps = len(train_dataset) // train_loader.batch_size
    batch_progress_bar = tqdm(train_loader, total=total_batch_steps, leave=False)

    # Training Step
    for batch_data in batch_progress_bar:
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()
        batch_progress_bar.set_description(f"train_loss: {loss.item():.4f}:")
        ## Log batch-wise training loss to W&B
        wandb.log({"batch/batch_step": batch_step, "batch/train_loss": loss.item()})
        batch_step += 1

    lr_scheduler.step()
    epoch_loss /= total_batch_steps
    ## Log batch-wise training loss and learning rate to W&B
    wandb.log(
        {
            "epoch/epoch_step": epoch,
            "epoch/mean_train_loss": epoch_loss,
            "epoch/learning_rate": lr_scheduler.get_last_lr()[0],
        }
    )
    epoch_progress_bar.set_description(f"Training: train_loss: {epoch_loss:.4f}:")

    # Validation and model checkpointing step
    if (epoch + 1) % config.validation_intervals == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                val_outputs = inference(model, val_inputs)
                val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                dice_metric(y_pred=val_outputs, y=val_labels)
                dice_metric_batch(y_pred=val_outputs, y=val_labels)

            metric_values.append(dice_metric.aggregate().item())
            metric_batch = dice_metric_batch.aggregate()
            metric_values_tumor_core.append(metric_batch[0].item())
            metric_values_whole_tumor.append(metric_batch[1].item())
            metric_values_enhanced_tumor.append(metric_batch[2].item())
            dice_metric.reset()
            dice_metric_batch.reset()

            checkpoint_path = os.path.join(config.checkpoint_dir, "model.pth")
            torch.save(model.state_dict(), checkpoint_path)

            # Log and versison model checkpoints using W&B artifacts.
            wandb.log_model(
                checkpoint_path,
                name=f"{wandb.run.id}-checkpoint",
                aliases=[f"epoch_{epoch}"],
            )

            # Log validation metrics to W&B dashboard.
            wandb.log(
                {
                    "validation/validation_step": validation_step,
                    "validation/mean_dice": metric_values[-1],
                    "validation/mean_dice_tumor_core": metric_values_tumor_core[-1],
                    "validation/mean_dice_whole_tumor": metric_values_whole_tumor[-1],
                    "validation/mean_dice_enhanced_tumor": metric_values_enhanced_tumor[-1],
                }
            )
            validation_step += 1

In [None]:
wandb.finish()