# Evaluate Bain Tumor Segmentation Data

In this notebook we will learn:
- how we can evaluate a pre-trained model checkpoint for brain tumor segmentation using MONAI and Weights & Biases.
- how we can visually compare the ground-truth labels with the predicted labels.

## 🌴 Setup and Installation

First, let us install the latest version of both MONAI and Weights and Biases.

In [None]:
!pip install -q -U monai wandb

## 🌳 Initialize a W&B Run

We will start a new W&B run to start tracking our experiment. Note that we set the job type for this run as `evaluate`.

In [None]:

import wandb
from monai.utils import set_determinism

wandb.init(
    project="brain-tumor-segmentation",
    entity="lifesciences",
    job_type="evaluate"
)

config = wandb.config

# Ensure deterministic behavior and reproducibility
config.seed = 0
set_determinism(seed=config.seed)

## 💿 Loading and Transforming the Data

We will use the validation transforms from the previous lessons to load and transform the validation dataset using the Decathlon dataset artifact on W&B.

In [None]:
from utils import ConvertToMultiChannelBasedOnBratsClassesd
from monai.apps import DecathlonDataset
from monai.transforms import (
    Compose,
    LoadImaged,
    NormalizeIntensityd,
    Orientationd,
    Spacingd,
    EnsureTyped,
    EnsureChannelFirstd,
)


transforms = Compose(
    [
        # load 4 Nifti images and stack them together
        LoadImaged(keys=["image", "label"]),
        # Ensure loaded images are in channels-first format
        EnsureChannelFirstd(keys="image"),
        # Ensure the input data to be a PyTorch Tensor or numpy array
        EnsureTyped(keys=["image", "label"]),
        # Convert labels to multi-channels based on brats18 classes
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        # Change the input image’s orientation into the specified based on axis codes
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        # Resample the input images to the specified pixel dimension
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        # Normalize input image intensity
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)


# Fetch the brain tumor segmentation dataset artifact from W&B
artifact = wandb.use_artifact(
    "lifesciences/brain-tumor-segmentation/decathlon_brain_tumor:latest",
    type="dataset",
)
artifact_dir = artifact.download()


# Create the dataset for the test split
# of the brain tumor segmentation dataset
val_dataset = DecathlonDataset(
    root_dir=artifact_dir,
    task="Task01_BrainTumour",
    transform=transforms,
    section="validation",
    download=False,
    cache_rate=0.0,
    num_workers=4,
)

## 🤖 Loading the Model Checkpoint

We are going to fetch the model checkpoints from the training run and load them.

In [None]:
import os
import torch
from monai.networks.nets import SegResNet

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

config.model_blocks_down = [1, 2, 2, 4]
config.model_blocks_up = [1, 1, 1]
config.model_init_filters = 16
config.model_in_channels = 4
config.model_out_channels = 3
config.model_dropout_prob = 0.2

# create model
model = SegResNet(
    blocks_down=config.model_blocks_down,
    blocks_up=config.model_blocks_up,
    init_filters=config.model_init_filters,
    in_channels=config.model_in_channels,
    out_channels=config.model_out_channels,
    dropout_prob=config.model_dropout_prob,
).to(device)


# Fetch the latest model checkpoint artifact from the training run
model_artifact = wandb.use_artifact(
    "lifesciences/brain-tumor-segmentation/8vmqcqao-checkpoint:latest",
    type="model",
)
model_artifact_dir = model_artifact.download()


# Load the model checkpoint
model.load_state_dict(torch.load(os.path.join(model_artifact_dir, "model.pth")))
model.eval()

## 📈 Evaluating the Model

First we define some instances of `monai.metrics.DiceMetric` for all the metrics that we will be evaluating the model against on the validation split of our dataset.

In [None]:
from monai.metrics import DiceMetric
from monai.transforms import Activations, AsDiscrete

# Dice score for each class
tumor_core_dice_metric = DiceMetric(include_background=True, reduction="mean")
enhancing_tumor_dice_metric = DiceMetric(include_background=True, reduction="mean")
whole_tumor_dice_metric = DiceMetric(include_background=True, reduction="mean")

# Mean dice score across all classes
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")

# transforms to postprocess the outputs of the model for evaluation and visualization
postprocessing_transforms = Compose(
    [Activations(sigmoid=True), AsDiscrete(threshold=0.5)]
)

Next, we write some utility functions for evaluating each data-point from the validation dataset by logging dice score for each target class and the ground-truth and predicted segmentation labels (for granular visual comparison and analysis) to a W&B Table.

In [None]:
from tqdm.auto import tqdm


def get_target_area_percentage(segmentation_map):
    segmentation_map_list = segmentation_map.flatten().tolist()
    return segmentation_map_list.count(1.0) * 100 / len(segmentation_map_list)


def get_class_wise_dice_scores(sample_label, predicted_label, slice_idx):
    sample_label = torch.from_numpy(sample_label).to(device)
    predicted_label = torch.from_numpy(predicted_label).to(device)
    tumor_core_dice_metric(
        y_pred=torch.unsqueeze(predicted_label[1, :, :, slice_idx], dim=0),
        y=torch.unsqueeze(sample_label[0, :, :, slice_idx], dim=0),
    )
    whole_tumor_dice_metric(
        y_pred=torch.unsqueeze(predicted_label[1, :, :, slice_idx], dim=0),
        y=torch.unsqueeze(sample_label[1, :, :, slice_idx], dim=0),
    )
    enhancing_tumor_dice_metric(
        y_pred=torch.unsqueeze(predicted_label[2, :, :, slice_idx], dim=0),
        y=torch.unsqueeze(sample_label[2, :, :, slice_idx], dim=0),
    )
    dice_scores = {
        "Tumor-Core": tumor_core_dice_metric.aggregate().item(),
        "Enhancing-Tumor": enhancing_tumor_dice_metric.aggregate().item(),
        "Whole-Tumor": whole_tumor_dice_metric.aggregate().item(),
    }
    tumor_core_dice_metric.reset()
    whole_tumor_dice_metric.reset()
    enhancing_tumor_dice_metric.reset()
    return dice_scores


def log_predictions_into_tables(
    sample_image,
    sample_label,
    predicted_label,
    split: str = None,
    data_idx: int = None,
    table: wandb.Table = None,
):
    sample_image = sample_image.cpu().numpy()
    sample_label = sample_label.cpu().numpy()
    predicted_label = predicted_label.cpu().numpy()
    _, _, _, num_slices = sample_image.shape
    with tqdm(total=num_slices, leave=False) as progress_bar:
        for slice_idx in range(num_slices):
            tumor_core_dice_metric
            wandb_images = [
                wandb.Image(
                    sample_image[0, :, :, slice_idx],
                    masks={
                        "ground-truth/Tumor-Core": {
                            "mask_data": sample_label[0, :, :, slice_idx],
                            "class_labels": {0: "background", 1: "Tumor Core"},
                        },
                        "prediction/Tumor-Core": {
                            "mask_data": predicted_label[0, :, :, slice_idx] * 2,
                            "class_labels": {0: "background", 2: "Tumor Core"},
                        },
                    },
                ),
                wandb.Image(
                    sample_image[0, :, :, slice_idx],
                    masks={
                        "ground-truth/Whole-Tumor": {
                            "mask_data": sample_label[1, :, :, slice_idx],
                            "class_labels": {0: "background", 1: "Whole Tumor"},
                        },
                        "prediction/Whole-Tumor": {
                            "mask_data": predicted_label[1, :, :, slice_idx] * 2,
                            "class_labels": {0: "background", 2: "Whole Tumor"},
                        },
                    },
                ),
                wandb.Image(
                    sample_image[0, :, :, slice_idx],
                    masks={
                        "ground-truth/Enhancing-Tumor": {
                            "mask_data": sample_label[2, :, :, slice_idx],
                            "class_labels": {0: "background", 1: "Enhancing Tumor"},
                        },
                        "prediction/Enhancing-Tumor": {
                            "mask_data": predicted_label[2, :, :, slice_idx] * 2,
                            "class_labels": {0: "background", 2: "Enhancing Tumor"},
                        },
                    },
                ),
            ]
            tumor_area_percentage = {
                "Ground-Truth": {
                    "Tumor-Core": get_target_area_percentage(
                        sample_label[0, :, :, slice_idx]
                    ),
                    "Whole-Tumor": get_target_area_percentage(
                        sample_label[1, :, :, slice_idx]
                    ),
                    "Enhancing-Tumor": get_target_area_percentage(
                        sample_label[2, :, :, slice_idx]
                    ),
                },
                "Prediction": {
                    "Tumor-Core": get_target_area_percentage(
                        predicted_label[0, :, :, slice_idx]
                    ),
                    "Whole-Tumor": get_target_area_percentage(
                        predicted_label[1, :, :, slice_idx]
                    ),
                    "Enhancing-Tumor": get_target_area_percentage(
                        predicted_label[2, :, :, slice_idx]
                    ),
                },
            }
            dice_scores = get_class_wise_dice_scores(
                sample_label, predicted_label, slice_idx
            )
            table.add_data(
                split,
                data_idx,
                slice_idx,
                dice_scores,
                tumor_area_percentage,
                *wandb_images
            )
            progress_bar.update(1)
    return table

Next, we create the prediction table.

In [None]:
evaluation_table = wandb.Table(
    columns=[
        "Split",
        "Data Index",
        "Slice Index",
        "Dice-Score",
        "Tumor-Area-Pixel-Percentage",
        "Prediction/Tumor-Core",
        "Prediction/Whole-Tumor",
        "Prediction/Enhancing-Tumor",
    ]
)

Finally, we loop over the validation dataset and log the evaluation table and the mean dice scores for each class across the entore validation set to W&B.

In [None]:
from utils import inference

total_tumor_core_dice_score = 0.0
total_whole_tumor_dice_score = 0.0
total_enhancing_tumor_dice_score = 0.0

config.inference_roi_size = (240, 240, 160)

# Perform inference and visualization
with torch.no_grad():
    for data_idx, sample in tqdm(enumerate(val_dataset), total=len(val_dataset), desc="Evaluating:"):
        test_input, test_labels = (
            torch.unsqueeze(sample["image"], 0).to(device),
            torch.unsqueeze(sample["label"], 0).to(device),
        )
        test_output = inference(model, test_input, config.inference_roi_size)
        test_output = postprocessing_transforms(test_output[0])
        dice_metric_batch(y_pred=torch.unsqueeze(test_output, dim=0), y=test_labels)
        metric_batch = dice_metric_batch.aggregate()
        evaluation_table = log_predictions_into_tables(
            sample_image=torch.squeeze(test_input),
            sample_label=torch.squeeze(test_labels),
            predicted_label=test_output,
            data_idx=data_idx,
            split="validation",
            table=evaluation_table,
        )
        total_tumor_core_dice_score += metric_batch[0].item()
        total_whole_tumor_dice_score += metric_batch[1].item()
        total_enhancing_tumor_dice_score += metric_batch[2].item()

    wandb.log({"Tumor-Segmentation-Evaludation": evaluation_table})
    wandb.summary["Tumor-Score-Dice-Score"] = total_tumor_core_dice_score / len(val_dataset)
    wandb.summary["Whole-Tumor-Dice-Score"] = total_whole_tumor_dice_score / len(val_dataset)
    wandb.summary["Enhancing-Tumor-Dice-Score"] = total_enhancing_tumor_dice_score / len(val_dataset)

# End the experiment
wandb.finish()