In [None]:
!pip install lightly split-folders

In [None]:
import os
import sys
import tarfile
import errno
import numpy as np
import matplotlib.pyplot as plt
import time
import urllib.request
import random
import shutil
import copy
import math
from typing import List, Tuple

import pytorch_lightning as pl
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import GradientAccumulationScheduler

import torch
from torch import Tensor
import torch.nn as nn
from torch.nn import Identity, ModuleList
from torch.nn import functional as F
from torch.optim import SGD
from torchvision.models import resnet50
import torchvision.transforms as transforms
import torchvision.datasets

from lightly.data import LightlyDataset
from lightly.loss import SwaVLoss
from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes
from lightly.models.modules.memory_bank import MemoryBankModule
from lightly.utils.dist import print_rank_zero
from lightly.models.utils import (
    batch_shuffle,
    batch_unshuffle,
    get_weight_decay_parameters,
    update_momentum,
    deactivate_requires_grad
)
from lightly.transforms import SwaVTransform, utils
from lightly.utils.benchmarking import OnlineLinearClassifier
from lightly.utils.scheduler import CosineWarmupScheduler
from lightly.utils.lars import LARS
from imageio import imsave
from tqdm import tqdm
import splitfolders

CROP_COUNTS: Tuple[int, int] = (2, 6)

In [None]:
from google.colab import drive
from google.colab import files
drive.mount('/content/drive')

In [None]:
num_workers = 2
batch_size = 64
memory_bank_size = 4096
seed = 1
max_epochs=200

## Downloading the STL10 Dataset

The STL10 dataset consists of 5000 training images and 8000 test images. To ensure a more effective training of the classification head, we will split the dataset into a 70-15-15 ratio for training, validation, and testing respectively.



In [None]:
stl10_dataset = torchvision.datasets.STL10('/content/stl10/', download=True)

In [None]:
HEIGHT = 96
WIDTH = 96
DEPTH = 3

SIZE = HEIGHT * WIDTH * DEPTH

DATA_DIR = '/content/Datasets'
# DATA_URL = 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz'

TRAIN_DATA_PATH = '/content/stl10/stl10_binary/train_X.bin'
TRAIN_LABEL_PATH = '/content/stl10/stl10_binary/train_y.bin'

TEST_DATA_PATH = '/content/stl10/stl10_binary/test_X.bin'
TEST_LABEL_PATH = '/content/stl10/stl10_binary/test_y.bin'

UNLAB_DATA_PATH = '/content/stl10/stl10_binary/unlabeled_X.bin'

In [None]:
def read_single_image(image_file):

  image = np.fromfile(image_file, dtype=np.uint8, count=SIZE)

  image = np.reshape(image, (3, 96, 96))
  image = np.transpose(image, (2, 1, 0))
  return image

def plot_image(image):

  plt.imshow(image)
  plt.show()

In [None]:
with open(TRAIN_DATA_PATH) as f:
  image = read_single_image(f)
  plot_image(image)

In [None]:
def read_labels(path_to_labels):

    with open(path_to_labels, 'rb') as f:
        labels = np.fromfile(f, dtype=np.uint8)
        return labels

def read_all_images(path_to_data):

    with open(path_to_data, 'rb') as f:
        everything = np.fromfile(f, dtype=np.uint8)

        images = np.reshape(everything, (-1, 3, 96, 96))
        images = np.transpose(images, (0, 3, 2, 1))
        return images

def save_image(image, name):
    imsave("%s.png" % name, image, format="png")

def save_images(images, labels, types):
    i = 0
    for image in tqdm(images, position=0):
        label = labels[i]
        directory = DATA_DIR + '/' + types + '/' + str(label) + '/'
        try:
            os.makedirs(directory, exist_ok=True)
        except OSError as exc:
            if exc.errno == errno.EEXIST:
                pass
        # Append a timestamp to the filename
        timestamp = time.strftime("%Y%m%d-%H%M%S")
        filename = directory + str(i) + "_" + timestamp
        save_image(image, filename)
        i = i+1

def save_unlabelled_images(images):
    i = 0
    for image in tqdm(images, position=0):
        directory = DATA_DIR + '/' + 'unlabelled' + '/'
        try:
            os.makedirs(directory, exist_ok=True)
        except OSError as exc:
            if exc.errno == errno.EEXIST:
                pass
        filename = directory + str(i)
        save_image(image, filename)
        i = i+1


def create_val_dataset():
    train_image_path = DATA_DIR + "/test"
    folders = os.listdir(train_image_path)

    for folder in tqdm(folders, position=0):
        temp_dir = DATA_DIR +"/test/" + folder
        temp_image_list = os.listdir(temp_dir)

    for i in range(50):
        val_dir = DATA_DIR + "/val/" + folder
        try:
            os.makedirs(val_dir, exist_ok=True)
        except OSError as exc:

            if exc.errno == errno.EEXIST:
                pass
        image_name = random.choice(temp_image_list)
        temp_image_list.remove(image_name)
        old_name = temp_dir + '/' + image_name
        new_name = val_dir + '/' + image_name
        os.replace(old_name, new_name)

In [None]:
train_labels = read_labels(TRAIN_LABEL_PATH)
train_images = read_all_images(TRAIN_DATA_PATH)

test_labels = read_labels(TEST_LABEL_PATH)
test_images = read_all_images(TEST_DATA_PATH)

#unlabelled_images = read_all_images(UNLAB_DATA_PATH)
# !rm -rf Datasets
# save_images(train_images, train_labels, "test")
# save_images(test_images, test_labels, "train")
save_images(train_images, train_labels, "all")
save_images(test_images, test_labels, "all")
#save_unlabelled_images(unlabelled_images)

In [None]:
def count_images(data_dir):
  # Initialize a counter
  num_images = 0

  # Iterate through each subdirectory in the data directory
  for subdir in os.listdir(data_dir):
      sub_dir_path = os.path.join(data_dir, subdir)

      # Check if the subdirectory is actually a directory
      if os.path.isdir(sub_dir_path):
          # Iterate through each file in the subdirectory
          for file in os.listdir(sub_dir_path):
              # Check if the file is an image
              if file.endswith(".png"):
                  # Increment the counter
                  num_images += 1

  return num_images

print(f"Number of images: {count_images('Datasets/all')}")


In [None]:
splitfolders.ratio('Datasets/all', output="Datasets", seed=seed, ratio=(.7, .15, .15))
!rm -rf Datasets/all/

In [None]:
path_to_train = "/content/Datasets/train"
path_to_test = "/content/Datasets/test"
path_to_val = "/content/Datasets/val"
#path_to_unlabelled= "/content/Datasets/unlabelled"

## Downloading the Pneumonia Dataset



In [None]:
!pip install -q kaggle

In [None]:
from google.colab import files
files.upload()
# Choose the kaggle.json file that you downloaded for the API token

In [None]:
!rm -rf ~/.kaggle
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!cp kaggle.json ~/.kaggle/

In [None]:
!kaggle datasets download -d paultimothymooney/chest-xray-pneumonia

In [None]:
!unzip chest-xray-pneumonia.zip -d chest_xray1
!rm -rf chest_xray1/chest_xray/chest_xray
!rm -rf chest_xray1/chest_xray/__MACOSX/

In [None]:
!mkdir chest_xray
!cp -r chest_xray1/chest_xray/train/* -d chest_xray/
!cp -r chest_xray1/chest_xray/test/* -d chest_xray/
!cp -r chest_xray1/chest_xray/val/* -d chest_xray/

In [None]:
!rm -rf chest_xray1

In [None]:
import os
def count_images(data_dir):
  # Initialize a counter
  num_images = 0

  # Iterate through each subdirectory in the data directory
  for subdir in os.listdir(data_dir):
      sub_dir_path = os.path.join(data_dir, subdir)

      # Check if the subdirectory is actually a directory
      if os.path.isdir(sub_dir_path):
          # Iterate through each file in the subdirectory
          for file in os.listdir(sub_dir_path):
              # Check if the file is an image
              if file.endswith(".jpeg"):
                  # Increment the counter
                  num_images += 1

  return num_images

print(f"Number of images: {count_images('chest_xray')}")

In [None]:
import splitfolders
splitfolders.ratio('chest_xray', output="xray", seed=seed, ratio=(.7, .15, .15))

In [None]:
!rm -rf chest_xray

In [None]:
path_to_train = "xray/train"
path_to_test = "xray/test"
path_to_val = "xray/val"

## Downloading the CIFAR-10 Dataset

The CIFAR-10 dataset consists of 60,000 32x32 color images in 10 classes, with 6,000 images per class. To download the dataset, you can use the following steps:


In [None]:
!pip install -q kaggle
from google.colab import files
files.upload()
#upload your API key generated from your profile on kaggle

In [None]:
!rm -rf ~/.kaggle
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!cp kaggle.json ~/.kaggle/

In [None]:
!kaggle datasets download -d swaroopkml/cifar10-pngs-in-folders

In [None]:
!unzip cifar10-pngs-in-folders.zip -d cifar10-pngs-in-folders

In [None]:
!rm -rf cifar10
!mkdir cifar10 cifar10_temp cifar10/test
!cp -r cifar10-pngs-in-folders/cifar10/cifar10/test/* cifar10/test
!cp -r cifar10-pngs-in-folders/cifar10/cifar10/train/* cifar10_temp
!rm -rf cifar10-pngs-in-folders

In [None]:
splitfolders.ratio('cifar10_temp', output="cifar10", seed=seed, ratio=(0.8, 0.2), group_prefix=None)
!rm -rf cifar10_temp

In [None]:
path_to_test = "cifar10/test"
path_to_train = "cifar10/train"
path_to_val = "cifar10/val"

In [None]:
import os
def count_images(data_dir):
  # Initialize a counter
  num_images = 0

  # Iterate through each subdirectory in the data directory
  for subdir in os.listdir(data_dir):
      sub_dir_path = os.path.join(data_dir, subdir)

      # Check if the subdirectory is actually a directory
      if os.path.isdir(sub_dir_path):
          # Iterate through each file in the subdirectory
          for file in os.listdir(sub_dir_path):
              # Check if the file is an image
              if file.endswith(".png"):
                  # Increment the counter
                  num_images += 1

  return num_images

print(f"Number of images: {count_images(path_to_train)}")
print(f"Number of images: {count_images(path_to_test)}")
print(f"Number of images: {count_images(path_to_val)}")

## Data Augmentations and Transformations


In [None]:
pl.seed_everything(seed)

In [None]:
transform = SwaVTransform(crop_counts=CROP_COUNTS)


In [None]:
train_classifier_transforms = transforms.Compose(
    [
      transforms.Resize((224, 224)),
      transforms.RandomHorizontalFlip(),
      transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
    ]
)

# No additional augmentations for the test set
test_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
    ]
)

# Load the Dataset

In [None]:
dataset_train_SwaV = LightlyDataset(input_dir=path_to_train, transform=transform)

dataset_train_classifier = LightlyDataset(
    input_dir=path_to_train, transform=train_classifier_transforms
)

dataset_test = LightlyDataset(input_dir=path_to_test, transform=test_transforms)

dataset_val = LightlyDataset(input_dir=path_to_val, transform=test_transforms)


# unlabelled_dataset_train = LightlyDataset(
#     input_dir=path_to_unlabelled, transform=train_classifier_transforms
# )

In [None]:
dataloader_train_SwaV = torch.utils.data.DataLoader(
    dataset_train_SwaV,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers,
)

dataloader_train_classifier = torch.utils.data.DataLoader(
    dataset_train_classifier,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers,
)

dataloader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers,
)

dataloader_val = torch.utils.data.DataLoader(
    dataset_val,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers,
)

# dataloader_unlabelled = torch.utils.data.DataLoader(
#     unlabelled_dataset_train,
#     batch_size=batch_size,
#     shuffle=True,
#     drop_last=True,
#     num_workers=num_workers,
# )

# Model Definitions

In [None]:
class SwAV(LightningModule):
    def __init__(self, batch_size_per_device: int=64, num_classes: int=10) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.batch_size_per_device = batch_size_per_device

        resnet = resnet50()
        resnet.fc = Identity()  # Ignore classification head
        self.backbone = resnet
        self.projection_head = SwaVProjectionHead()
        self.prototypes = SwaVPrototypes(n_steps_frozen_prototypes=1)
        self.criterion = SwaVLoss(sinkhorn_gather_distributed=True)
        self.online_classifier = OnlineLinearClassifier(num_classes=num_classes)

        # Use a queue for small batch sizes (<= 256).
        self.start_queue_at_epoch = 15
        self.n_batches_in_queue = 15
        self.queues = ModuleList(
            [
                MemoryBankModule(
                    size=(self.n_batches_in_queue * self.batch_size_per_device, 128)
                )
                for _ in range(CROP_COUNTS[0])
            ]
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.backbone(x)

    def project(self, x: Tensor) -> Tensor:
        x = self.projection_head(x)
        return F.normalize(x, dim=1, p=2)

    def training_step(
        self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int
    ) -> Tensor:
        # Normalize the prototypes so they are on the unit sphere.
        self.prototypes.normalize()

        # The dataloader returns a list of image crops where the
        # first few items are high resolution crops and the rest are low
        # resolution crops.
        multi_crops, targets = batch[0], batch[1]

        # Forward pass through backbone and projection head.
        multi_crop_features = [
            self.forward(crops).flatten(start_dim=1) for crops in multi_crops
        ]
        multi_crop_projections = [
            self.project(features) for features in multi_crop_features
        ]

        # Get the queue projections and logits.
        queue_crop_logits = None
        with torch.no_grad():
            if self.current_epoch >= self.start_queue_at_epoch:
                # Start filling the queue.
                queue_crop_projections = _update_queue(
                    projections=multi_crop_projections[: CROP_COUNTS[0]],
                    queues=self.queues,
                )
                if batch_idx > self.n_batches_in_queue:
                    # The queue is filled, so we can start using it.
                    queue_crop_logits = [
                        self.prototypes(projections, step=self.current_epoch)
                        for projections in queue_crop_projections
                    ]

        # Get the rest of the multi-crop logits.
        multi_crop_logits = [
            self.prototypes(projections, step=self.current_epoch)
            for projections in multi_crop_projections
        ]

        # Calculate the SwAV loss.
        loss = self.criterion(
            high_resolution_outputs=multi_crop_logits[: CROP_COUNTS[0]],
            low_resolution_outputs=multi_crop_logits[CROP_COUNTS[0] :],
            queue_outputs=queue_crop_logits,
        )
        self.log(
            "train_loss",
            loss,
            prog_bar=True,
            sync_dist=True,
            batch_size=len(targets),
        )

        # Calculate the classification loss.
        cls_loss, cls_log = self.online_classifier.training_step(
            (multi_crop_features[0].detach(), targets), batch_idx
        )
        self.log_dict(cls_log, sync_dist=True, batch_size=len(targets))
        return loss + cls_loss

    def validation_step(
        self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int
    ) -> Tensor:
        images, targets = batch[0], batch[1]
        features = self.forward(images).flatten(start_dim=1)
        cls_loss, cls_log = self.online_classifier.validation_step(
            (features.detach(), targets), batch_idx
        )
        self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets))
        return cls_loss

    def configure_optimizers(self):
        # Don't use weight decay for batch norm, bias parameters, and classification
        # head to improve performance.
        params, params_no_weight_decay = get_weight_decay_parameters(
            [self.backbone, self.projection_head, self.prototypes]
        )
        optimizer = LARS(
            [
                {"name": "swav", "params": params},
                {
                    "name": "swav_no_weight_decay",
                    "params": params_no_weight_decay,
                    "weight_decay": 0.0,
                },
                {
                    "name": "online_classifier",
                    "params": self.online_classifier.parameters(),
                    "weight_decay": 0.0,
                },
            ],
            # Smaller learning rate for smaller batches: lr=0.6 for batch_size=256
            # scaled linearly by batch size to lr=4.8 for batch_size=2048.
            # See Appendix A.1. and A.6. in SwAV paper https://arxiv.org/pdf/2006.09882.pdf
            lr=0.6 * (self.batch_size_per_device * self.trainer.world_size) / 256,
            momentum=0.9,
            weight_decay=1e-6,
        )
        scheduler = {
            "scheduler": CosineWarmupScheduler(
                optimizer=optimizer,
                warmup_epochs=int(
                    self.trainer.estimated_stepping_batches
                    / self.trainer.max_epochs
                    * 10
                ),
                max_epochs=int(self.trainer.estimated_stepping_batches),
                end_value=0.0006
                * (self.batch_size_per_device * self.trainer.world_size)
                / 256,
            ),
            "interval": "step",
        }
        return [optimizer], [scheduler]


transform = SwaVTransform(crop_counts=CROP_COUNTS)


@torch.no_grad()
def _update_queue(
    projections: List[Tensor],
    queues: ModuleList,
):
    """Adds the high resolution projections to the queues and returns the queues."""

    if len(projections) != len(queues):
        raise ValueError(
            f"The number of queues ({len(queues)}) should be equal to the number of high "
            f"resolution inputs ({len(projections)})."
        )

    # Get the queue projections
    queue_projections = []
    for i in range(len(queues)):
        _, queue_proj = queues[i](projections[i], update=True)
        # Queue projections are in (num_ftrs X queue_length) shape, while the high res
        # projections are in (batch_size_per_device X num_ftrs). Swap the axes for interoperability.
        queue_proj = torch.permute(queue_proj, (1, 0))
        queue_projections.append(queue_proj)

    return queue_projections

### Linear Probing Classifier ###

In [None]:
class Classifier(pl.LightningModule):
    def __init__(self, backbone):
        super().__init__()
        # use the pretrained ResNet backbone
        self.backbone = backbone

        # freeze the backbone
        deactivate_requires_grad(backbone)

        # create a linear layer for our downstream classification model
        self.fc = nn.Linear(2048, 10)

        self.criterion = nn.CrossEntropyLoss()
        self.validation_step_outputs = []

    def forward(self, x):
        y_hat = self.backbone(x).flatten(start_dim=1)
        y_hat = self.fc(y_hat)
        return y_hat

    def training_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self.forward(x)
        loss = self.criterion(y_hat, y)
        self.log("train_loss_fc", loss)
        return loss

    def test_step(self, batch, batch_idx):
        x, y, _ = batch  # Assuming batch contains input data (x), labels (y), and any additional information (_)
        y_hat = self.forward(x)  # Forward pass to get predictions

        # Compute accuracy
        predictions = torch.argmax(y_hat, dim=1)  # Get predicted labels
        correct_predictions = (predictions == y).sum().item()  # Count correct predictions
        total_samples = len(y)  # Total number of samples in the batch
        accuracy = correct_predictions / total_samples  # Calculate accuracy

        self.log("test_accuracy", accuracy, on_step=False, on_epoch=True)  # Log accuracy for tracking

        return accuracy  # Return accuracy

    def on_train_epoch_end(self):
        self.custom_histogram_weights()

    # We provide a helper method to log weights in tensorboard
    # which is useful for debugging.
    def custom_histogram_weights(self):
        for name, params in self.named_parameters():
            self.logger.experiment.add_histogram(name, params, self.current_epoch)

    def validation_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self.forward(x)
        y_hat = torch.nn.functional.softmax(y_hat, dim=1)

        # calculate number of correct predictions
        _, predicted = torch.max(y_hat, 1)
        num = predicted.shape[0]
        correct = (predicted == y).float().sum()
        self.validation_step_outputs.append((num, correct))
        return num, correct

    def on_validation_epoch_end(self):
        # calculate and log top1 accuracy
        if self.validation_step_outputs:
            total_num = 0
            total_correct = 0
            for num, correct in self.validation_step_outputs:
                total_num += num
                total_correct += correct
            acc = total_correct / total_num
            self.log("val_acc", acc, on_epoch=True, prog_bar=True)
            self.validation_step_outputs.clear()

    def configure_optimizers(self):
        optim = torch.optim.SGD(self.fc.parameters(), lr=0.03, momentum=0.9, weight_decay=0.0005)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs)
        return [optim], [scheduler]

# Model Pretraining Code

In [None]:
# from torch.utils.data import DataLoader
# import torchvision.transforms as T

# batch_size_per_device = 64
# num_classes = 1000
# train_dataset = LightlyDataset(input_dir=path_to_train, transform=transform)
# train_dataloader = DataLoader(
#         train_dataset,
#         batch_size=batch_size_per_device,
#         shuffle=True,
#         num_workers=num_workers,
#         drop_last=True,
#         persistent_workers=False,
# )
# # Setup validation data.
# val_transform = T.Compose(
#     [
#         T.Resize(256),
#         T.CenterCrop(224),
#         T.ToTensor(),
#         T.Normalize(mean=[0.4467106, 0.43980986, 0.40664646], std=[0.26034098, 0.25657727, 0.27126738]),
#     ]
# )
# val_dataset = LightlyDataset(input_dir=str(path_to_val), transform=val_transform)
# val_dataloader = DataLoader(
#     val_dataset,
#     batch_size=batch_size_per_device,
#     shuffle=False,
#     num_workers=num_workers,
#     persistent_workers=False,
# )

In [None]:
# from lightly.utils.benchmarking import MetricCallback
# from pytorch_lightning.callbacks import (
#     DeviceStatsMonitor,
#     EarlyStopping,
#     LearningRateMonitor,
# )
# from pytorch_lightning.loggers import TensorBoardLogger
# # Check if CUDA (GPU) is available
# os.environ['RANK'] = '0'
# os.environ['WORLD_SIZE'] = '1'
# os.environ['MASTER_ADDR'] = 'localhost'
# os.environ['MASTER_PORT'] = '12355'
# if torch.cuda.is_available():
#     # Initialize the default process group
#     torch.distributed.init_process_group(backend='nccl')
#     # Now you can perform your training using PyTorch DistributedDataParallel
#     # or other distributed training techniques
# else:
#     print("CUDA is not available. Cannot perform distributed training.")

# accelerator = "gpu" if torch.cuda.is_available() else "cpu"

# metric_callback = MetricCallback()
# trainer = pl.Trainer(
#         max_epochs=102,
#         accelerator=accelerator,
#         devices=1,
#         callbacks=[
#             LearningRateMonitor(),
#             # Stop if training loss diverges.
#             EarlyStopping(monitor="train_loss", patience=int(1e12), check_finite=True),
#             DeviceStatsMonitor(),
#             metric_callback,
#         ],
#         logger=TensorBoardLogger(save_dir="/content/logs/", name="pretrain"),
#         precision="16-mixed",
#         sync_batchnorm=accelerator != "cpu",  # Sync batchnorm is not supported on CPU.
#         num_sanity_val_steps=0,
# )

# model = MoCoV2(batch_size_per_device=batch_size_per_device, num_classes=num_classes)
# trainer.fit(
#         model=model,
#         train_dataloaders=train_dataloader,
#         val_dataloaders=val_dataloader,
#         ckpt_path=ckpt_path,
# )
# for metric in ["val_online_cls_top1", "val_online_cls_top5"]:
#         print_rank_zero(f"max {metric}: {max(metric_callback.val_metrics[metric])}")

### Load the Checkpoint

!cp "/content/drive/MyDrive/Colab Notebooks/swav_epoch-99.ckpt" "/content/swav_epoch-99.ckpt"

In [None]:
ckpt_path = "/content/swav_epoch-99.ckpt"
model = SwAV.load_from_checkpoint("/content/swav_epoch-99.ckpt")

# Linear Probing

In [None]:
# # copy the classifier checkpoints from drive
# !cp "drive/My Drive/Colab Notebooks/classifier_model_stl10_200.ckpt" "/content/classifier_model_stl10_200.ckpt"

In [None]:
# till 5th epoch, it will accumulate every 8 batches. From 5th epoch
# till 9th epoch it will accumulate every 4 batches and after that no accumulation
# will happen. Note that you need to use zero-indexed epoch keys here
accumulator = GradientAccumulationScheduler(scheduling={0: 10, 4: 5, 8: 2})

classifier_trainer = pl.Trainer(
    max_epochs=100,  # Adjust the number of epochs as needed
    devices=1,
    accelerator="gpu",
    callbacks=accumulator,
)

classifier=Classifier(model.backbone)

# classifier_trainer.fit(
#     classifier,
#     dataloader_train_classifier,
#     dataloader_val,
#     #ckpt_path="classifier_model_cifar10.ckpt"
# )

# Save the model and checkpoints

In [None]:
classifier_trainer.save_checkpoint("classifier_swav_cifar10_100.ckpt")

In [None]:
!cp "classifier_swav_cifar10_100.ckpt" "drive/My Drive/Colab Notebooks/classifier_swav_cifar10_100.ckpt"

In [None]:
!zip -r lightning_logs_swav_cifar10.zip /content/lightning_logs

In [None]:
!cp "/content/lightning_logs_swav_cifar10.zip" "drive/My Drive/lightning_logs_swav_cifar10.zip"

In [None]:
!cp "drive/My Drive/Colab Notebooks/classifier_swav_cifar10_100.ckpt" "classifier_swav_cifar10_100.ckpt"

# Test the model

In [None]:
classifier_trainer.test(classifier, dataloader_train_classifier, ckpt_path="classifier_swav_cifar10_100.ckpt") #training accuracy

In [None]:
classifier_trainer.test(classifier, dataloader_test, ckpt_path="classifier_swav_cifar10_100.ckpt")

# Download and view logs

#connect to google drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!zip -r lightning_logs_SwaV_stl_10.zip /content/lightning_logs

In [None]:
!cp "drive/My Drive/Colab Notebooks/classifier_model_SwaV_stl10.ckpt" "/content/classifier_model_stl10_200.ckpt"

In [None]:
!cp "drive/My Drive/lightning_logs_swav_cifar10.zip" "/content/lightning_logs_SwaV_stl_10.zip"

In [None]:
!rm -rf lightning_logs

In [None]:
!unzip "lightning_logs_SwaV_stl_10.zip" -d "/content/lightning_logs1"
!mv lightning_logs1/content/* .
!rm -rf lightning_logs1

In [None]:
%reload_ext tensorboard
%tensorboard --logdir lightning_logs