# Train a Baseline Segmentation Model
In this notebook we will learn:

- We will learn how to use specific MONAI APIs to write our training workflow, including a SoTA neural network architecture and loss function and metrics for our task.
- Use Weights & Biases for tracking our experiments and logging and verisioning our model checkpoints.

## 🌴 Setup and Installation

First, let us install the latest version of both MONAI and Weights and Biases.

In [None]:
!pip install -q -U monai wandb

## 🌳 Initialize a W&B Run

We will start a new W&B run to start tracking our experiment.

In [None]:
import wandb

wandb.init(
    project="brain-tumor-segmentation",
    entity="lifesciences",
    job_type="train_baseline",
)

## 🌼 Reproducibility and Configuration Management

`wandb.config` allows us to easily define and manage the configurations of our experiments. This includes hyperparameters, model settings, and any other experiment variables that we use in a particular run. By centralizing this information, we can ensure consistency across runs and make your experiments more organized and reproducible.

In [None]:
config = wandb.config

Next, we set random seed for modules to enable deterministic training by setting a global seed using `monai.utils.set_determinism`. Setting a random seed (or multiple random seeds) and storing them as a configuration, we can make sure that a particular run is reproducible.

In [None]:
from monai.utils import set_determinism

config.seed = 0
set_determinism(seed=config.seed)

## 💿 Loading and Transforming the Data

We will now learn using the [`monai.transforms`](https://docs.monai.io/en/stable/transforms.html) API to create and apply transforms to our data.

### Creating a Custom Transform

First, we demonstrate the creation of a custom transform `ConvertToMultiChannelBasedOnBratsClassesd` using [`monai.transforms.MapTransform`](https://docs.monai.io/en/stable/transforms.html#maptransform) that converts labels to multi-channel tensors based on brats18 classes:
- label 1 is the necrotic and non-enhancing tumor core
- label 2 is the peritumoral edema
- label 3 is the GD-enhancing tumor.

The target classes for the semantic segmentation task after applying this transform on the dataset will be
- Tumor core
- Whole tumor
- Enhancing tumor

In [None]:
import torch
from monai.transforms import MapTransform


class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    Convert labels to multi-channels based on brats classes:
    label 1 is the peritumoral edema
    label 2 is the GD-enhancing tumor
    label 3 is the necrotic and non-enhancing tumor core
    The possible classes are TC (Tumor core), WT (Whole tumor), and ET (Enhancing tumor).

    Reference: https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/brats_segmentation_3d.ipynb

    """

    def __call__(self, data):
        data_dict = dict(data)
        for key in self.keys:
            result = []
            # merge label 2 and label 3 to construct Tumor Core
            result.append(torch.logical_or(data_dict[key] == 2, data_dict[key] == 3))
            # merge labels 1, 2 and 3 to construct Whole Tumor
            result.append(
                torch.logical_or(
                    torch.logical_or(data_dict[key] == 2, data_dict[key] == 3),
                    data_dict[key] == 1,
                )
            )
            # label 2 is Enhancing Tumor
            result.append(data_dict[key] == 2)
            data_dict[key] = torch.stack(result, axis=0).float()
        return data_dict

Next, we compose all the necessary transforms for visualizing the data using [`monai.transforms.Compose`](https://docs.monai.io/en/stable/transforms.html#monai.transforms.Compose).

**Note:** During training, we will apply a differnt set of transforms to the data.

In [None]:
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
    LoadImaged,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    EnsureTyped,
    EnsureChannelFirstd,
)


config.roi_size = [224, 224, 144]

train_transform = Compose(
    [
        # load 4 Nifti images and stack them together
        LoadImaged(keys=["image", "label"]),
        # Ensure loaded images are in channels-first format
        EnsureChannelFirstd(keys="image"),
        # Ensure the input data to be a PyTorch Tensor or numpy array
        EnsureTyped(keys=["image", "label"]),
        # Convert labels to multi-channels based on brats18 classes
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        # Change the input image’s orientation into the specified based on axis codes
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        # Resample the input images to the specified pixel dimension
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        # Augmentation: Crop image with random size or specific size ROI
        RandSpatialCropd(
            keys=["image", "label"], roi_size=config.roi_size, random_size=False
        ),
        
        # Augmentation: Randomly flip the image on the specified axes
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        
        # Normalize input image intensity
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        
        # Augmentation: Randomly scale the image intensity
        RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
    ]
)
val_transform = Compose(
    [
        # load 4 Nifti images and stack them together
        LoadImaged(keys=["image", "label"]),
        # Ensure loaded images are in channels-first format
        EnsureChannelFirstd(keys="image"),
        # Ensure the input data to be a PyTorch Tensor or numpy array
        EnsureTyped(keys=["image", "label"]),
        # Convert labels to multi-channels based on brats18 classes
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        # Change the input image’s orientation into the specified based on axis codes
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        # Resample the input images to the specified pixel dimension
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        # Normalize input image intensity
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

For loading the dataset, we first fetch it from the W&B dataset artifact that we had created earlier. This enables us to use the dataset as an input artifact to our visualization run, and establish the necessary lineage for our experiment.

![](./assets/artifact_usage.png)

In [None]:
artifact = wandb.use_artifact(
    "lifesciences/brain-tumor-segmentation/decathlon_brain_tumor:v0", type="dataset"
)
artifact_dir = artifact.download()

We now use the [`monai.apps.DecathlonDataset`](https://docs.monai.io/en/stable/apps.html#monai.apps.DecathlonDataset) to load our dataset and apply the transforms we defined on the data samples so that we use them for training and validation.

In [None]:
from monai.apps import DecathlonDataset

config.num_workers = 4

# Create the dataset for the training split
# of the brain tumor segmentation dataset
train_dataset = DecathlonDataset(
    root_dir=artifact_dir,
    task="Task01_BrainTumour",
    transform=train_transform,
    section="training",
    download=False,
    cache_rate=0.0,
    num_workers=config.num_workers,
)

# Create the dataset for the validation split
# of the brain tumor segmentation dataset
val_dataset = DecathlonDataset(
    root_dir=artifact_dir,
    task="Task01_BrainTumour",
    transform=val_transform,
    section="validation",
    download=False,
    cache_rate=0.0,
    num_workers=config.num_workers,
)

We now create DataLoaders for the train and validation datasets respectively using [`monai.data.DataLoader`](https://docs.monai.io/en/stable/data.html#dataloader) which provides an iterable over the given dataset.

In [None]:
from monai.data import DataLoader

config.batch_size = 2

# create the train_loader
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
)

# create the val_loader
val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers,
)

## 🤖 Creating the Model, Loss, and Optimizer

We will be training a **SegResNet** model based on the paper [3D MRI brain tumor segmentation using auto-encoder regularization](https://arxiv.org/pdf/1810.11654.pdf). The [SegResNet](https://docs.monai.io/en/stable/networks.html#segresnet) model that comes implemented as a PyTorch Module as part of the [`monai.networks.nets`](https://docs.monai.io/en/stable/networks.html#nets) API that provides out-of-the-box implementations of SoTA neural network models for different medical imaging tasks.

In [None]:
from monai.networks.nets import SegResNet

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

config.model_blocks_down = [1, 2, 2, 4]
config.model_blocks_up = [1, 1, 1]
config.model_init_filters = 16
config.model_in_channels = 4
config.model_out_channels = 3
config.model_dropout_prob = 0.2

# create model
model = SegResNet(
    blocks_down=config.model_blocks_down,
    blocks_up=config.model_blocks_up,
    init_filters=config.model_init_filters,
    in_channels=config.model_in_channels,
    out_channels=config.model_out_channels,
    dropout_prob=config.model_dropout_prob,
).to(device)

We will be using [Adam Optimizer](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html) and the [cosine annealing schedule](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html) to schedule our learning rate. This approach is designed to help in finding global minima in the optimization landscape and to provide a form of reset mechanism during training, which can improve the performance of the model.

In [None]:
config.initial_learning_rate = 1e-4
config.weight_decay = 1e-5
config.max_train_epochs = 25

# create optimizer
optimizer = torch.optim.Adam(
    model.parameters(),
    config.initial_learning_rate,
    weight_decay=config.weight_decay,
)

# create learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=config.max_train_epochs
)

Next, we would define the loss as multi-label DiceLoss as proposed by the paper [V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation](https://arxiv.org/abs/1606.04797) using the [`monai.losses`](https://docs.monai.io/en/stable/losses.html) API and the corresponding dice metrics using the [`monai.metrics`](https://docs.monai.io/en/stable/metrics.html) API.

In [None]:
config.dice_loss_smoothen_numerator = 0
config.dice_loss_smoothen_denominator = 1e-5
config.dice_loss_squared_prediction = True
config.dice_loss_target_onehot = False
config.dice_loss_apply_sigmoid = True

from monai.losses import DiceLoss

loss_function = DiceLoss(
    smooth_nr=config.dice_loss_smoothen_numerator,
    smooth_dr=config.dice_loss_smoothen_denominator,
    squared_pred=config.dice_loss_squared_prediction,
    to_onehot_y=config.dice_loss_target_onehot,
    sigmoid=config.dice_loss_apply_sigmoid,
)

In [None]:
from monai.metrics import DiceMetric

dice_metric = DiceMetric(include_background=True, reduction="mean")
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

## 🚀 Automatic Mixed Precision

Mixed precision training is a technique used in training neural networks that utilizes both 16-bit and 32-bit floating-point types for different parts of the computation, rather than using a single precision type throughout the entire process. This method is primarily aimed at accelerating the training process while also reducing the memory usage of the models.

We will be using [`torch.amp provides`](https://pytorch.org/docs/stable/amp.html#module-torch.amp) convenience methods for mixed precision, where some operations use the `torch.float32` datatype and other operations use lower precision floating point datatype such as `torch.float16` or `torch.bfloat16`.

### ⚖️ Gradient and Loss Scaling

If the forward pass for a particular op has float16 inputs, the backward pass for that op will produce float16 gradients. Gradient values with small magnitudes may not be representable in float16. These values will gradient underflow, so the update for the corresponding parameters will be lost.

In order to counteract the gradient underflow issues of FP16, especially in handling small gradient values, gradient and loss scaling is applied. This involves scaling up the loss before the gradient computation and scaling it back down afterwards. We will be using [`torch.cuda.amp.GradScaler`](https://pytorch.org/docs/stable/amp.html#gradient-scaling) to perform the scaling.

In [None]:
# use automatic mixed-precision to accelerate training
scaler = torch.cuda.amp.GradScaler()
torch.backends.cudnn.benchmark = True

Next, we write a utility function to perform sliding window inference using from [`monai.inferers.sliding_window_inference`](https://docs.monai.io/en/stable/inferers.html#sliding-window-inference-function) and AMP autocast. This function would be used durring the validation step in our training and validation loop.

In [None]:
from monai.inferers import sliding_window_inference

config.inference_roi_size = (240, 240, 160)


def inference(model, input):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=config.inference_roi_size,
            sw_batch_size=1,
            predictor=model,
            overlap=0.5,
        )

    with torch.cuda.amp.autocast():
        return _compute(input)

## 🦾 Training the Model
Let's finally get to training the model!

### 🐝 Customize Log Axes on W&B

We will use Use [`wandb.define_metric`](https://docs.wandb.ai/guides/track/log/customize-logging-axes) to set a custom x axis for our W&B charts. Custom x-axes are useful in contexts where you need to log to different time steps in the past during training, asynchronously. For example, for training our brain tumor segmentation model, we can log the training loss and metrics every training step but log the validation metrics every epoch.

In [None]:
wandb.define_metric("epoch/epoch_step")
wandb.define_metric("epoch/*", step_metric="epoch/epoch_step")
wandb.define_metric("batch/batch_step")
wandb.define_metric("batch/*", step_metric="batch/batch_step")
wandb.define_metric("validation/validation_step")
wandb.define_metric("validation/*", step_metric="validation/validation_step")

### 🏋️ Training and Validation Loop

Next, we will proceed to writing the training and validation loop for the brain tumor segmentation model. The traininng loop consists of 3 different logical steps:

1. **The training step**: In this step, we actually train the model, by looping over the `train_loader`. Note that we use autocast to speed up the forward pass and loss calculation and during the backpropagation, we use gradient scaler to avoid the vanishing gradient problem. At the end of each batch step, we log the batch step under `batch/batch_step` and the training loss under `batch/train_loss`. This ensure that the training loss is logged under its section against the batch step in the x-axis in the W&B workspace. Here's how the training step is written:
   
   ```python
   for batch_data in train_loader:
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()
        batch_progress_bar.set_description(f"train_loss: {loss.item():.4f}:")
        ## Log batch-wise training loss to W&B
        wandb.log({"batch/batch_step": batch_step, "batch/train_loss": loss.item()})
        batch_step += 1
        
   ```

2. **The epoch-wise logging step:** In this step, we log the learning rate and mean training loss for the epoch under the section `epoch/*`. We also update the learning rate using our learning rate scheduler after logging.

   ```python
    wandb.log(
        {
            "epoch/epoch_step": epoch,
            "epoch/mean_train_loss": total_epoch_loss / total_batch_steps,
            "epoch/learning_rate": lr_scheduler.get_last_lr()[0],
        }
    )
    lr_scheduler.step()
    
   ```

3. **The validation step:** This step is executed at the interval of a certain number of epochs. In this step, we use the aforementioned `inference` function to predict the segmentation masks for the images from the validation dataloader and use `dice_metric` to calculate the dice coefficients for each of our target classes and log the dice coefficient values under the `validation/*` section. We also save our model checkpoint to W&B using `wandb.log_model` .

    ```python
    for val_data in val_loader:
        val_inputs, val_labels = (
            val_data["image"].to(device),
            val_data["label"].to(device),
        )
        val_outputs = inference(model, val_inputs)
        val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
        dice_metric(y_pred=val_outputs, y=val_labels)
        dice_metric_batch(y_pred=val_outputs, y=val_labels)

    wandb.log(
        {
            "validation/validation_step": validation_step,
            "validation/mean_dice": metric_values[-1],
            "validation/mean_dice_tumor_core": metric_values_tumor_core[-1],
            "validation/mean_dice_whole_tumor": metric_values_whole_tumor[-1],
            "validation/mean_dice_enhanced_tumor": metric_values_enhanced_tumor[-1],
        }
    )

    checkpoint_path = os.path.join(config.checkpoint_dir, "model.pth")
    torch.save(model.state_dict(), checkpoint_path)
    wandb.log_model(
        checkpoint_path,
        name=f"{wandb.run.id}-checkpoint",
        aliases=[f"epoch_{epoch}"],
    )
    ```

In [None]:
import os
from tqdm.auto import tqdm
from monai.data import decollate_batch

config.validation_intervals = 1
config.checkpoint_dir = "./checkpoints"

# Create checkpoint directory
os.makedirs(config.checkpoint_dir, exist_ok=True)

batch_step = 0
validation_step = 0
metric_values = []
metric_values_tumor_core = []
metric_values_whole_tumor = []
metric_values_enhanced_tumor = []

epoch_progress_bar = tqdm(range(config.max_train_epochs), desc="Training:")

for epoch in epoch_progress_bar:
    model.train()
    epoch_loss = 0

    total_batch_steps = len(train_dataset) // train_loader.batch_size
    batch_progress_bar = tqdm(train_loader, total=total_batch_steps, leave=False)

    # Training Step
    for batch_data in batch_progress_bar:
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()
        batch_progress_bar.set_description(f"train_loss: {loss.item():.4f}:")
        ## Log batch-wise training loss to W&B
        wandb.log({"batch/batch_step": batch_step, "batch/train_loss": loss.item()})
        batch_step += 1

    epoch_loss /= total_batch_steps
    ## Log batch-wise training loss and learning rate to W&B
    wandb.log(
        {
            "epoch/epoch_step": epoch,
            "epoch/mean_train_loss": epoch_loss,
            "epoch/learning_rate": lr_scheduler.get_last_lr()[0],
        }
    )
    lr_scheduler.step()
    epoch_progress_bar.set_description(f"Training: train_loss: {epoch_loss:.4f}:")

    # Validation and model checkpointing step
    if (epoch + 1) % config.validation_intervals == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                val_outputs = inference(model, val_inputs)
                val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                dice_metric(y_pred=val_outputs, y=val_labels)
                dice_metric_batch(y_pred=val_outputs, y=val_labels)

            metric_values.append(dice_metric.aggregate().item())
            metric_batch = dice_metric_batch.aggregate()
            metric_values_tumor_core.append(metric_batch[0].item())
            metric_values_whole_tumor.append(metric_batch[1].item())
            metric_values_enhanced_tumor.append(metric_batch[2].item())
            dice_metric.reset()
            dice_metric_batch.reset()

            # Log and versison model checkpoints using W&B artifacts.
            checkpoint_path = os.path.join(config.checkpoint_dir, "model.pth")
            torch.save(model.state_dict(), checkpoint_path)
            wandb.log_model(
                checkpoint_path,
                name=f"{wandb.run.id}-checkpoint",
                aliases=[f"epoch_{epoch}"],
            )

            # Log validation metrics to W&B dashboard.
            wandb.log(
                {
                    "validation/validation_step": validation_step,
                    "validation/mean_dice": metric_values[-1],
                    "validation/mean_dice_tumor_core": metric_values_tumor_core[-1],
                    "validation/mean_dice_whole_tumor": metric_values_whole_tumor[-1],
                    "validation/mean_dice_enhanced_tumor": metric_values_enhanced_tumor[-1],
                }
            )
            validation_step += 1

Now we end the experiment by calling `wandb.finish()`.

In [None]:
wandb.finish()