# BYOL Demo

An end-to-end demonstration of BYOL in action, using STL10 as a toy dataset.

### Install dependencies

PyTorch and Torchvision are pre-installed in Colab instances, so no need to worry about those.

In [None]:
# Install dependencies.  Note that pytorch and torchvision are pre-installed 
# in standard Colab instances, so no need to worry about those.
!pip install -q kornia pytorch_lightning
!pip install torchinfo

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m612.0/612.0 KB[0m [31m39.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m715.6/715.6 KB[0m [31m55.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 KB[0m [31m49.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m60.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m158.8/158.8 KB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.2/114.2 KB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m264.6/264.6 KB[0m [31m23.2 MB/s[0m eta [36m0:00:00[0m
[?25hLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchinfo
  Downloading torchinfo-1.7.2

In [None]:
!pip install pytorch_lightning==1.9.3

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch_lightning==1.9.3
  Downloading pytorch_lightning-1.9.3-py3-none-any.whl (826 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m826.4/826.4 KB[0m [31m45.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pytorch_lightning
  Attempting uninstall: pytorch_lightning
    Found existing installation: pytorch-lightning 2.0.0
    Uninstalling pytorch-lightning-2.0.0:
      Successfully uninstalled pytorch-lightning-2.0.0
Successfully installed pytorch_lightning-1.9.3


### Data Augmentations

In [None]:
import random
from typing import Callable, Tuple

from kornia import augmentation as aug
from kornia import filters
from kornia.geometry import transform as tf
import torch
from torch import nn, Tensor
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

In [None]:
class RandomApply(nn.Module):
    def __init__(self, fn: Callable, p: float):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x: Tensor) -> Tensor:
        return x if random.random() > self.p else self.fn(x)


def default_augmentation(image_size: Tuple[int, int] = (224, 224)) -> nn.Module:
    return nn.Sequential(
        tf.Resize(size=image_size),
        RandomApply(aug.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
        aug.RandomGrayscale(p=0.2),
        aug.RandomHorizontalFlip(),
        RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
        aug.RandomResizedCrop(size=image_size),
        aug.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225]),
        ),
    )

### Encoder Wrapper

Enables us to work with *any* model, not just ResNet18 which I chose for this demo.

In [None]:
from typing import Union


def mlp(dim: int, projection_size: int = 256, hidden_size: int = 4096) -> nn.Module:
    return nn.Sequential(
        nn.Linear(dim, hidden_size),
        nn.BatchNorm1d(hidden_size),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, projection_size),
    )


class EncoderWrapper(nn.Module):
    def __init__(
        self,
        model: nn.Module,
        projection_size: int = 256,
        hidden_size: int = 4096,
        layer: Union[str, int] = -2,
    ):
        super().__init__()
        self.model = model
        self.projection_size = projection_size
        self.hidden_size = hidden_size
        self.layer = layer

        self._projector = None
        self._projector_dim = None
        self._encoded = torch.empty(0)
        self._register_hook()

    @property
    def projector(self):
        if self._projector is None:
            self._projector = mlp(
                self._projector_dim, self.projection_size, self.hidden_size
            )
        return self._projector

    def _hook(self, _, __, output):
        output = output.flatten(start_dim=1)
        if self._projector_dim is None:
            self._projector_dim = output.shape[-1]
        self._encoded = self.projector(output)

    def _register_hook(self):
        if isinstance(self.layer, str):
            layer = dict([*self.model.named_modules()])[self.layer]
        else:
            layer = list(self.model.children())[self.layer]

        layer.register_forward_hook(self._hook)

    def forward(self, x: Tensor) -> Tensor:
        _ = self.model(x)
        return self._encoded

### BYOL and Training Code

Encapsulate BYOL into a single module.  I use PyTorch Lightning here, because it makes training very easy.  This code also works for multi-GPU or TPU training, and experiments are logged automatically.

In [None]:
from copy import deepcopy
from itertools import chain
from typing import Dict, List

import pytorch_lightning as pl
from torch import optim
import torch.nn.functional as f


def normalized_mse(x: Tensor, y: Tensor) -> Tensor:
    x = f.normalize(x, dim=-1)
    y = f.normalize(y, dim=-1)
    return 2 - 2 * (x * y).sum(dim=-1)


class BYOL(pl.LightningModule):
    def __init__(
        self,
        model: nn.Module,
        image_size: Tuple[int, int] = (128, 128),
        hidden_layer: Union[str, int] = -2,
        projection_size: int = 256,
        hidden_size: int = 4096,
        augment_fn: Callable = None,
        beta: float = 0.999,
        **hparams,
    ):
        super().__init__()
        self.augment = default_augmentation(image_size) if augment_fn is None else augment_fn
        self.beta = beta
        self.encoder = EncoderWrapper(
            model, projection_size, hidden_size, layer=hidden_layer
        )
        self.predictor = nn.Linear(projection_size, projection_size, hidden_size)
        # self.hparams = hparams
        self._target = None

        self.encoder(torch.zeros(2, 3, *image_size))

    def forward(self, x: Tensor) -> Tensor:
        return self.predictor(self.encoder(x))

    @property
    def target(self):
        if self._target is None:
            self._target = deepcopy(self.encoder)
        return self._target

    def update_target(self):
        for p, pt in zip(self.encoder.parameters(), self.target.parameters()):
            pt.data = self.beta * pt.data + (1 - self.beta) * p.data

    # --- Methods required for PyTorch Lightning only! ---

    def configure_optimizers(self):
        optimizer = getattr(optim, self.hparams.get("optimizer", "Adam"))
        lr = self.hparams.get("lr", 1e-4)
        weight_decay = self.hparams.get("weight_decay", 1e-6)
        return optimizer(self.parameters(), lr=lr, weight_decay=weight_decay)

    def training_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]:
        x = batch[0]
        with torch.no_grad():
            x1, x2 = self.augment(x), self.augment(x)

        pred1, pred2 = self.forward(x1), self.forward(x2)
        with torch.no_grad():
            targ1, targ2 = self.target(x1), self.target(x2)
        loss = torch.mean(normalized_mse(pred1, targ2) + normalized_mse(pred2, targ1))

        self.log("train_loss", loss.item())
        self.update_target()

        return {"loss": loss}

    @torch.no_grad()
    def validation_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]:
        x = batch[0]
        x1, x2 = self.augment(x), self.augment(x)
        pred1, pred2 = self.forward(x1), self.forward(x2)
        targ1, targ2 = self.target(x1), self.target(x2)
        loss = torch.mean(normalized_mse(pred1, targ2) + normalized_mse(pred2, targ1))

        return {"loss": loss}

    @torch.no_grad()
    def validation_epoch_end(self, outputs: List[Dict]) -> Dict:
        val_loss = sum(x["loss"] for x in outputs) / len(outputs)
        self.log("val_loss", val_loss.item())

### Supervised Training Module

We also need a Lightning module for supervised training on STL10, after any self-supervised training has completed.  There's not much special here -- just standard supervised training.

In [None]:
class SupervisedLightningModule(pl.LightningModule):
    def __init__(self, model: nn.Module, **hparams):
        super().__init__()
        self.model = model
        self.save_hyperparameters()

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

    def configure_optimizers(self):
        optimizer = getattr(optim, self.hparams.get("optimizer", "Adam"))
        lr = self.hparams.get("lr", 1e-4)
        weight_decay = self.hparams.get("weight_decay", 1e-6)
        return optimizer(self.parameters(), lr=lr, weight_decay=weight_decay)

    def training_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]:
        x, y = batch
        loss = f.cross_entropy(self.forward(x), y)
        self.log("train_loss", loss.item())
        return {"loss": loss}

    @torch.no_grad()
    def validation_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]:
        x, y = batch
        loss = f.cross_entropy(self.forward(x), y)
        self.log("val_loss", loss.item()) 
        return {"loss": loss}

    @torch.no_grad()
    def validation_epoch_end(self, outputs: List[Dict]) -> Dict:
        val_loss = sum(x["loss"] for x in outputs) / len(outputs)
        self.log("val_loss", val_loss.item())

### Supervised Training without BYOL

Run through supervised training, and measure the accuracy.  Performance should be pretty good already. 

In [None]:
!unrar x -y "/content/drive/MyDrive/TRAIN CACD/CACD73.rar"

In [None]:
from torchvision import datasets, models, transforms

input_path = "./CACD73"
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

data_transforms = {
    'train':
    transforms.Compose([
        transforms.Resize((224,224)),
        transforms.RandomAffine(0, shear=10, scale=(0.8,1.2)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ]),

    'validation':
    transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        normalize
    ]),
    'test':
    transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        normalize
    ])
}

image_datasets = {
    'train': 
    datasets.ImageFolder(input_path + '/train', data_transforms['train']),
    'validation': 
    datasets.ImageFolder(input_path + '/val', data_transforms['validation']),
    'test': 
    datasets.ImageFolder(input_path + '/test', data_transforms['test'])
}


train_dataloader = torch.utils.data.DataLoader(image_datasets['train'],
                                 batch_size=16,
                                 shuffle=True)
val_dataloader = torch.utils.data.DataLoader(image_datasets['validation'],
                                 batch_size=16,
                                 shuffle=False)
test_dataloader = torch.utils.data.DataLoader(image_datasets['test'],
                                 batch_size=16,
                                 shuffle=False)

class_names = image_datasets['train'].classes

In [None]:
from os import cpu_count

from torch.utils.data import DataLoader
from torchvision.models import resnet18, resnet50

#load pretrain
model = resnet50(pretrained=True)
num_ftrs = model.fc.in_features  
model.fc = nn.Sequential(
               nn.Linear(num_ftrs, 512),
               nn.ReLU(inplace=True),
               nn.Dropout(0.4),
               nn.Linear(512, len(class_names)))





In [None]:
from torchinfo import summary
summary(model, input_size=(8, 3, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [8, 2000]                 --
├─Conv2d: 1-1                            [8, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [8, 64, 112, 112]         128
├─ReLU: 1-3                              [8, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [8, 64, 56, 56]           --
├─Sequential: 1-5                        [8, 256, 56, 56]          --
│    └─Bottleneck: 2-1                   [8, 256, 56, 56]          --
│    │    └─Conv2d: 3-1                  [8, 64, 56, 56]           4,096
│    │    └─BatchNorm2d: 3-2             [8, 64, 56, 56]           128
│    │    └─ReLU: 3-3                    [8, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [8, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-5             [8, 64, 56, 56]           128
│    │    └─ReLU: 3-6                    [8, 64, 56, 56]           --
│ 

In [None]:
checkpoint_callback = ModelCheckpoint(
                                      dirpath='./checkpoint_model/',
                                      filename='sample-CACD-{epoch:02d}-{val_loss:.2f}',
                                      save_top_k = 1,
                                      monitor='val_loss',
                                      mode='min'
                                      )

In [None]:
# early_stop_callback = EarlyStopping(
#                                     monitor="val_loss",
#                                     mode="min",
#                                     patience=10)

In [None]:
supervised = SupervisedLightningModule(model)
trainer = pl.Trainer(
                     max_epochs=70,
                     gpus= 1,
                     callbacks=[checkpoint_callback],
                     enable_model_summary=None)

trainer.fit(supervised, train_dataloader, val_dataloader)

  rank_zero_warn(
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!unrar x -y "/content/drive/MyDrive/vip/FPT.rar"

In [None]:
def accuracy(pred: Tensor, labels: Tensor) -> float:
    return (pred.argmax(dim=-1) == labels).float().mean().item()

In [None]:
torch.save(model.state_dict(), './CACD_r50_baseline_1.pth')

In [None]:
model_test = resnet50() 
model_test.fc = nn.Sequential(
               nn.Linear(num_ftrs, 512),
               nn.ReLU(inplace=True),
               nn.Dropout(0.4),
               nn.Linear(512, 2000))

model_test.load_state_dict(torch.load('/content/CACD_r50_baseline_1.pth'))
model_test.eval()

model_test.cuda()
acc = sum([accuracy(model_test(x.cuda()), y.cuda()) for x, y in val_dataloader]) / len(val_dataloader)
print(f"Accuracy: {acc:.3f}")

Accuracy: 0.761


In [None]:
import glob
t = glob.glob('./checkpoint_model/*')
print(t)

['./checkpoint_model/sample-CACD-epoch=13-val_loss=1.81.ckpt']


In [None]:
!nvidia-smi

Mon Mar 20 10:32:39 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   58C    P0    28W /  70W |  15097MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
torch.cuda.empty_cache()

In [None]:
test = SupervisedLightningModule.load_from_checkpoint(t[0])
test.eval()


test.cuda()
acc = sum([accuracy(test(x.cuda()), y.cuda()) for x, y in val_dataloader]) / len(val_dataloader)
print(f"Accuracy: {acc:.3f}")

INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.9.4 to v2.0.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file checkpoint_model/sample-CACD-epoch=13-val_loss=1.81.ckpt`


Accuracy: 0.757


In [None]:
torch.cuda.empty_cache()

### Self-Supervised Training with BYOL

Perform our self-supervised training.  This is the most computationally intensive part of the script.  It takes roughly 45 minutes on a standard Colab instance with a K80 GPU.

In [None]:
checkpoint_callback = ModelCheckpoint(
                                      dirpath='./checkpoint_BYOL/',
                                      filename='BYOL-CACD-{epoch:02d}-{val_loss:.2f}',
                                      save_top_k = 1,
                                      monitor='val_loss',
                                      mode='min'
                                      )

In [None]:
model_SSL = resnet50(weights = 'ResNet50_Weights.IMAGENET1K_V1')
model_SSL.fc = nn.Sequential(
               nn.Linear(num_ftrs, 512),
               nn.ReLU(inplace=True),
               nn.Dropout(0.4),
               nn.Linear(512, len(class_names)))
byol = BYOL(model_SSL, image_size=(224, 224))
trainer = pl.Trainer(
    max_epochs=50,
    gpus=-1,
    accumulate_grad_batches=2048 // 128,
    callbacks = [checkpoint_callback],
    enable_model_summary=None,
)

trainer.fit(byol, train_dataloader, val_dataloader)

In [None]:
torch.save(model_SSL.state_dict(), './CACD_BYOL_baseline_1.pth')

### Supervised Training Again

Extract the state dictionary from BYOL, and load it into our ResNet18 model before starting training.  Then run supervised training, and watch the accuracy improve from last time!

In [None]:
# Extract the state dictionary, initialize a new ResNet18 model,
# and load the state dictionary into the new model.
#
# This ensures that we remove all hooks from the previous model,
# which are automatically implemented by BYOL.
state_dict = model_SSL.state_dict()
model_improved = resnet50()
model_improved.fc = nn.Sequential(
               nn.Linear(num_ftrs, 512),
               nn.ReLU(inplace=True),
               nn.Dropout(0.4),
               nn.Linear(512, len(class_names)))
model_improved.load_state_dict(state_dict)

In [None]:
checkpoint_callback = ModelCheckpoint(
                                      dirpath='./checkpoint_improved/',
                                      filename='sample-CACD-{epoch:02d}-{val_loss:.2f}',
                                      save_top_k = 1,
                                      monitor='val_loss',
                                      mode='min'
                                      )

In [None]:
supervised = SupervisedLightningModule(model_improved)
trainer = pl.Trainer(
    max_epochs=70,
    gpus=-1,
    callbacks = checkpoint_callback,
    enable_model_summary=None,
)

trainer.fit(supervised, train_dataloader, val_dataloader)


In [None]:
torch.save(model_improved.state_dict(), './CACD_r50_improved_1.pth')

In [None]:
model_test_2 = resnet50()
model_test_2.fc = nn.Sequential(
               nn.Linear(num_ftrs, 512),
               nn.ReLU(inplace=True),
               nn.Dropout(0.4),
               nn.Linear(512, len(class_names)))

model_test_2.load_state_dict(torch.load('./CACD_r50_improved_1.pth'))
model_test_2.eval()

model_test_2.cuda()
acc = sum([accuracy(model_test_2(x.cuda()), y.cuda()) for x, y in val_dataloader]) / len(val_dataloader)
print(f"Accuracy: {acc:.3f}")

#clear cache
torch.cuda.empty_cache()

Accuracy: 0.769


In [None]:
t = glob.glob('./checkpoint_improved/*')
name = t[0].split("/")[-1]

test2 = SupervisedLightningModule.load_from_checkpoint(t[0])
test2.eval()


test2.cuda()
acc = sum([accuracy(test2(x.cuda()), y.cuda()) for x, y in val_dataloader]) / len(val_dataloader)
print(f"Accuracy: {acc:.3f}")

#clear cache
torch.cuda.empty_cache()

INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.9.4 to v2.0.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file checkpoint_improved/sample-CACD-epoch=20-val_loss=1.96.ckpt`


Accuracy: 0.767
