In [None]:
!pip install kornia
!pip install pytorch_lightning

In [None]:
from google.colab import drive
drive.mount('/gdrive')

Mounted at /gdrive


In [None]:
cat_to_idx = {'Cassava bacterial blight (cbb)': 0,
              'Cassava brown streak disease (cbsd)': 1,
              'Cassava green mottle (cgm)': 2,
              'Cassava mosaic disease (cmd)': 3,
              'Healthy': 4}


In [None]:
import random
from typing import Callable, Tuple
from typing import Union


from kornia import augmentation as aug
from kornia import filters
from kornia.geometry import transform as tf
import torch
from torch import nn, Tensor

from torch.utils.data import Dataset, DataLoader
import pandas as pd
from PIL import Image
from torchvision import transforms


from copy import deepcopy
from itertools import chain
from typing import Dict, List

import pytorch_lightning as pl
from torch import optim
import torch.nn.functional as f

from torchvision.models import resnet18, resnet34, resnet50

In [None]:
class CustomDataset(Dataset):
    def __init__(self, data_path, transforms, mapping = cat_to_idx):
        self.data_path = data_path
        self.transforms = transforms
        self.mapping = cat_to_idx

    def __len__(self):
        return len(self.data_path)

    def __getitem__(self, idx):

        im = Image.open(self.data_path.loc[idx, 'file_name'])
        label = int(self.mapping[self.data_path.loc[idx, 'classes']])

        if self.transforms:
            im = self.transforms(im)
        return im, label

In [None]:
train_data = pd.read_csv('/content/remo_train.csv')
validation_data = pd.read_csv('/content/remo_valid.csv')

train_transforms = transforms.Compose([transforms.ToTensor()])
val_transforms = transforms.Compose([transforms.ToTensor()])


#train_dl = DataLoader(CustomDataset(data_path=train_data, transforms=train_transforms), batch_size=128, num_workers=2, pin_memory=True)
#val_dl = DataLoader(CustomDataset(data_path=validation_data, transforms=val_transforms), batch_size = 50, num_workers=2, pin_memory=True)

train_dl = DataLoader(CustomDataset(data_path=train_data, transforms=train_transforms), batch_size=128, num_workers=2)
val_dl = DataLoader(CustomDataset(data_path=validation_data, transforms=val_transforms), batch_size = 50, num_workers=2)

In [None]:
class RandomApply(nn.Module):
    def __init__(self, fn: Callable, p: float):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x: Tensor) -> Tensor:
        return x if random.random() > self.p else self.fn(x)


def default_augmentation(image_size: Tuple[int, int] = (224, 224)) -> nn.Module:
    return nn.Sequential(
        tf.Resize(size=image_size),
        RandomApply(aug.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
        aug.RandomGrayscale(p=0.2),
        aug.RandomHorizontalFlip(),
        RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
        aug.RandomResizedCrop(size=image_size),
        aug.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225]),
        ),
    )

In [None]:
def mlp(dim: int, projection_size: int = 256, hidden_size: int = 4096) -> nn.Module:
    return nn.Sequential(
        nn.Linear(dim, hidden_size),
        nn.BatchNorm1d(hidden_size),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, projection_size),
    )


class EncoderWrapper(nn.Module):
    def __init__(
        self,
        model: nn.Module,
        projection_size: int = 256,
        hidden_size: int = 4096,
        layer: Union[str, int] = -2,
    ):
        super().__init__()
        self.model = model
        self.projection_size = projection_size
        self.hidden_size = hidden_size
        self.layer = layer

        self._projector = None
        self._projector_dim = None
        self._encoded = torch.empty(0)
        self._register_hook()

    @property
    def projector(self):
        if self._projector is None:
            self._projector = mlp(
                self._projector_dim, self.projection_size, self.hidden_size
            )
        return self._projector
    
    def _hook(self, _, __, output):
        output = output.flatten(start_dim=1)
        if self._projector_dim is None:
            # If we haven't already, measure the output size
            self._projector_dim = output.shape[-1]

        # Project the output to get encodings
        self._encoded = self.projector(output)

    def _register_hook(self):
        if isinstance(self.layer, str):
            layer = dict([*self.model.named_modules()])[self.layer]
        else:
            layer = list(self.model.children())[self.layer]

        layer.register_forward_hook(self._hook)
        
    # ------------------- End hooks methods ----------------------

    def forward(self, x: Tensor) -> Tensor:
        # Pass through the model, and collect 'encodings' from our forward hook!
        _ = self.model(x)
        return self._encoded

In [None]:
def normalized_mse(x: Tensor, y: Tensor) -> Tensor:
    x = f.normalize(x, dim=-1)
    y = f.normalize(y, dim=-1)
    return torch.mean(2 - 2 * (x * y).sum(dim=-1))


class BYOL(pl.LightningModule):
    def __init__(
        self,
        model: nn.Module,
        image_size: Tuple[int, int] = (96, 96),
        hidden_layer: Union[str, int] = -2,
        projection_size: int = 256,
        hidden_size: int = 4096,
        augment_fn: Callable = None,
        beta: float = 0.99,
        **hparams,
    ):
        super().__init__()
        self.augment = default_augmentation(image_size) if augment_fn is None else augment_fn
        self.beta = beta
        self.encoder = EncoderWrapper(
            model, projection_size, hidden_size, layer=hidden_layer
        )
        self.predictor = nn.Linear(projection_size, projection_size, hidden_size)
        self.hparams = hparams
        self._target = None

        # Perform a single forward pass, which initializes the 'projector' in our 
        # 'EncoderWrapper' layer.
        self.encoder(torch.zeros(2, 3, *image_size))

    def forward(self, x: Tensor) -> Tensor:
        return self.predictor(self.encoder(x))

    @property
    def target(self):
        if self._target is None:
            self._target = deepcopy(self.encoder)
        return self._target

    def update_target(self):
        for p, pt in zip(self.encoder.parameters(), self.target.parameters()):
            pt.data = self.beta * pt.data + (1 - self.beta) * p.data

    # --- Methods required for PyTorch Lightning only! ---

    def configure_optimizers(self):
        optimizer = getattr(optim, self.hparams.get("optimizer", "Adam"))
        lr = self.hparams.get("lr", 1e-4)
        weight_decay = self.hparams.get("weight_decay", 1e-6)
        return optimizer(self.parameters(), lr=lr, weight_decay=weight_decay)

    def training_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]:
        x = batch[0]
        with torch.no_grad():
            x1, x2 = self.augment(x), self.augment(x)

        pred1, pred2 = self.forward(x1), self.forward(x2)
        with torch.no_grad():
            targ1, targ2 = self.target(x1), self.target(x2)
        loss = (normalized_mse(pred1, targ2) + normalized_mse(pred2, targ1)) / 2

        self.log("train_loss", loss.item())
        return {"loss": loss}

    @torch.no_grad()
    def validation_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]:
        x = batch[0]
        x1, x2 = self.augment(x), self.augment(x)
        pred1, pred2 = self.forward(x1), self.forward(x2)
        targ1, targ2 = self.target(x1), self.target(x2)
        loss = (normalized_mse(pred1, targ2) + normalized_mse(pred2, targ1)) / 2

        return {"loss": loss}

    @torch.no_grad()
    def validation_epoch_end(self, outputs: List[Dict]) -> Dict:
        val_loss = sum(x["loss"] for x in outputs) / len(outputs)
        self.log("val_loss", val_loss.item())

In [None]:
model = resnet50(pretrained=True)
byol = BYOL(model, image_size=(96, 96))

trainer = pl.Trainer(gpus=1)

lr_finder = trainer.tuner.lr_find(byol, train_dl, val_dl)

fig = lr_finder.plot(); fig.show()
suggested_lr = lr_finder.suggestion()

In [None]:
suggested_lr

In [None]:
from torchvision.models import resnet18, resnet34, resnet50
model = resnet50(pretrained=True)
byol = BYOL(model, lr = 0.001, image_size=(96, 96))

trainer = pl.Trainer(
    max_epochs=50, 
    gpus=1,
    # Batch size of 2048 matches the BYOL paper
    accumulate_grad_batches=2048 // 50,
    weights_summary=None
)

lr_finder = trainer.tuner.lr_find(byol, train_dl, val_dl)

fig = lr_finder.plot(); fig.show()
suggested_lr = lr_finder.suggestion()

#hparams.lr = suggested_lr
#trainer.fit(byol, train_dl, val_dl, hparams)

In [None]:
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/