# HuBMAP PyTorch ⚡ MONAI Train & Infer

## Combine powers of [PyTorch Lightning](https://www.pytorchlightning.ai/) and [MONAI](https://monai.io/)

# Imports

In [2]:
from pathlib import Path
from typing import Any
from typing import Callable
from typing import Dict
from typing import Optional
from typing import Tuple

import monai
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import seaborn as sns
import tifffile
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from monai.data import CSVDataset
from monai.data import DataLoader
from monai.data import ImageReader
from sklearn.model_selection import StratifiedKFold
from tqdm.notebook import tqdm

# Paths & Settings

In [None]:
KAGGLE_DIR = Path("/home/ziyang/kaggle/hubmap-organ-segmentation")

INPUT_DIR = KAGGLE_DIR / "input"
OUTPUT_DIR = KAGGLE_DIR / "working"

COMPETITION_DATA_DIR = INPUT_DIR / "hubmap-organ-segmentation"

TRAIN_PREPARED_CSV_PATH = "train_prepared.csv"
VAL_PRED_PREPARED_CSV_PATH = "val_pred_prepared.csv"
TEST_PREPARED_CSV_PATH = "test_prepared.csv"

N_SPLITS = 4
RANDOM_SEED = 2022
SPATIAL_SIZE = 1024
VAL_FOLD = 0
NUM_WORKERS = 2
BATCH_SIZE = 16
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 0.0
FAST_DEV_RUN = False
GPUS = 1
MAX_EPOCHS = 10
PRECISION = 16
DEBUG = False

DEVICE = "cuda"
THRESHOLD = 0.5

# Prepare DataFrames (Add paths and create folds)

In [None]:
def add_path_to_df(df: pd.DataFrame, data_dir: Path, type_: str, stage: str) -> pd.DataFrame:
    ending = ".tiff" if type_ == "image" else ".npy"
    
    dir_ = str(data_dir / f"{stage}_{type_}s") if type_ == "image" else f"{stage}_{type_}s"
    df[type_] = dir_ + "/" + df["id"].astype(str) + ending
    return df


def add_paths_to_df(df: pd.DataFrame, data_dir: Path, stage: str) -> pd.DataFrame:
    df = add_path_to_df(df, data_dir, "image", stage)
    df = add_path_to_df(df, data_dir, "mask", stage)
    return df


def create_folds(df: pd.DataFrame, n_splits: int, random_seed: int) -> pd.DataFrame:
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_seed)
    for fold, (_, val_idx) in enumerate(skf.split(X=df, y=df["organ"])):
        df.loc[val_idx, "fold"] = fold

    return df


def prepare_data(data_dir: Path, stage: str, n_splits: int, random_seed: int) -> None:
    df = pd.read_csv(data_dir / f"{stage}.csv")
    df = add_paths_to_df(df, data_dir, stage)

    if stage == "train":
        df = create_folds(df, n_splits, random_seed)

    filename = f"{stage}_prepared.csv"
    df.to_csv(filename, index=False)

    print(f"Created {filename} with shape {df.shape}")

    return df

In [None]:
train_df = prepare_data(COMPETITION_DATA_DIR, "train", N_SPLITS, RANDOM_SEED)
test_df = prepare_data(COMPETITION_DATA_DIR, "test", N_SPLITS, RANDOM_SEED)

In [None]:
train_df

In [None]:
test_df

# Save Train Masks as NumPy Arrays

In [None]:
def rle2mask(mask_rle: str, shape: Tuple[int, int]) -> np.ndarray:
    """
    mask_rle: run-length as string formated (start length)
    shape: (width,height) of array to return
    Returns numpy array, 1 - mask, 0 - background
    Source: https://www.kaggle.com/paulorzp/rle-functions-run-lenght-encode-decode
    """
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape).T


def save_array(file_path: str, array: np.ndarray) -> None:
    file_path = Path(file_path)
    file_path.parent.mkdir(parents=True, exist_ok=True)
    np.save(file_path, array)


def save_masks(df: pd.DataFrame) -> None:
    for row in tqdm(df.itertuples(), total=len(df)):
        mask = rle2mask(row.rle, shape=(row.img_width, row.img_height))
        save_array(row.mask, mask)

In [None]:
save_masks(train_df)

# Lightning DataModule

In [None]:
class TIFFImageReader(ImageReader):
    def read(self, data: str) -> np.ndarray:
        return tifffile.imread(data)

    def get_data(self, img: np.ndarray) -> Tuple[np.ndarray, Dict[str, Any]]:
        return img, {"spatial_shape": np.asarray(img.shape), "original_channel_dim": -1}

    def verify_suffix(self, filename: str) -> bool:
        return ".tiff" in filename

In [None]:
class LitDataModule(pl.LightningDataModule):
    def __init__(
        self,
        train_csv_path: str,
        test_csv_path: str,
        spatial_size: int,
        val_fold: int,
        batch_size: int,
        num_workers: int,
    ):
        super().__init__()

        self.save_hyperparameters()

        self.train_df = pd.read_csv(train_csv_path)
        self.test_df = pd.read_csv(test_csv_path)

        self.train_transform, self.val_transform, self.test_transform = self._init_transforms()

    def _init_transforms(self) -> Tuple[Callable, Callable, Callable]:
        spatial_size = (self.hparams.spatial_size, self.hparams.spatial_size)
        train_transform = monai.transforms.Compose(
            [
                monai.transforms.LoadImaged(keys=["image"], reader=TIFFImageReader),
                monai.transforms.EnsureChannelFirstd(keys=["image"]),
                monai.transforms.ScaleIntensityd(keys=["image"]),
                monai.transforms.LoadImaged(keys=["mask"]),
                monai.transforms.AddChanneld(keys=["mask"]),
                #monai.transforms.RandAxisFlipd(keys=["image", "mask"], prob=0.5),
                monai.transforms.RandFlipd(keys=["image", "mask"], spatial_axis=[0], prob=0.5),
                monai.transforms.RandFlipd(keys=["image", "mask"], spatial_axis=[1], prob=0.5),
                monai.transforms.RandRotate90d(keys=["image", "mask"], prob=0.5),
                monai.transforms.OneOf(
                    [
                        monai.transforms.RandGridDistortiond(keys=["image", "mask"], prob=0.5, distort_limit=0.2),
                        #monai.transforms.RandAffined(keys=["image", "mask"], prob=0.5, rotate_range=0.2, shear_range=0.2, scale_range=0.2),
                    ]
                ),
                monai.transforms.OneOf(
                    [
                        monai.transforms.RandShiftIntensityd(keys=["image"], offsets=0.10, prob=0.5),
                        monai.transforms.RandAdjustContrastd(keys=["image"], prob=0.5, gamma=(1.5, 2.5)),
                        monai.transforms.RandHistogramShiftd(keys=["image"], prob=0.5),
                    ]
                ),
                monai.transforms.Resized(keys=["image", "mask"], spatial_size=spatial_size),
            ]
        )

        val_transform = monai.transforms.Compose(
            [
                monai.transforms.LoadImaged(keys=["image"], reader=TIFFImageReader),
                monai.transforms.EnsureChannelFirstd(keys=["image"]),
                monai.transforms.ScaleIntensityd(keys=["image"]),
                monai.transforms.LoadImaged(keys=["mask"]),
                monai.transforms.AddChanneld(keys=["mask"]),
                monai.transforms.Resized(keys=["image", "mask"], spatial_size=spatial_size),
            ]
        )

        test_transform = monai.transforms.Compose(
            [
                monai.transforms.LoadImaged(keys=["image"], reader=TIFFImageReader),
                monai.transforms.EnsureChannelFirstd(keys=["image"]),
                monai.transforms.ScaleIntensityd(keys=["image"]),
                monai.transforms.Resized(keys=["image"], spatial_size=spatial_size),
            ]
        )

        return train_transform, val_transform, test_transform

    def setup(self, stage: str = None):
        if stage == "fit" or stage is None:
            train_df = self.train_df[self.train_df.fold != self.hparams.val_fold].reset_index(drop=True)
            val_df = self.train_df[self.train_df.fold == self.hparams.val_fold].reset_index(drop=True)

            self.train_dataset = self._dataset(train_df, transform=self.train_transform)
            self.val_dataset = self._dataset(val_df, transform=self.val_transform)

        if stage == "test" or stage is None:
            self.test_dataset = self._dataset(self.test_df, transform=self.test_transform)

    def _dataset(self, df: pd.DataFrame, transform: Callable) -> CSVDataset:
        return CSVDataset(src=df, transform=transform)

    def train_dataloader(self) -> DataLoader:
        return self._dataloader(self.train_dataset, train=True)

    def val_dataloader(self) -> DataLoader:
        return self._dataloader(self.val_dataset)

    def test_dataloader(self) -> DataLoader:
        return self._dataloader(self.test_dataset)

    def _dataloader(self, dataset: CSVDataset, train: bool = False) -> DataLoader:
        return DataLoader(
            dataset,
            batch_size=self.hparams.batch_size,
            shuffle=train,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
        )

# Visualize Images and Masks

In [None]:
def show_image(title: str, image: np.ndarray, mask: Optional[np.ndarray] = None):
    plt.title(title)
    plt.imshow(image)

    if mask is not None:
        plt.imshow(mask, alpha=0.2)

    plt.tight_layout()
    plt.axis("off")


def show_batch(batch: Dict, nrows: int, show_mask: bool = True):
    fig, _ = plt.subplots(figsize=(3 * nrows, 3 * nrows))

    for idx, _ in enumerate(batch["image"]):
        plt.subplot(nrows, nrows, idx + 1)

        title = batch["id"][idx].numpy()
        image = np.transpose(batch["image"][idx].numpy(), axes=(1, 2, 0))
        mask = np.transpose(batch["mask"][idx].numpy(), axes=(1, 2, 0)) if show_mask else None

        show_image(title, image, mask)

## Setup DataModule

In [None]:
nrows = 3

data_module = LitDataModule(
    train_csv_path=TRAIN_PREPARED_CSV_PATH,
    test_csv_path=TEST_PREPARED_CSV_PATH,
    spatial_size=SPATIAL_SIZE,
    val_fold=VAL_FOLD,
    batch_size=nrows ** 2,
    num_workers=NUM_WORKERS,
)
data_module.setup()

## Train Images

In [None]:
train_batch = next(iter(data_module.train_dataloader()))
show_batch(train_batch, nrows)

## Test Images

In [None]:
test_batch = next(iter(data_module.test_dataloader()))
show_batch(test_batch, nrows, show_mask=False)

# Lightning Module

In [None]:
class LitModule(pl.LightningModule):
    def __init__(
        self,
        learning_rate: float,
        weight_decay: float,
    ):
        super().__init__()

        self.save_hyperparameters()

        self.model = self._init_model()

        self.loss_fn = self._init_loss_fn()

        # TODO: add metric

    def _init_model(self) -> nn.Module:
        # TODO: try other networks
        return monai.networks.nets.UNet(
            spatial_dims=2,
            in_channels=3,
            out_channels=1,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
        )

    def _init_loss_fn(self):
        # TODO: try other losses
        return monai.losses.DiceLoss(sigmoid=True)

    def configure_optimizers(self):
        # TODO: try other optimizers and schedulers
        return torch.optim.Adam(
            params=self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay
        )

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        return self.model(images)

    def training_step(self, batch: Dict, batch_idx: int) -> torch.Tensor:
        images, masks = batch["image"], batch["mask"]
        outputs = self(images)

        loss = self.loss_fn(outputs, masks)

        self.log("train_loss", loss, batch_size=images.shape[0])

        return loss

    def validation_step(self, batch: Dict, batch_idx: int) -> None:
        images, masks = batch["image"], batch["mask"]
        outputs = self(images)

        loss = self.loss_fn(outputs, masks)

        self.log("val_loss", loss, prog_bar=True, batch_size=images.shape[0])

    @classmethod
    def load_eval_checkpoint(cls, checkpoint_path: str, device: str) -> nn.Module:
        module = cls.load_from_checkpoint(checkpoint_path=checkpoint_path).to(device)
        module.eval()

        return module

# Train

In [None]:
def train(
    random_seed: int = RANDOM_SEED,
    train_csv_path: str = str(TRAIN_PREPARED_CSV_PATH),
    test_csv_path: str = str(TEST_PREPARED_CSV_PATH),
    spatial_size: Tuple[int, int] = SPATIAL_SIZE,
    val_fold: str = VAL_FOLD,
    batch_size: int = BATCH_SIZE,
    num_workers: int = NUM_WORKERS,
    learning_rate: float = LEARNING_RATE,
    weight_decay: float = WEIGHT_DECAY,
    fast_dev_run: bool = FAST_DEV_RUN,
    gpus: int = GPUS,
    max_epochs: int = MAX_EPOCHS,
    precision: int = PRECISION,
    debug: bool = DEBUG,
) -> None:
    pl.seed_everything(random_seed)

    data_module = LitDataModule(
        train_csv_path=train_csv_path,
        test_csv_path=test_csv_path,
        spatial_size=spatial_size,
        val_fold=val_fold,
        batch_size=2 if debug else batch_size,
        num_workers=num_workers,
    )

    module = LitModule(
        learning_rate=learning_rate,
        weight_decay=weight_decay,
    )

    trainer = pl.Trainer(
        fast_dev_run=fast_dev_run,
        gpus=gpus,
        limit_train_batches=0.1 if debug else 1.0,
        limit_val_batches=0.1 if debug else 1.0,
        log_every_n_steps=5,
        logger=pl.loggers.CSVLogger(save_dir='logs/'),
        max_epochs=2 if debug else max_epochs,
        precision=precision,
    )

    trainer.fit(module, datamodule=data_module)
    
    return trainer

In [None]:
trainer = train()

In [None]:
# From https://www.kaggle.com/code/jirkaborovec?scriptVersionId=93358967&cellId=22
metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv")[["epoch", "train_loss", "val_loss"]]
metrics.set_index("epoch", inplace=True)

sns.relplot(data=metrics, kind="line", height=5, aspect=1.5)
plt.grid()

# Infer

In [None]:
def mask2rle(img):
    '''
    Efficient implementation of mask2rle, from @paulorzp
    --
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    Source: https://www.kaggle.com/xhlulu/efficient-mask2rle
    '''
    pixels = img.T.flatten()
    pixels = np.pad(pixels, ((1, 1), ))
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)


@torch.no_grad()
def create_pred_df(module, dataloader, threshold):
    ids = []
    rles = []
    for batch in tqdm(dataloader):
        id_ = batch["id"].numpy()[0]
        height = batch["img_height"].numpy()[0]
        width = batch["img_width"].numpy()[0]
        
        images = batch["image"].to(module.device)
        outputs = module(images)[0]
        
        post_pred_transform = monai.transforms.Compose(
            [
                monai.transforms.Resize(spatial_size=(height, width), mode="nearest"),
                monai.transforms.Activations(sigmoid=True),
                monai.transforms.AsDiscrete(threshold=threshold),
            ]
        )
        
        mask = post_pred_transform(outputs).to(torch.uint8).cpu().detach().numpy()[0]
        
        rle = mask2rle(mask)
        
        ids.append(id_)
        rles.append(rle)
        
    return pd.DataFrame({"id": ids, "rle": rles})


def infer(
    checkpoint_path: str,
    device: str = DEVICE,
    train_csv_path: str = TRAIN_PREPARED_CSV_PATH,
    test_csv_path: str = TEST_PREPARED_CSV_PATH,
    spatial_size: int = SPATIAL_SIZE,
    num_workers: int = NUM_WORKERS,
    threshold: float = THRESHOLD,
):
    module = LitModule.load_eval_checkpoint(checkpoint_path, device)

    data_module = LitDataModule(
        train_csv_path=train_csv_path,
        test_csv_path=test_csv_path,
        spatial_size=spatial_size,
        val_fold=0,
        batch_size=1,
        num_workers=num_workers,
    )
    data_module.setup()
    
    val_dataloader = data_module.val_dataloader()
    test_dataloader = data_module.test_dataloader()
    
    val_pred_df = create_pred_df(module, val_dataloader, threshold)
    test_pred_df = create_pred_df(module, test_dataloader, threshold)
    
    return val_pred_df, test_pred_df


In [None]:
checkpoint_path = list((Path(trainer.logger.log_dir) / "checkpoints").glob("*.ckpt"))[0]
val_pred_df, test_pred_df = infer(checkpoint_path)

# Submit

In [None]:
test_pred_df.to_csv("submission.csv", index=False)
test_pred_df

## Visualize Val Predictions

In [None]:
val_pred_df = add_path_to_df(val_pred_df, COMPETITION_DATA_DIR, "mask", "pred")
val_pred_df

In [None]:
val_df = train_df[train_df.fold == VAL_FOLD].reset_index(drop=True)
val_df

In [None]:
val_pred_df = val_pred_df.merge(val_df, on="id", suffixes=("", "_gt"))
val_pred_df

In [None]:
save_masks(val_pred_df)

In [None]:
val_pred_df.to_csv(VAL_PRED_PREPARED_CSV_PATH, index=False)

## GT Val Images

In [None]:
nrows = 3

data_module = LitDataModule(
    train_csv_path=TRAIN_PREPARED_CSV_PATH,
    test_csv_path=TEST_PREPARED_CSV_PATH,
    spatial_size=SPATIAL_SIZE,
    val_fold=VAL_FOLD,
    batch_size=nrows ** 2,
    num_workers=0,
)
data_module.setup()

val_batch = next(iter(data_module.val_dataloader()))
show_batch(val_batch, nrows)

## Pred Val Images

In [None]:
data_module = LitDataModule(
    train_csv_path=VAL_PRED_PREPARED_CSV_PATH,
    test_csv_path=TEST_PREPARED_CSV_PATH,
    spatial_size=SPATIAL_SIZE,
    val_fold=VAL_FOLD,
    batch_size=nrows ** 2,
    num_workers=0,
)
data_module.setup()

val_batch = next(iter(data_module.val_dataloader()))
show_batch(val_batch, nrows)