In [None]:
!pip install pytorch-lightning-spells geffnet

In [None]:
from pathlib import Path
from typing import Callable, Union, Tuple, List

import torch
import numpy as np
import pandas as pd
from PIL import Image
from matplotlib import pyplot as plt
from albumentations import SmallestMaxSize, CenterCrop, Compose
from albumentations.pytorch.transforms import ToTensor
import pytorch_lightning as pl
from pytorch_lightning_spells.callbacks import CutMixCallback, MixUpCallback, SnapMixCallback
from torch.utils.data import Dataset, DataLoader

torch.multiprocessing.set_sharing_strategy('file_system')

OFFSET = np.asarray([0.485, 0.456, 0.406])[:, np.newaxis, np.newaxis]
SCALE = np.asarray([0.229, 0.224, 0.225])[:, np.newaxis, np.newaxis]
TRANSFORMATIONS = Compose([
    SmallestMaxSize(448),
    CenterCrop(448, 448),
    ToTensor(normalize=dict(
        mean=OFFSET[:,0,0], std=SCALE[:,0,0])
    )
])

In [None]:
def load_image(filepath: Path) -> Image.Image:
    image = np.array(Image.open(filepath).convert('RGB'))
    return image


def load_transform_image(filepath):
    image = load_image(filepath)
    return TRANSFORMATIONS(image=image)["image"]


class CassavaDataset(Dataset):
    def __init__(self, folder: Union[Path, str], df: pd.DataFrame):
        super().__init__()
        self._df = df
        self._folder = Path(folder)

    def __len__(self):
        return len(self._df)

    def __getitem__(self, idx: int):
        item = self._df.iloc[idx]
        image = load_transform_image(
            self._folder / item.image_id)
        return [image, torch.tensor(item.label, dtype=torch.int64)]

In [None]:
data_loader = DataLoader(
    CassavaDataset(
        Path("/kaggle/input/cassava-leaf-disease-classification/train_images/"),
        pd.read_csv("/kaggle/input/cassava-leaf-disease-classification/train.csv")
    ),
    shuffle=True,
    batch_size=16,
    num_workers=0,
    drop_last=False
)

## MixUp

In [None]:
mixup_cb = MixUpCallback(alpha=0.5, softmax_target=True)

In [None]:
batch, targets = next(iter(data_loader))
batch_packed = [batch, targets]
mixup_cb.on_train_batch_start(None, None, batch_packed, None, None)
batch_new, targets_new = batch_packed
fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(12, 12))
count=0
for row in ax:
    for col in row:
        col.imshow(
            ((batch_new[count].numpy() * SCALE + OFFSET).transpose(1,2,0) * 255.).astype(np.uint8)
        )
        col.set_axis_off()
        col.set_title(
            f"L1: {targets_new[count][0]:.0f} ({targets_new[count][2]*100:.0f}%) "
            f"| L2: {targets_new[count][1]:.0f} "
        )
        count += 1
plt.show()

Because the nature of this dataset, the mixed up image is hard to read. We can infer that Mix Up augmentation is probablly not the best choice for this dataset.

## CutMix

In [None]:
cutmix_cb = CutMixCallback(alpha=0.9, minmax=[0.2, 0.8], softmax_target=True)

In [None]:
batch, targets = next(iter(data_loader))
batch_packed = [batch, targets]
cutmix_cb.on_train_batch_start(None, None, batch_packed, None, None)
batch_new, targets_new = batch_packed
fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(12, 12))
count=0
for row in ax:
    for col in row:
        col.imshow(
            ((batch_new[count].numpy() * SCALE + OFFSET).transpose(1,2,0) * 255.).astype(np.uint8)
        )
        col.set_axis_off()
        col.set_title(
            f"L1: {targets_new[count][0]:.0f} ({targets_new[count][2]*100:.0f}%) "
            f"| L2: {targets_new[count][1]:.0f} "
        )
        count += 1
plt.show()

## SnapMix

Reference: [SnapMix: Semantically Proportional Mixing for Augmenting Fine-grained Data](https://arxiv.org/abs/2012.04846)

In [None]:
import geffnet
from fastcore.basics import patch_to

@patch_to(geffnet.gen_efficientnet.GenEfficientNet)
def extract_features(self, input_tensor):
    return self.features(input_tensor)

@patch_to(geffnet.gen_efficientnet.GenEfficientNet)
def get_fc(self):
    return self.classifier

In [None]:
# Load a baseline model
model_dict = torch.load("/kaggle/input/cassava-model/full_b4_41341_3290_randaug.pth")
model = geffnet.tf_efficientnet_b4_ns(pretrained=False, drop_rate=0.4, as_sequential=False)
model.classifier = torch.nn.Linear(model.classifier.in_features, 5)
model.load_state_dict(model_dict["model_states"])
_ = model.cuda()

### CutMix-style Bounding Boxes

In [None]:
snapmix_cb = SnapMixCallback(model, image_size=(448, 448), minmax = (0.2, 0.8), alpha=0.9, cutmix_bbox=True)

In [None]:
batch, targets = next(iter(data_loader))
batch_packed = [batch, targets]
snapmix_cb.on_train_batch_start(None, None, batch_packed, None, None)
batch_new, targets_new = batch_packed
fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(12, 12))
count=0
for row in ax:
    for col in row:
        col.imshow(
            ((batch_new[count].numpy() * SCALE + OFFSET).transpose(1,2,0) * 255.).astype(np.uint8)
        )
        col.set_axis_off()
        col.set_title(
            f"L1: {targets_new[count][0]:.0f} ({targets_new[count][2]*100:.0f}%) "
            f"| L2: {targets_new[count][1]:.0f} ({targets_new[count][3]*100:.0f}%)"
        )
        count += 1
plt.show()

### SnapMix-style Bounding Boxes

1. Randomly generate two bounding boxes (instead of one as in CutMix). 
2. The first bounding mox is used to extract a patch from the source image.
3. The extracted patch is resized to the size of the second bounding box.
4. The resized patch is put into the target image at the second bounding box.

In [None]:
snapmix_cb = SnapMixCallback(model, image_size=(448, 448), minmax = (0.2, 0.8), alpha=0.9, cutmix_bbox=False)

In [None]:
batch, targets = next(iter(data_loader))
batch_packed = [batch, targets]
snapmix_cb.on_train_batch_start(None, None, batch_packed, None, None)
batch_new, targets_new = batch_packed
fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(12, 12))
count=0
for row in ax:
    for col in row:
        col.imshow(
            ((batch_new[count].numpy() * SCALE + OFFSET).transpose(1,2,0) * 255.).astype(np.uint8)
        )
        col.set_axis_off()
        col.set_title(
            f"L1: {targets_new[count][0]:.0f} ({targets_new[count][2]*100:.0f}%) "
            f"| L2: {targets_new[count][1]:.0f} ({targets_new[count][3]*100:.0f}%)"
        )
        count += 1
plt.show()