# Brain Tumor Classification with PyTorch⚡Lightning & EfficientNet 3D

The goal of this challenge is to Predict the status of a genetic biomarker important for brain cancer treatment.

All the code is refered from public repository: https://github.com/Borda/kaggle_brain-tumor-3D
Any nice contribution is welcome!

In [None]:
! pip install -q https://github.com/Borda/kaggle_brain-tumor-3D/archive/refs/heads/main.zip
! pip install -q https://github.com/shijianjian/EfficientNet-PyTorch-3D/archive/refs/heads/master.zip
! pip install -q "pytorch-lightning==1.3.8"
! pip uninstall -q -y wandb
! ls -l /kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification
! nvidia-smi
! mkdir /kaggle/temp

%matplotlib inline
%load_ext autoreload
%autoreload 2

import kaggle_brain3d
print(kaggle_brain3d.__version__)

## Data exploration

These 3 cohorts are structured as follows: Each independent case has a dedicated folder identified by a five-digit number.
Within each of these “case” folders, there are four sub-folders, each of them corresponding to each of the structural multi-parametric MRI (mpMRI) scans, in DICOM format.
The exact mpMRI scans included are:

- **FLAIR**: Fluid Attenuated Inversion Recovery
- **T1w**: T1-weighted pre-contrast
- **T1Gd**: T1-weighted post-contrast
- **T2**: T2-weighted

In [None]:
import os
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

PATH_DATASET = "/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification"
PATH_TEMP = "/kaggle/temp"
SCAN_TYPES = ("FLAIR", "T1w", "T1CE", "T2w")

df_train = pd.read_csv(os.path.join(PATH_DATASET, "train_labels.csv"))
df_train["BraTS21ID"] = df_train["BraTS21ID"].apply(lambda i: "%05d" % i)
display(df_train.head())

See the dataset label distribution

In [None]:
_= df_train["MGMT_value"].value_counts().plot(kind="pie", title="label distribution", autopct="%.1f%%")

For almost all scans we have all four types

In [None]:
scans = [os.path.basename(p) for p in glob.glob(os.path.join(PATH_DATASET, "train", "*", "*"))]
_= pd.Series(scans).value_counts().plot(kind="bar", grid=True)

### Interactive view

showing particular scan in XYZ dimension/slices

In [None]:
from ipywidgets import interact, IntSlider

from kaggle_brain3d.utils import load_volume, interpolate_volume, show_volume
from kaggle_brain3d.transforms import crop_volume

def interactive_show(volume_path: str, crop_thr: float):
    print(f"loading: {volume_path}")
    volume = load_volume(volume_path, percentile=0)
    print(f"sample shape: {volume.shape} >> {volume.dtype}")
    volume = interpolate_volume(volume)
    print(f"interp shape: {volume.shape} >> {volume.dtype}")
    volume = crop_volume(volume, crop_thr)
    print(f"crop shape: {volume.shape} >> {volume.dtype}")
    vol_shape = volume.shape
    interact(
        lambda x, y, z: plt.show(show_volume(volume, x, y, z)),
        x=IntSlider(min=0, max=vol_shape[0], step=5, value=int(vol_shape[0] / 2)),
        y=IntSlider(min=0, max=vol_shape[1], step=5, value=int(vol_shape[1] / 2)),
        z=IntSlider(min=0, max=vol_shape[2], step=5, value=int(vol_shape[2] / 2)),
    )


PATH_SAMPLE_VOLUME = os.path.join(PATH_DATASET, "train", "00005", "FLAIR")

interactive_show(PATH_SAMPLE_VOLUME, crop_thr=1e-6)

## Prepare dataset

### Pytorch Dataset

The basic building block is traforming raw data to Torch Dataset.
We have here loading particular DICOM images into a volume and saving as temp/cacher, so we do not need to take the very time demanding loading do next time - this boost the IO from about 2h to 8min

At the end we show a few sample images from prepared dataset.

In [None]:
import os
import pandas as pd
import torch
from tqdm.auto import tqdm

from kaggle_brain3d.data import BrainScansDataset
from kaggle_brain3d.transforms import resize_volume

# ==============================

ds = BrainScansDataset(
    image_dir=os.path.join(PATH_DATASET, "train"),
    df_table=os.path.join(PATH_DATASET, "train_labels.csv"),
    crop_thr=None, cache_dir=PATH_TEMP,
)
for i in tqdm(range(2)):
    img = ds[i * 10]["data"]
    img = resize_volume(img[0])
    show_volume(img, fig_size=(12, 8))

### Lightning DataModule

It is constric to wrap all data-related peaces and define Pytoch dataloder for Training / Validation / Testing phase.

At the end we show a few sample images from the fost training batch.

In [None]:
from functools import partial
import rising.transforms as rtr
from rising.loading import DataLoader, default_transform_call
from rising.random import DiscreteParameter, UniformParameter

from kaggle_brain3d.data import BrainScansDM  # , TRAIN_TRANSFORMS, VAL_TRANSFORMS
from kaggle_brain3d.transforms import RandomAffine, rising_zero_mean

# ==============================

# Dataset >> mean: 0.13732214272022247 STD: 0.24326834082603455
rising_norm = partial(rising_zero_mean, mean=0.137, std=0.243)

# define transformations
TRAIN_TRANSFORMS = [
    rtr.Rot90((0, 1, 2), keys=["data"], p=0.5),
    rtr.Mirror(dims=DiscreteParameter([0, 1, 2]), keys=["data"]),
    RandomAffine(scale_range=(0.9, 1.1), rotation_range=(-10, 10), translation_range=(-0.1, 0.1)),
    rising_norm,
]
VAL_TRANSFORMS = [
    rising_norm,
]

# ==============================

dm = BrainScansDM(
    data_dir=PATH_DATASET,
    scan_types=["FLAIR"],
    input_size=224,
    crop_thr=1e-6,
    batch_size=3,
    cache_dir=PATH_TEMP,
    # in_memory=True,
    num_workers=2,
    train_transforms=rtr.Compose(TRAIN_TRANSFORMS, transform_call=default_transform_call),
    valid_transforms=rtr.Compose(VAL_TRANSFORMS, transform_call=default_transform_call),
)
dm.prepare_data(3)
dm.setup()
print(f"Training batches: {len(dm.train_dataloader())} and Validation {len(dm.val_dataloader())}")

# Quick view
for batch in dm.train_dataloader():
    for i in range(2):
        show_volume(batch["data"][i][0], fig_size=(9, 6), v_min_max=(-1., 3.))
    break

## Prepare 3D model

LightningModule is the core of PL, it wrappes all model related peaces, mainly:

- the model/architecture/weights
- evaluation metrics
- configs for optimizer and LR cheduler

In [None]:
import logging
from typing import Any, Optional, Sequence, Tuple, Union

import torch
import torch.nn.functional as F
from monai.networks.nets import EfficientNetBN
from pytorch_lightning import LightningModule
from torch import nn, Tensor
from torch.optim import AdamW, Optimizer
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchmetrics import AUROC, F1
from torchsummary import summary


class LitBrainMRI(LightningModule):

    def __init__(
        self,
        net: Union[nn.Module, str] = "efficientnet-b0",
        lr: float = 1e-3,
        optimizer: Optional[Optimizer] = None,
    ):
        super().__init__()
        if isinstance(net, str):
            self.name = net
            net = EfficientNetBN(net, spatial_dims=3, in_channels=1, num_classes=2)
        else:
            self.name = net.__class__.__name__
        self.net = net
        for _, param in self.net.named_parameters():
            param.requires_grad = True
        self.learning_rate = lr
        self.optimizer = optimizer or AdamW(self.net.parameters(), lr=self.learning_rate)

        self.train_auroc = AUROC(num_classes=2, compute_on_step=False)
        self.train_f1_score = F1()
        self.val_auroc = AUROC(num_classes=2, compute_on_step=False)
        self.val_f1_score = F1()

    def forward(self, x: Tensor) -> Tensor:
        return self.net(x)

    def compute_loss(self, y_hat: Tensor, y: Tensor):
        return F.cross_entropy(y_hat, y)

    def training_step(self, batch, batch_idx):
        img, y = batch["data"], batch["label"]
        y_hat = self(img)
        loss = self.compute_loss(y_hat, y)
        self.log("train/loss", loss, prog_bar=False)
        y_hat = F.softmax(y_hat)
        self.log("train/f1", self.train_f1_score(y_hat, y), prog_bar=True)
        self.train_auroc.update(y_hat, y)
        try:  # ToDo: use balanced sampler
            self.log('train/auroc', self.train_auroc, on_step=False, on_epoch=True)
        except ValueError:
            pass
        return loss

    def validation_step(self, batch, batch_idx):
        img, y = batch["data"], batch["label"]
        y_hat = self(img)
        loss = self.compute_loss(y_hat, y)
        self.log("valid/loss", loss, prog_bar=False)
        y_hat = F.softmax(y_hat)
        self.log("valid/f1", self.val_f1_score(y_hat, y), prog_bar=True)
        self.val_auroc.update(y_hat, y)
        try:  # ToDo: use balanced sampler
            self.log('valid/auroc', self.val_auroc, on_step=False, on_epoch=True)
        except ValueError:
            pass

    def configure_optimizers(self):
        scheduler = CosineAnnealingLR(self.optimizer, self.trainer.max_epochs, 0)
        return [self.optimizer], [scheduler]


# ==============================

model = LitBrainMRI(lr=5e-4)
# summary(model, input_size=(1, 128, 128, 128))

## Train a model

Lightning forces the following structure to your code which makes it reusable and shareable:

- Research code (the LightningModule).
- Engineering code (you delete, and is handled by the Trainer).
- Non-essential research code (logging, etc... this goes in Callbacks).
- Data (use PyTorch DataLoaders or organize them into a LightningDataModule).

Once you do this, you can train on multiple-GPUs, TPUs, CPUs and even in 16-bit precision without changing your code!

In [None]:
import pytorch_lightning as pl

logger = pl.loggers.CSVLogger(save_dir='logs/', name=model.name)
swa = pl.callbacks.StochasticWeightAveraging(swa_epoch_start=0.6)
ckpt = pl.callbacks.ModelCheckpoint(
    monitor='valid/auroc',
    save_top_k=1,
    filename='checkpoint/{epoch:02d}-{valid_auroc:.4f}',
    mode='max',
)

# ==============================

trainer = pl.Trainer(
    # overfit_batches=5,
    # fast_dev_run=True,
    gpus=1,
    callbacks=[ckpt , swa],  #
    logger=logger,
    max_epochs=10,
    precision=16,
    accumulate_grad_batches=24,
    # val_check_interval=0.5,
    progress_bar_refresh_rate=1,
    log_every_n_steps=5,
    weights_summary='top',
    auto_lr_find=True,
#     auto_scale_batch_size='binsearch',
)

# ==============================

# trainer.tune(
#     model, 
#     datamodule=dm, 
#     lr_find_kwargs=dict(min_lr=1e-5, max_lr=1e-3, num_training=20),
#     # scale_batch_size_kwargs=dict(max_trials=5),
# )
# print(f"Batch size: {dm.batch_size}")
# print(f"Learning Rate: {model.learning_rate}")

# ==============================

trainer.fit(model=model, datamodule=dm)

### Training progress

In [None]:
metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')
display(metrics.head())

aggreg_metrics = []
agg_col = "epoch"
for i, dfg in metrics.groupby(agg_col):
    agg = dict(dfg.mean())
    agg[agg_col] = i
    aggreg_metrics.append(agg)

df_metrics = pd.DataFrame(aggreg_metrics)
df_metrics[['train/loss', 'valid/loss']].plot(grid=True, legend=True, xlabel=agg_col)
df_metrics[['train/f1', 'train/auroc', 'valid/f1', 'valid/auroc']].plot(grid=True, legend=True, xlabel=agg_col)