# BYOL

In this notebook we are going to implement [BYOL: Bootstrap Your Own Latent](https://arxiv.org/pdf/2006.07733.pdf) and compare the results of a classification task before and after pretraining the model with BYOL.

In [14]:
from google.colab import drive
drive.mount('/content/gdrive')
import sys
sys.path.append('/content/gdrive//My Drive/Colab Notebooks/Deep_hw1/resnet.py')
import resnet


Mounted at /content/gdrive


### Data Augmentations

In [15]:
import random
from typing import Callable, Tuple
import torch
import torchvision
from torch import nn, Tensor
from torchvision import transforms as T
from torch.nn import functional as F


class RandomApply(nn.Module):
    def __init__(self, fn: Callable, p: float):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x: Tensor) -> Tensor:
        return x if torch.rand(1).item() > self.p else self.fn(x)


def default_augmentation(image_size: Tuple[int, int] = (224, 224)) -> nn.Module:
    """
        1. resize images to 'image_size'
        2. RandomApply color jitter
        3. RandomApply grayscale
        4. RandomApply horizon flip
        5. RandomApply gaussian blur with kernel_size(3, 3), sigma=(1.5, 1.5)
        6. RandomApply ResizedCrop to 'image_size'
        7. Normalize
        choosing hyperparameters that are not mentioned is up to you
    """
    return nn.Sequential(
        T.Resize(image_size),
        RandomApply(
            T.ColorJitter(0.8, 0.8, 0.8, 0.2),
            p=0.8
        ),
        RandomApply(
            T.Grayscale(num_output_channels=3),
            p=0.2
        ),
        T.RandomHorizontalFlip(),
        RandomApply(
            T.GaussianBlur(kernel_size=3, sigma=(1.5, 1.5)),
            p=0.1
        ),
        T.RandomResizedCrop(image_size),
        T.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225]),
        )
    )

# Model
We will use ResNet18 as our representation model.

In [16]:
def get_encoder_model():
    resnet = torchvision.models.resnet18(pretrained=False)
    # remove last fully-connected layer
    #resnet.fc = nn.Identity()
    resnet = nn.Sequential(*list(resnet.children())[:-1])
    return resnet

### Loss Function
We need to use NormalizedMSELoss as our loss function.
$$NormalizedMSELoss(v_1, v_2) = \Vert \bar{v_1} - \bar{v_2}\Vert_2^2 = 2 - 2.\frac{\langle v_1, v_2 \rangle}{\Vert v_1\Vert_2 \Vert v_2\Vert_2}$$

In [17]:
class NormalizedMSELoss(nn.Module):
    def __init__(self) -> None:
        super(NormalizedMSELoss,self).__init__()

    def forward(self, view1: Tensor, view2: Tensor) -> Tensor:
        view1_norm = F.normalize(view1, p=2, dim=-1)
        view2_norm = F.normalize(view2, p=2, dim=-1)
        return 2 - 2 * (view1_norm * view2_norm).sum(dim=-1)

### MLP
Here you will implement a simple MLP class with one hidden layer with BatchNorm and ReLU activation, and a linear output layer. This class will be used for both the projections and the prediction networks.

In [18]:
class MLP(nn.Module):
    def __init__(self, input_dim: int, projection_dim: int = 256, hidden_dim: int = 4096) -> None:
        super(MLP, self).__init__()

        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, projection_dim)
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.layers(x)

### Encoder + Projector Network
This is the network structure that is shared between online and target networks. It consists of our encoder model, followed by a projection MLP.

In [19]:
class EncoderProjecter(nn.Module):
    def __init__(self,
                 encoder: nn.Module,
                 hidden_dim: int = 4096,
                 projection_out_dim: int = 256
                 ) -> None:
        super(EncoderProjecter, self).__init__()

        # your code
        self.encoder = encoder
        self.projector = MLP(encoder.output_dim, projection_out_dim, hidden_dim)
    def forward(self, x: Tensor) -> Tensor:
        # your code
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        x = self.projector(x)
        return x

## BYOL

In [20]:
import copy

class BYOL(nn.Module):
    def __init__(self,
                 model: nn.Module,
                 hidden_dim: int = 4096,
                 projection_out_dim: int = 256,
                 target_decay: float = 0.99

                ) -> None:
        super(BYOL, self).__init__()

        # your code
        self.online_network = EncoderProjecter(model, hidden_dim, projection_out_dim)
        self.online_predictor = MLP(projection_out_dim, projection_out_dim, hidden_dim)

        self.target_network = copy.deepcopy(self.online_network)  # init with copy of parameters of online network
        # set target_network's weights to be untrainable

        self.target_network.eval()

        for param in self.target_network.parameters():
            param.requires_grad = False

        self.loss_function = NormalizedMSELoss()
        self.target_decay = target_decay

    @torch.no_grad()
    def soft_update_target_network(self) -> None:
        for online_params, target_params in zip(self.online_network.parameters(), self.target_network.parameters()):
            target_params.data = self.target_decay * target_params.data + (1 - self.target_decay) * online_params.data


    def forward(self, view) -> Tuple[Tensor]:
        online_projection = self.online_network(view)
        target_projection = self.target_network(view)
        return online_projection, target_projection


    def loss(self, view1, view2):
        online_projection1, target_projection1 = self.forward(view1)
        online_projection2, target_projection2 = self.forward(view2)

        online_prediction1 = self.online_predictor(online_projection1)
        online_prediction2 = self.online_predictor(online_projection2)

        loss1 = self.loss_function(online_prediction1, target_projection2.detach())
        loss2 = self.loss_function(online_prediction2, target_projection1.detach())

        return (loss1 + loss2) / 2

# STL10 Datasets

We need 3 separate datasets from STL10 for this experiment:
1. `"train"` -- Contains only labeled training images. Used for supervised training.
2. `"train+unlabeled"` -- Contains training images, plus a large number of unlabelled images.  Used for self-supervised learning with BYOL.
3. `"test"` -- Labeled test images.  We use it both as a validation set, and for computing the final model accuracy.

In [21]:
from torchvision.datasets import STL10
from torchvision.transforms import ToTensor


TRAIN_DATASET = STL10(root="data", split="train", download=True, transform=ToTensor())
TRAIN_UNLABELED_DATASET = STL10(root="data", split="train+unlabeled", download=True, transform=ToTensor())
TEST_DATASET = STL10(root="data", split="test", download=True, transform=ToTensor())

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Create dataloaders:

In [22]:
import torch
from torch.utils.data import DataLoader

train_dataloader = DataLoader(TRAIN_DATASET, batch_size=64, shuffle=True)
train_unlabeled_dataloader = DataLoader(TRAIN_UNLABELED_DATASET, batch_size=64, shuffle=True)
test_dataloader = DataLoader(TEST_DATASET, batch_size=64, shuffle=False)

# Supervised Training without BYOL

First create a classifier model by simply adding a linear layer on top of the encoder model. Then train the model using the labeled training set. Performance should be pretty good already.

In [23]:
encoder = get_encoder_model()
classifier = nn.Sequential(
    encoder,
    nn.Flatten(),
    nn.Linear(512, 10)
)



### Self-Supervised Training with BYOL

Now perform the self-supervised training. This is the most computationally intensive part of the script.

In [25]:
# Initialize BYOL and optimizer
byol = BYOL(encoder)
optimizer = torch.optim.Adam(byol.parameters(), lr=3e-4)

# Training loop
num_epochs = 100

for epoch in range(num_epochs):
    for data in train_unlabeled_dataloader:
        view1, view2 = data['image1'], data['image2']

        optimizer.zero_grad()
        loss = byol.loss(view1, view2)
        loss.backward()
        optimizer.step()

        byol.soft_update_target_network()

    print("Epoch: {}/{}, Loss: {:.4f}".format(epoch + 1, num_epochs, loss.item()))

AttributeError: ignored

### Supervised Training Again

Extract the encoder network's state dictionary from BYOL, and load it into our ResNet18 model before starting training.  Then run supervised training, and watch the accuracy improve from last time!

In [None]:
# Load the pretrained weights from BYOL
encoder.load_state_dict(byol.online_network.encoder.state_dict())

# Train the classifier using the labeled data
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-4)

num_epochs = 10
for epoch in range(num_epochs):
    for images, labels in train_dataloader:
        optimizer.zero_grad()
        outputs = classifier(images)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()

    # Evaluate the classifier on the test set
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_dataloader:
            outputs = classifier(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print("Epoch: {}/{}, Loss: {:.4f}, Test Accuracy: {:.2f}%".format(epoch + 1, num_epochs, loss.item(), 100 * correct / total))