# Cancer Detection With Tiles

Try detection with tiles, where for each large image, we randomly select a tile and assign it the original image's label. 

We can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
We can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

## Imports

In [1]:
import pytorch_lightning as pl
import torch
from torchvision.models import convnext_tiny, ConvNeXt_Tiny_Weights, convnext_base, ConvNeXt_Base_Weights, convnext_small, ConvNeXt_Small_Weights, efficientnet_b4, EfficientNet_B4_Weights
import torch.nn as nn
import torch.optim as optim
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from typing import List, Dict, Optional
import pandas as pd
import numpy as np
import os

import albumentations as albu
from albumentations.pytorch import ToTensorV2
import random
import matplotlib.pyplot as plt

from pathlib import Path
import random
import cv2



## Base Classes

In [2]:
class AugmentationTransforms:
    def __init__(self, image_size: int):
        self.image_size = image_size

    def get_training_augmentation(self):
        train_transform = [
            albu.HorizontalFlip(p=0.5),
            albu.VerticalFlip(p=0.5),
            albu.augmentations.transforms.GaussNoise(p=0.2),
            albu.OneOf(
                [
                    albu.CLAHE(p=1),
                    albu.RandomBrightnessContrast(p=1),
                    albu.RandomGamma(p=1),
                    albu.HueSaturationValue(p=1),
                ],
                p=0.5,
            ),
#             albu.OneOf(
#                 [
#                     albu.augmentations.transforms.Sharpen(p=1),
#                     albu.Blur(blur_limit=3, p=1),
#                     albu.MotionBlur(blur_limit=3, p=1),
#                 ],
#                 p=0.5,
#             ),
              albu.augmentations.geometric.resize.Resize(
                self.image_size, self.image_size, always_apply=True
            ),
        ]
        return albu.Compose(train_transform)

    def get_validation_augmentation(self):
        """Add paddings to make image shape divisible by 32"""
        test_transform = [
            albu.augmentations.geometric.resize.Resize(
                self.image_size, self.image_size, always_apply=True
            ),
        ]
        return albu.Compose(test_transform)

    def get_preprocessing(self):
        """Construct preprocessing transform

        Args:
            preprocessing_fn (callbale): data normalization function
                (can be specific for each pretrained neural network)
        Return:
            transform: albumentations.Compose

        """

        # Model expects input [N, C, H, W]
        # ToTensor convert HWC image to CHW image
        ubc_mean = [0.8894420586142374,0.8208752169441305,0.8864016141389351]
        ubc_std = [0.10106393015358608,0.15637655015581306,0.09892687853183287]
        transform = [
            albu.Normalize(mean=ubc_mean, std=ubc_std),
            ToTensorV2(),
        ]

        return albu.Compose(transform)

In [3]:
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import pandas as pd
import cv2


class CancerDataset(Dataset):
    def __init__(
        self,
        data_df: pd.DataFrame,
        bag_size: int,
        black_threshold: int = 20,
        max_bg_threshold: float = 0.35,
        preprocessing=None,
        augmentation=None,
    ):
        self.data_df = data_df
        self.bag_size = bag_size
        self.black_threshold = black_threshold
        self.max_bg_threshold = max_bg_threshold
        self.preprocessing = preprocessing
        self.augmentation = augmentation

        self.classes = ["HGSC", "LGSC", "EC", "CC", "MC"]
        self.class2idx = {label: idx for idx, label in enumerate(self.classes)}

        self.data_df.loc[:, "label"] = self.data_df.loc[:, "label"].map(self.class2idx)
        self.images_lists = [
            list(
                Path(
                    "/kaggle/input/tiles-of-cancer-2048px-scale-0-25", str(image_id)
                ).rglob("**/*.png")
            )
            for image_id in self.data_df["image_id"]
        ]

    def __getitem__(self, i):
        image_id, label, _, _, _ = self.data_df.iloc[i]

        image_id_list = self.images_lists[i]
        # or
        # image_list = list(Path("/kaggle/input/tiles-of-cancer-2048px-scale-0-25/", str(image_id)).rglob("**/*.png"))
        random.shuffle(image_id_list)
        
        suitable_imgs = []
        suitable_img_count = 0
        
        # Construct a bag
        for img_path in image_id_list:
            img = cv2.imread(str(img_path))
            
            # Check if doesn't contain too much black background
            black_bg = np.sum(img, axis=2) <= self.black_threshold
            if np.sum(black_bg) <= (np.prod(black_bg.shape) * self.max_bg_threshold):
                # Replace black bg with white
                black_pixels = np.where(img == 0)
                img[black_pixels] = 255
                
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                
                if self.augmentation:
                    img = self.augmentation(image=img)["image"]

                if self.preprocessing:
                    img = self.preprocessing(image=img)["image"]
                
                suitable_imgs.append(img)
                suitable_img_count += 1
                
            # Check if we have reached the limit on the number of allowed elements
            if suitable_img_count == self.bag_size:
                break
        #print(suitable_imgs)
        bag = torch.stack(suitable_imgs)

        return bag, torch.tensor(label)

    def __len__(self):
        return len(self.data_df)


In [4]:
def collate(batch):
    batch_data = []
    batch_labels = []

    for data, label in batch:
        batch_data.append(data)
        batch_labels.append(label)
    
    out_data = torch.cat(batch_data, dim = 0)
    out_labels = torch.stack(batch_labels)

    return out_data, out_labels

class CancerDataModule(pl.LightningDataModule):
    def __init__(
        self,
        image_size: int,
        batch_size: int,
        bag_size: int,
        cutoff: float = 0.8,
        black_threshold: int = 20,
        max_bg_threshold: float = 0.35,
        shuffle: Optional[bool] = True,
    ):
        super().__init__()
        self.image_size = image_size
        self.train_batch_size = batch_size
        self.bag_size = bag_size
        self.shuffle = shuffle
        self.cutoff = cutoff
        self.black_threshold = black_threshold
        self.max_bg_threshold = max_bg_threshold

        aug_transforms = AugmentationTransforms(self.image_size)

        self.preprocess_transforms = aug_transforms.get_preprocessing()
        self.train_transforms = aug_transforms.get_training_augmentation()
        self.val_transforms = aug_transforms.get_validation_augmentation()

    def setup(self, stage: Optional[str] = None):
        if stage == "fit":
            train_csv = pd.read_csv(
                "/kaggle/input/tiles-of-cancer-2048px-scale-0-25/train.csv"
            )
            cutoff_point = int(len(train_csv) * self.cutoff)
            
            self.train_dataset = CancerDataset(
                train_csv.iloc[:cutoff_point],
                self.bag_size,
                self.black_threshold,
                self.max_bg_threshold,
                self.preprocess_transforms,
                self.train_transforms,
            )
            self.validation_dataset = CancerDataset(
                train_csv.iloc[cutoff_point:],
                self.bag_size,
                self.black_threshold,
                self.max_bg_threshold,
                self.preprocess_transforms,
                self.val_transforms,
            )

            print(f"The the training set has {len(self.train_dataset)} images")
            print(f"The the validation set has {len(self.validation_dataset)} images")

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.train_batch_size,
            shuffle=self.shuffle,
            num_workers=4,
            collate_fn=collate,
        )

    def val_dataloader(self):
        batch_size = min(self.train_batch_size, 8)
        return DataLoader(
            self.validation_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=4,
            collate_fn=collate,
        )


In [5]:
import timm
import torch.nn.functional as F

class AttnMIL(nn.Module):
    def __init__(self, num_classes, pretrained: bool):
        # Input of 256
#         torch.Size([1, 64, 128, 128])
#         torch.Size([1, 64, 64, 64])
#         torch.Size([1, 128, 32, 32])
#         torch.Size([1, 256, 16, 16])
#         torch.Size([1, 512, 8, 8])
        super(AttnMIL, self).__init__()
        self.L = 500
        self.D = 128
        self.K = 1

#         self.feature_extractor_part1 = nn.Sequential(
#             nn.Conv2d(3, 20, kernel_size=5),
#             nn.ReLU(),
#             nn.MaxPool2d(2, stride=2),
#             nn.Conv2d(20, 50, kernel_size=5),
#             nn.ReLU(),
#             nn.MaxPool2d(2, stride=2)
#         )

        self.feature_extractor_part1 = timm.create_model('resnet34', pretrained=pretrained, features_only=True, out_indices=[4])
    
        self.feature_extractor_part2 = nn.Sequential(
            nn.Linear(512*8*8, self.L),
            nn.ReLU(),
        )

        self.attention = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh(),
            nn.Linear(self.D, self.K)
        )

        self.classifier = nn.Sequential(
            nn.Linear(self.L*self.K, num_classes),
        )

    def forward(self, x):
        #x = x.squeeze(0)

        H = self.feature_extractor_part1(x)[0]
        H = H.view(-1, 512*8*8)
        H = self.feature_extractor_part2(H)  # NxL

        A = self.attention(H)  # NxK
        A = torch.transpose(A, 1, 0)  # KxN
        A = F.softmax(A, dim=1)  # softmax over N

        M = torch.mm(A, H)  # KxL

        y = self.classifier(M)

        return y

In [6]:
from torchmetrics.classification import (
    MulticlassAccuracy,
    MulticlassPrecision,
    MulticlassRecall,
    MulticlassF1Score,
)
from torchmetrics import MetricCollection


class CancerDetector(pl.LightningModule):
    def __init__(
        self,
        lr: float,
        gamma: float,
        model_name: str,
        batch_size: int,
        warmup_epochs: int = 4,
        num_classes: int = 5,
        init_weights: bool = True,
    ):
        super().__init__()
        # TODO Use model preprocessing function
        self.model = self._get_model(model_name, num_classes, init_weights)

        self.loss_fn = nn.CrossEntropyLoss()
        self.lr = lr
        self.gamma = gamma
        self.warmup_epochs = warmup_epochs
        self.batch_size = batch_size

        self.save_hyperparameters()

        # Should we use micro average? Default is macro
        metrics = MetricCollection(
            [
                MulticlassAccuracy(num_classes),
                MulticlassF1Score(num_classes),
                MulticlassPrecision(num_classes),
                MulticlassRecall(num_classes),
            ]
        )
        self.train_metrics = metrics.clone(prefix="train/")
        self.valid_metrics = metrics.clone(prefix="val/")

        self.train_step_outputs = []
        self.validation_step_outputs = []

    def _get_model(self, model_name: str, num_classes: int, init_weights: bool):
        if model_name == "attnmil":
            model = AttnMIL(num_classes, init_weights)
        else:
            raise Exception(f"Unknown model name {model_name}")

        return model

    def forward(self, imgs: torch.Tensor):
        return self.model(imgs)

    def training_step(self, batch: torch.Tensor, batch_idx: int):
        x, y = batch
        output = self(x)
        loss = self.loss_fn(output, y)

        self.train_metrics.update(output, y)
        self.train_step_outputs.append(loss.detach().item())

        return loss

    def validation_step(self, batch: torch.Tensor, batch_idx: int):
        x, y = batch
        output = self(x)
        loss = self.loss_fn(output, y)

        self.valid_metrics.update(output, y)
        self.validation_step_outputs.append(loss.detach().item())

        return loss

    def on_train_epoch_end(self):
        loss = np.mean(self.train_step_outputs)
        self.log("train/loss", loss, on_step=False, on_epoch=True)

        output = self.train_metrics.compute()
        self.log_dict(output)

        self.train_metrics.reset()
        self.train_step_outputs.clear()

    def on_validation_epoch_end(self):
        if not self.trainer.sanity_checking:
            loss = np.mean(self.validation_step_outputs)
            self.log("val/loss", loss, on_step=False, on_epoch=True)

            output = self.valid_metrics.compute()
            self.log_dict(output)

        self.valid_metrics.reset()
        self.validation_step_outputs.clear()

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.model.parameters(), lr=self.lr)

        warmup = optim.lr_scheduler.LinearLR(optimizer, total_iters=self.warmup_epochs)
        exponential = optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.gamma)
        scheduler = optim.lr_scheduler.SequentialLR(
            optimizer, schedulers=[warmup, exponential], milestones=[self.warmup_epochs]
        )

        return [optimizer], [scheduler]


## Visualization
Visualize inputs before the network

In [7]:
# # Visualize input images
# def visualize_input(datamodule: pl.LightningDataModule):
#     mean=np.array((0.485, 0.456, 0.406)) * 255
#     std=np.array((0.229, 0.224, 0.225)) * 255
    
#     datamodule.prepare_data()
#     datamodule.setup("fit")
#     train = datamodule.train_dataloader()
#     imgs, labels = next(iter(train))
    
#     # B x C x H x W to B x H x W x C
#     imgs = imgs.permute((0,2,3,1))
#     imgs = imgs * std + mean
#     # Change the order of channels
#     imgs = imgs.flip(3)
#     imgs = imgs.numpy().astype('uint8')
    
#     classes = ['HGSC', 'LGSC', 'EC', 'CC', 'MC']
#     idx2class = {idx: class_name for idx, class_name in enumerate(classes)}
    
#     plt.rcParams['figure.figsize'] = (8.0, 8.0) # set default size of plots

#     for i, (img, label) in enumerate(zip(imgs, labels)):
#         plt.subplot(4, 4, 1 + i)
#         plt.imshow(img)
#         plt.title(idx2class[label.item()])
#         plt.axis('off')
        
#         if i == 15:
#             break
            
#     plt.show()
    
# cancer_module = CancerDataModule(320, 16)
# visualize_input(cancer_module)

### Find Max Number of Patches
What is the largest bag size

In [8]:
folder = Path("/kaggle/input/tiles-of-cancer-2048px-scale-0-25")
subfolder = [f for f in folder.iterdir() if f.is_dir()]

counts = 0
for s in subfolder:
    counts = max(counts, len(list(s.glob("*.png"))))

print(f"Maximum number of slices is {counts}")

Maximum number of slices is 663


## Training

In [9]:
!echo -e "machine api.wandb.ai\n  login user\n  password 33b28a5d3e362bc21dfe1fc1759af32fdd74dec0" >> /root/.netrc

In [10]:
config = {
    "lr": 9e-4, #3e-3,
    "lr_decay_gamma": 0.89, # 0.054 after 25 epochs
    "warmup_epochs": 3,
    "model_name": "attnmil",
    "input_size": 256,
    "batch_size": 1,
    "bag_size": 350,
    "epochs": 25,
    "train_val_cutoff": 0.8,
    "black_threshold": 20,
    "max_bg_threshold": 0.3,
}


pl.seed_everything(seed=31415, workers=True)

wandb_logger = WandbLogger(project="UBC Ovarian Cancer Classification", log_model=False)
model = CancerDetector(
    lr=config["lr"],
    gamma=config["lr_decay_gamma"],
    model_name=config["model_name"],
    batch_size=config["batch_size"],
    warmup_epochs=config["warmup_epochs"],
)
cancer_module = CancerDataModule(
    config["input_size"],
    config["batch_size"],
    config["bag_size"],
    cutoff=config["train_val_cutoff"],
    black_threshold=config["black_threshold"],
    max_bg_threshold=config["max_bg_threshold"],
)

# Initialize callbacks
lr_monitor = LearningRateMonitor()
# early_stopping = EarlyStopping(
#     monitor="val/f1", min_delta=0.0002, patience=8, mode="max"
# )
checkpoints = ModelCheckpoint(
    monitor="val/MulticlassF1Score",
    save_top_k=3,
    mode="max",
    save_weights_only=True,
    save_last=True,
    auto_insert_metric_name=False,
    filename="epoch={epoch}-loss={val/loss:.4f}-f1={val/f1:.4f}",
)

trainer = pl.Trainer(
    logger=wandb_logger,
    max_epochs=config["epochs"],
    accelerator="gpu",
    devices=1,
    precision="16-mixed",
    callbacks=[lr_monitor, checkpoints], #early_stopping,
) 

trainer.fit(model, datamodule=cancer_module)
print(trainer.checkpoint_callback.best_model_path)

with open("best_model.txt", "w") as f:
    f.write(trainer.checkpoint_callback.best_model_path)


trainer.save_checkpoint("cancer_classification_model.pt")


[34m[1mwandb[0m: Currently logged in as: [33mkarl-joan-alesma[0m ([33mmalev[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: wandb version 0.16.1 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.15.9
[34m[1mwandb[0m: Run data is saved locally in [35m[1m./wandb/run-20231210_140411-xfzavdh6[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mdifferent-deluge-84[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/malev/UBC%20Ovarian%20Cancer%20Classification[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/malev/UBC%20Ovarian%20Cancer%20Classification/runs/xfzavdh6[0m


Downloading model.safetensors:   0%|          | 0.00/87.3M [00:00<?, ?B/s]

The the training set has 430 images
The the validation set has 108 images


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

./UBC Ovarian Cancer Classification/xfzavdh6/checkpoints/epoch=16-loss=2.7357-f1=0.0000.ckpt


### TODO
#### Validate Code
- Fix seed
- [x] Overfit a batch with a single image
- [x] Input independent baseline test
#### Fit
- [x] Bigger model
- [x] weight decay
- [ ] augmentations
- [ ] early stopping
- [ ] lr scheduler
- [ ] hyper parameters