# Visualize Bain Tumor Segmentation Data

In this notebook we will learn:
- MONAI transform API:
  - MONAI Transforms for dictionary format data.
  - Creating custom transforms using [`monai.transforms`](https://docs.monai.io/en/stable/transforms.html) API.
- how we can visualize the brain tumor segmentation dataset using W&B image overlays.
- how we can analyze our data using W&B Tables.

## 🌴 Setup and Installation

First, let us install the latest version of both MONAI and Weights and Biases.

In [None]:
!pip install -q -U monai wandb

## 🌳 Initialize a W&B Run

We will start a new W&B run to start tracking our experiment.

In [None]:
import wandb

wandb.init(
    project="brain-tumor-segmentation",
    entity="lifesciences",
    job_type="visualize_dataset"
)

## 💿 Loading and Transforming the Data

We will now learn using the [`monai.transforms`](https://docs.monai.io/en/stable/transforms.html) API to create and apply transforms to our data.

### Creating a Custom Transform

First, we demonstrate the creation of a custom transform `ConvertToMultiChannelBasedOnBratsClassesd` using [`monai.transforms.MapTransform`](https://docs.monai.io/en/stable/transforms.html#maptransform) that converts labels to multi-channel tensors based on brats18 classes:
- label 1 is the necrotic and non-enhancing tumor core
- label 2 is the peritumoral edema
- label 3 is the GD-enhancing tumor.

The target classes for the semantic segmentation task after applying this transform on the dataset will be
- Tumor core
- Whole tumor
- Enhancing tumor

In [None]:
import torch
from monai.transforms import MapTransform


class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    Convert labels to multi-channels based on brats classes:
    label 1 is the peritumoral edema
    label 2 is the GD-enhancing tumor
    label 3 is the necrotic and non-enhancing tumor core
    The possible classes are TC (Tumor core), WT (Whole tumor), and ET (Enhancing tumor).

    Reference: https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/brats_segmentation_3d.ipynb

    """

    def __call__(self, data):
        data_dict = dict(data)
        for key in self.keys:
            result = []
            # merge label 2 and label 3 to construct Tumor Core
            result.append(torch.logical_or(data_dict[key] == 2, data_dict[key] == 3))
            # merge labels 1, 2 and 3 to construct Whole Tumor
            result.append(
                torch.logical_or(
                    torch.logical_or(data_dict[key] == 2, data_dict[key] == 3), data_dict[key] == 1
                )
            )
            # label 2 is Enhancing Tumor
            result.append(data_dict[key] == 2)
            data_dict[key] = torch.stack(result, axis=0).float()
        return data_dict

Next, we compose all the necessary transforms for visualizing the data using [`monai.transforms.Compose`](https://docs.monai.io/en/stable/transforms.html#monai.transforms.Compose).

**Note:** During training, we will apply a differnt set of transforms to the data.

In [None]:
from monai.transforms import (
    Compose,
    LoadImaged,
    NormalizeIntensityd,
    Orientationd,
    Spacingd,
    EnsureTyped,
    EnsureChannelFirstd,
)


transforms = Compose(
    [
        # Load 4 Nifti images and stack them together
        LoadImaged(keys=["image", "label"]),
        # Ensure loaded images are in channels-first format
        EnsureChannelFirstd(keys="image"),
        # Ensure the input data to be a PyTorch Tensor or numpy array
        EnsureTyped(keys=["image", "label"]),
        # Convert labels to multi-channels based on brats18 classes
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        # Change the input image’s orientation into the specified based on axis codes
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        # Resample the input images to the specified pixel dimension
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        # Normalize input image intensity
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

For loading the dataset, we first fetch it from the W&B dataset artifact that we had created earlier. This enables us to use the dataset as an input artifact to our visualization run, and establish the necessary lineage for our experiment.

![](./assets/artifact_usage.png)

In [None]:
artifact = wandb.use_artifact(
    "lifesciences/brain-tumor-segmentation/decathlon_brain_tumor:v0", type="dataset"
)
artifact_dir = artifact.download()

We now use the [`monai.apps.DecathlonDataset`](https://docs.monai.io/en/stable/apps.html#monai.apps.DecathlonDataset) to load our dataset and apply the transforms we defined on the data samples so that we can visualize it.

In [None]:
from monai.apps import DecathlonDataset


# Create the dataset for the training split
# of the brain tumor segmentation dataset
train_dataset = DecathlonDataset(
    root_dir=artifact_dir,
    task="Task01_BrainTumour",
    transform=transforms,
    section="training",
    download=True,
    cache_rate=0.0,
    num_workers=4,
)

# Create the dataset for the validation split
# of the brain tumor segmentation dataset
val_dataset = DecathlonDataset(
    root_dir=artifact_dir,
    task="Task01_BrainTumour",
    transform=transforms,
    section="validation",
    download=False,
    cache_rate=0.0,
    num_workers=4,
)

## 📸 Visualizing the Dataset

Weights & Biases supports images, video, audio, and more. You can log rich media to explore your results and visually compare our runs, models, and datasets. Now, you will learn using the [segmentation mask overlay](https://docs.wandb.ai/guides/track/log/media#image-overlays-in-tables) system to visualize our data volumes. To log segmentation masks in [W&B tables](https://docs.wandb.ai/guides/tables), you must provide a [`wandb.Image`](https://docs.wandb.ai/ref/python/data-types/image) object containing the segmentation annotations for each row in the table.

![](https://docs.wandb.ai/assets/images/viz-2-e3652d015abbf1d6d894e8edb1424eac.gif)

An example is provided in the pseudocode below:

```python
table = wandb.Table(columns=["ID", "Image"])

for id, img, label in zip(ids, images, labels):
    mask_img = wandb.Image(
        img,
        masks={
            "ground-truth": {"mask_data": label, "class_labels": class_labels}
            # ...
        },
    )

    table.add_data(id, img)

wandb.log({"Table": table})
```

However, in our case, since the volume of the target classes might overlap one another, we will log them as separate overlays on the same image, so that we do not miss the relevant information.

An example is provided in the pseudocode below:

```python
mask_img = wandb.Image(
    img,
    masks={
        "ground-truth/Tumor-Core": {
            "mask_data": label_tumor_core,
            "class_labels": {0: "background", 1: "Tumor Core"}
        },
        "ground-truth/Whole-Tumor": {
            "mask_data": label_tumor_core,
            "class_labels": {0: "background", 2: "Whole-Tumor"}
        },
        "ground-truth/Enhancing-Tumor": {
            "mask_data": label_tumor_core,
            "class_labels": {0: "background", 3: "Enhancing-Tumor"}
        },
    },
)
```

In [None]:
import numpy as np
from tqdm.auto import tqdm


def get_target_area_percentage(segmentation_map):
    segmentation_map_list = segmentation_map.flatten().tolist()
    return segmentation_map_list.count(1.0) * 100 / len(segmentation_map_list)


def log_data_samples_into_tables(
    sample_image: np.array,
    sample_label: np.array,
    split: str = None,
    data_idx: int = None,
    table: wandb.Table = None,
):
    """Utility function for logging a data sample into a W&B Table"""
    num_channels, _, _, num_slices = sample_image.shape
    with tqdm(total=num_slices, leave=False) as progress_bar:
        for slice_idx in range(num_slices):
            ground_truth_wandb_images, tumor_area_percentages = [], []
            for channel_idx in range(num_channels):
                masks = {
                    "ground-truth/Tumor-Core": {
                        "mask_data": sample_label[0, :, :, slice_idx],
                        "class_labels": {0: "background", 1: "Tumor Core"},
                    },
                    "ground-truth/Whole-Tumor": {
                        "mask_data": sample_label[1, :, :, slice_idx] * 2,
                        "class_labels": {0: "background", 2: "Whole Tumor"},
                    },
                    "ground-truth/Enhancing-Tumor": {
                        "mask_data": sample_label[2, :, :, slice_idx] * 3,
                        "class_labels": {0: "background", 3: "Enhancing Tumor"},
                    },
                }

                ground_truth_wandb_images.append(
                    wandb.Image(
                        sample_image[channel_idx, :, :, slice_idx],
                        masks=masks,
                    )
                )
                tumor_area_percentages.append(
                    {
                        "Tumor-Core-Area-Percentage": get_target_area_percentage(
                            sample_label[0, :, :, slice_idx]
                        ),
                        "Whole-Tumor-Area-Percentage": get_target_area_percentage(
                            sample_label[1, :, :, slice_idx]
                        ),
                        "Enhancing-Tumor-Area-Percentage": get_target_area_percentage(
                            sample_label[2, :, :, slice_idx]
                        ),
                    }
                )
            table.add_data(
                split,
                data_idx,
                slice_idx,
                *tumor_area_percentages,
                *ground_truth_wandb_images
            )
            progress_bar.update(1)
    return table

Next, we iterate over our respective datasets and populate the table on our W&B dashboard.

In [None]:
# Define the schema of the table
table = wandb.Table(
    columns=[
        "Split",
        "Data Index",
        "Slice Index",
        "Tumor-Area-Pixel-Percentages-Channel-0",
        "Tumor-Area-Pixel-Percentages-Channel-1",
        "Tumor-Area-Pixel-Percentages-Channel-2",
        "Tumor-Area-Pixel-Percentages-Channel-3",
        "Image-Channel-0",
        "Image-Channel-1",
        "Image-Channel-2",
        "Image-Channel-3",
    ]
)

In [None]:
# Generate visualizations for train_dataset
for data_idx, sample in tqdm(
    enumerate(train_dataset),
    total=len(train_dataset),
    desc="Generating Train Dataset Visualizations:",
):
    sample_image = sample["image"].detach().cpu().numpy()
    sample_label = sample["label"].detach().cpu().numpy()
    table = log_data_samples_into_tables(
        sample_image,
        sample_label,
        split="train",
        data_idx=data_idx,
        table=table,
    )

# Generate visualizations for val_dataset
for data_idx, sample in tqdm(
    enumerate(val_dataset),
    total=len(val_dataset),
    desc="Generating Validation Dataset Visualizations:",
):
    sample_image = sample["image"].detach().cpu().numpy()
    sample_label = sample["label"].detach().cpu().numpy()
    table = log_data_samples_into_tables(
        sample_image,
        sample_label,
        split="val",
        data_idx=data_idx,
        table=table,
    )

# Log the table to your dashboard
wandb.log({"tumor_segmentation_data": table})

In [None]:
wandb.finish()