# Cancer Detection With Tiles

Try detection with tiles, where for each large image, we randomly select a tile and assign it the original image's label. 

We can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
We can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

## Imports

In [None]:
import pytorch_lightning as pl
import torch
from torchvision.models import convnext_tiny, ConvNeXt_Tiny_Weights, convnext_base, ConvNeXt_Base_Weights, convnext_small, ConvNeXt_Small_Weights, efficientnet_b4, EfficientNet_B4_Weights
import torch.nn as nn
import torch.optim as optim
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from typing import List, Dict, Optional
import pandas as pd
import numpy as np
import os

import albumentations as albu
from albumentations.pytorch import ToTensorV2
import random
import matplotlib.pyplot as plt

from pathlib import Path
import random
import cv2

## Base Classes

In [None]:
class AugmentationTransforms:
    def __init__(self, image_size: int):
        self.image_size = image_size

    def get_training_augmentation(self):
        train_transform = [
            albu.HorizontalFlip(p=0.5),
            albu.VerticalFlip(p=0.5),
            albu.augmentations.transforms.GaussNoise(p=0.2),
            albu.OneOf(
                [
                    albu.CLAHE(p=1),
                    albu.RandomBrightnessContrast(p=1),
                    albu.RandomGamma(p=1),
                    albu.HueSaturationValue(p=1),
                ],
                p=0.5,
            ),
#             albu.OneOf(
#                 [
#                     albu.augmentations.transforms.Sharpen(p=1),
#                     albu.Blur(blur_limit=3, p=1),
#                     albu.MotionBlur(blur_limit=3, p=1),
#                 ],
#                 p=0.5,
#             ),
              albu.augmentations.geometric.resize.Resize(
                self.image_size, self.image_size, always_apply=True
            ),
        ]
        return albu.Compose(train_transform)

    def get_validation_augmentation(self):
        """Add paddings to make image shape divisible by 32"""
        test_transform = [
            albu.augmentations.geometric.resize.Resize(
                self.image_size, self.image_size, always_apply=True
            ),
        ]
        return albu.Compose(test_transform)

    def get_preprocessing(self):
        """Construct preprocessing transform

        Args:
            preprocessing_fn (callbale): data normalization function
                (can be specific for each pretrained neural network)
        Return:
            transform: albumentations.Compose

        """

        # Model expects input [N, C, H, W]
        # ToTensor convert HWC image to CHW image
        ubc_mean = [0.8894420586142374,0.8208752169441305,0.8864016141389351]
        ubc_std = [0.10106393015358608,0.15637655015581306,0.09892687853183287]
        transform = [
            albu.Normalize(mean=ubc_mean, std=ubc_std),
            ToTensorV2(),
        ]

        return albu.Compose(transform)


In [None]:
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import pandas as pd
import cv2


class CancerDataset(Dataset):
    def __init__(
        self,
        data_df: pd.DataFrame,
        black_threshold: int = 20,
        max_bg_threshold: float = 0.35,
        preprocessing=None,
        augmentation=None,
    ):
        self.data_df = data_df
        self.black_threshold = black_threshold
        self.max_bg_threshold = max_bg_threshold
        self.preprocessing = preprocessing
        self.augmentation = augmentation

        self.classes = ["HGSC", "LGSC", "EC", "CC", "MC"]
        self.class2idx = {label: idx for idx, label in enumerate(self.classes)}

        self.data_df.loc[:, "label"] = self.data_df.loc[:, "label"].map(self.class2idx)

        
        # Get filenames
        slice_dir = Path("/kaggle/input/tiles-of-cancer-2048px-scale-0-25/")
        slice_mask_dir = Path(
            "/kaggle/input/ubc-ocean-tiles-w-masks-2048px-scale-0-25/train_images"
        )
        # Annotations contain pixel values
        # Mask contains only values 0,1,2,3 for different classes
        self.mask_dir = Path(
            "/kaggle/input/ubc-ocean-tiles-w-masks-2048px-scale-0-25/train_annotations"
        )
        self.images_lists = []

        print("Generating image list")
        for idx, image in self.data_df.iterrows():
            if image["has_mask"]:
                images = list(
                    (slice_mask_dir / str(image["image_id"])).rglob("**/*.png")
                )
            else:
                images = list((slice_dir / str(image["image_id"])).rglob("**/*.png"))
                
            self.images_lists.append(images)
        print("Done")

    def __getitem__(self, i):
        image_id, label, _, _, _, has_mask = self.data_df.iloc[i]

        image_id_list = self.images_lists[i]
        random.shuffle(image_id_list)

        # We make a copy so we don't skip elements, when remove some elements
        # It takes to long to check, if all slices have tumor, so we do this procedurally
        for img_path in image_id_list.copy():
            img = cv2.imread(str(img_path))
            
            # If mask, check for tumor
            if has_mask:
                mask_slice = self.mask_dir / str(image_id) / img_path.name   
                mask_slice_img = cv2.imread(str(mask_slice))
                mask_slice_img = cv2.cvtColor(mask_slice_img, cv2.COLOR_BGR2RGB)
                
                # Check if the slice contains tumor
                # All non zero values in red channel are for tumor
                if np.all(mask_slice_img[..., 0] == 0):
                    # If mask has no tumor, then remove it
                    image_id_list.remove(img_path)
                    continue
            
            # If there not too much background, the exit the loop and serve the image
            black_bg = np.sum(img, axis=2) <= self.black_threshold
            if np.sum(black_bg) <= (np.prod(black_bg.shape[:2]) * self.max_bg_threshold):
                break

        # Replace black pixels with white pixels
        img[img <= self.black_threshold] = 255
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        if self.augmentation:
            img = self.augmentation(image=img)["image"]

        if self.preprocessing:
            img = self.preprocessing(image=img)["image"]

        return img, label

    def __len__(self):
        return len(self.data_df)


In [None]:
class CancerDataModule(pl.LightningDataModule):
    def __init__(
        self,
        image_size: int,
        batch_size: int,
        cutoff: float = 0.8,
        black_threshold: int = 20,
        max_bg_threshold: float = 0.35,
        shuffle: Optional[bool] = True,
    ):
        super().__init__()
        self.image_size = image_size
        self.train_batch_size = batch_size
        self.shuffle = shuffle
        self.cutoff = cutoff
        self.black_threshold = black_threshold
        self.max_bg_threshold = max_bg_threshold

        aug_transforms = AugmentationTransforms(self.image_size)

        self.preprocess_transforms = aug_transforms.get_preprocessing()
        self.train_transforms = aug_transforms.get_training_augmentation()
        self.val_transforms = aug_transforms.get_validation_augmentation()

    def setup(self, stage: Optional[str] = None):
        if stage == "fit":
            train_csv = pd.read_csv(
                "/kaggle/input/tiles-of-cancer-2048px-scale-0-25/train.csv"
            )
            train_csv["has_mask"] = False
            
            # Mark examples that have masks available
            masks_dir = Path("/kaggle/input/ubc-ocean-tiles-w-masks-2048px-scale-0-25/train_masks")
            for idx, row in train_csv.iterrows():
                if (masks_dir / str(row['image_id'])).exists():
                    train_csv.at[idx, "has_mask"] = True
            
            cutoff_point = int(len(train_csv) * self.cutoff)

            self.train_dataset = CancerDataset(
                train_csv.iloc[:cutoff_point],
                self.black_threshold,
                self.max_bg_threshold,
                self.preprocess_transforms,
                self.train_transforms,
            )
            self.validation_dataset = CancerDataset(
                train_csv.iloc[cutoff_point:],
                self.black_threshold,
                self.max_bg_threshold,
                self.preprocess_transforms,
                self.val_transforms,
            )

            print(f"The the training set has {len(self.train_dataset)} images")
            print(f"The the validation set has {len(self.validation_dataset)} images")

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.train_batch_size,
            shuffle=self.shuffle,
            num_workers=4,
        )

    def val_dataloader(self):
        return DataLoader(
            self.validation_dataset,
            batch_size=8,
            shuffle=False,
            num_workers=4,
        )


In [None]:
from torchmetrics.classification import (
    MulticlassAccuracy,
    MulticlassPrecision,
    MulticlassRecall,
    MulticlassF1Score,
)
from torchmetrics import MetricCollection
import timm


class CancerDetector(pl.LightningModule):
    def __init__(
        self,
        lr: float,
        gamma: float,
        model_name: str,
        warmup_epochs: int = 4,
        num_classes: int = 5,
        use_pretrain: bool = True,
    ):
        super().__init__()
        # TODO Use model preprocessing function
        self.model = self._get_model(model_name, num_classes, use_pretrain)

        self.loss_fn = nn.CrossEntropyLoss()
        self.lr = lr
        self.gamma = gamma
        self.warmup_epochs = warmup_epochs

        self.save_hyperparameters()

        # Should we use micro average? Default is macro
        metrics = MetricCollection(
            [
                MulticlassAccuracy(num_classes),
                MulticlassF1Score(num_classes),
                MulticlassPrecision(num_classes),
                MulticlassRecall(num_classes),
            ]
        )
        self.train_metrics = metrics.clone(prefix="train/")
        self.valid_metrics = metrics.clone(prefix="val/")

        self.train_step_outputs = []
        self.validation_step_outputs = []

    def _get_model(self, model_name: str, num_classes: int, use_pretrain: bool):
        if model_name == "efficientnet_b4":
            model = timm.create_model("efficientnet-b4", pretrained=use_pretrain, num_classes=num_classes)
        elif model_name == "efficientnet_b5":    
            model = timm.create_model("efficientnet-b5", pretrained=use_pretrain, num_classes=num_classes)
        elif model_name == "tiny_vit_21m_384":    
            model = timm.create_model("tiny_vit_21m_384", pretrained=use_pretrain, num_classes=num_classes)
        elif model_name == "convnextv2_tiny":    
            model = timm.create_model("convnextv2_tiny", pretrained=use_pretrain, num_classes=num_classes)
        else:
            raise Exception(f"Unknown model name {model_name}")
        
        return model

    def forward(self, imgs: torch.Tensor):
        return self.model(imgs)

    def training_step(self, batch: torch.Tensor, batch_idx: int):
        x, y = batch
        output = self(x)
        loss = self.loss_fn(output, y)

        self.train_metrics.update(output, y)
        self.train_step_outputs.append(loss.detach().item())
        
        return loss

    def validation_step(self, batch: torch.Tensor, batch_idx: int):
        x, y = batch
        output = self(x)
        loss = self.loss_fn(output, y)

        self.valid_metrics.update(output, y)
        self.validation_step_outputs.append(loss.detach().item())

        return loss
    
    def on_train_epoch_end(self):
        loss = np.mean(self.train_step_outputs)
        self.log("train/loss", loss, on_step=False, on_epoch=True)

        output = self.train_metrics.compute()
        self.log_dict(output)

        self.train_metrics.reset()
        self.train_step_outputs.clear()
        
    def on_validation_epoch_end(self):
        if not self.trainer.sanity_checking:
            loss = np.mean(self.validation_step_outputs)
            self.log("val/loss", loss, on_step=False, on_epoch=True)

            output = self.valid_metrics.compute()
            self.log_dict(output)

        self.valid_metrics.reset()
        self.validation_step_outputs.clear()

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.model.parameters(), lr=self.lr)
        
        warmup = optim.lr_scheduler.LinearLR(optimizer, total_iters=self.warmup_epochs)
        exponential = optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.gamma)
        scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup, exponential], milestones=[self.warmup_epochs])
        
        return [optimizer], [scheduler]


## Visualization
Visualize inputs before the network

In [None]:
# Visualize input images
def visualize_input(datamodule: pl.LightningDataModule):
    mean=np.array([0.8894420586142374,0.8208752169441305,0.8864016141389351]) * 255
    std=np.array([0.10106393015358608,0.15637655015581306,0.09892687853183287]) * 255
    
    datamodule.prepare_data()
    datamodule.setup("fit")
    train = datamodule.train_dataloader()
    imgs, labels = next(iter(train))
    
    # B x C x H x W to B x H x W x C
    imgs = imgs.permute((0,2,3,1))
    imgs = imgs * std + mean
    # Change the order of channels
    imgs = imgs.flip(3)
    imgs = imgs.numpy().astype('uint8')
    
    classes = ['HGSC', 'LGSC', 'EC', 'CC', 'MC']
    idx2class = {idx: class_name for idx, class_name in enumerate(classes)}
    
    plt.rcParams['figure.figsize'] = (8.0, 8.0) # set default size of plots

    for i, (img, label) in enumerate(zip(imgs, labels)):
        plt.subplot(4, 4, 1 + i)
        plt.imshow(img)
        plt.title(idx2class[label.item()])
        plt.axis('off')
        
        if i == 15:
            break
            
    plt.show()
    
cancer_module = CancerDataModule(320, 16, 0.8, 15)
visualize_input(cancer_module)

## Training

In [3]:
!echo -e "machine api.wandb.ai\n  login user\n  password 33b28a5d3e362bc21dfe1fc1759af32fdd74dec0" >> /root/.netrc

In [4]:
# import wandb

# # Example sweep configuration
# sweep_configuration= {
#     "method": "random",
#     "name": "sweep",
#     "metric": {"goal": "maximize", "name": "val/MulticlassF1Score"},
#     "parameters": {
#         "batch_size": {"values": [8, 16]},
#         "lr": {"distribution": "log_uniform_values", "max": 0.03, "min": 3e-5},
#         "lr_decay_gamma": {"distribution": "uniform", "max": 0.998, "min": 0.99},
#         "model_name": {"values": ["efficientnet-b5", "tiny_vit_21m_384", "convnextv2_tiny"]}
#     },
# }

# sweep_id = wandb.sweep(sweep=sweep_configuration, project="UBC Ovarian Cancer Classification")

Create sweep with ID: ahjvv36h
Sweep URL: https://wandb.ai/malev/UBC%20Ovarian%20Cancer%20Classification/sweeps/ahjvv36h


In [None]:
import wandb

main_config = {
    "warmup_epochs": 5,
    "input_size": 448,
    "train_val_cutoff": 0.8,
    "black_threshold": 15,
    "max_bg_threshold": 0.3,
    "epochs": 350,
}

def train():
    run = wandb.init()

    hyperparameters = wandb.config
    config = main_config | dict(hyperparameters)
    
    pl.seed_everything(seed=31415, workers=True)

    wandb_logger = WandbLogger(project="UBC Ovarian Cancer Classification", log_model=False, id=run.id, name=run.name)
    model = CancerDetector(
        lr=config["lr"],
        gamma=config["lr_decay_gamma"],
        model_name=config["model_name"],
        warmup_epochs=config["warmup_epochs"],
    )
    cancer_module = CancerDataModule(
        config["input_size"],
        config["batch_size"],
        cutoff=config["train_val_cutoff"],
        black_threshold=config["black_threshold"],
        max_bg_threshold=config["max_bg_threshold"],
    )

    # Initialize callbacks
    lr_monitor = LearningRateMonitor()
    early_stopping = EarlyStopping(
        monitor="val/MulticlassF1Score", min_delta=0.0001, patience=25, mode="max"
    )
    checkpoints = ModelCheckpoint(
        monitor="train/MulticlassF1Score" if config["train_val_cutoff"] == 1.0 else "val/MulticlassF1Score",
        save_top_k=3,
        mode="max",
        save_weights_only=True,
        save_last=True,
        auto_insert_metric_name=False,
        filename="epoch={epoch}-loss={val/loss:.4f}-f1={val/MulticlassF1Score:.4f}",
    )

    # To disable validation, set limit_val_batches=0
    enable_val = 0 if config["train_val_cutoff"] == 1.0 else 1.0
    print(enable_val)
    trainer = pl.Trainer(
        logger=wandb_logger,
        max_epochs=config["epochs"],
        accelerator="gpu",
        devices=1,
        limit_val_batches=enable_val,
        precision="16-mixed",
        callbacks=[lr_monitor, checkpoints], #early_stopping,
    )

    trainer.fit(model, datamodule=cancer_module)
    print(trainer.checkpoint_callback.best_model_path)

    with open("best_model.txt", "w") as f:
        f.write(trainer.checkpoint_callback.best_model_path)

    trainer.save_checkpoint("cancer_classification_model.pt")
    
    
wandb.agent("ahjvv36h", function=train, project="UBC Ovarian Cancer Classification", count=1)

### Convert to ONNX

In [None]:
model = CancerDetector.load_from_checkpoint(
    "cancer_classification_model.pt", use_pretrain=False
)
model.eval()

x = torch.randn(1, 3, config["input_size"], config["input_size"], requires_grad=True)

# Export the model
model.to_onnx(
    "cancer_classification_model.onnx",
    x,
    opset_version=17,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)

In [None]:
!pip download onnxruntime

- https://huggingface.co/timm/efficientnet_b5.sw_in12k_ft_in1k input size 448
- https://huggingface.co/timm/tf_efficientnet_b4.ns_jft_in1k input size 380
- tf_efficientnetv2_s_in21ft1k
- tf_efficientnet_b4.ns_jft_in1k
- tiny_vit_21m_384.dist_in22k_ft_in1k
- edgenext_base.in21k_ft_in1k