# 🧈 Training MOCOv2 on CIFAR10 🔥 With pytorch lightning 🧈

In [1]:
import torch
import torch.nn as nn
import torchvision
import pytorch_lightning as pl
import lightly

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning import loggers as pl_loggers

SEED = 1



# ⬇️ Build Dataset

In [2]:
# DATA hyperparams
num_workers = 6
moco_batch_size = 512
classifier_train_batch_size = 512
classifier_test_batch_size = 512

In [3]:
# The dataset structure should be like this:
# cifar10/train/
#  L airplane/
#    L 10008_airplane.png
#    L ...
#  L automobile/
#  L bird/
#  L cat/
#  L deer/
#  L dog/
#  L frog/
#  L horse/
#  L ship/
#  L truck/
path_to_train = './data/cifar10_lightly/train/'
path_to_test = './data/cifar10_lightly/test/'

### Augmentations

In [4]:
################### Classifier Augmentations ###################
# Augmentations typically used to train on cifar-10
train_classifier_transforms = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(32, padding=4),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=lightly.data.collate.imagenet_normalize['mean'],
        std=lightly.data.collate.imagenet_normalize['std'],
    )
])

# No additional augmentations for the test set
test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32, 32)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=lightly.data.collate.imagenet_normalize['mean'],
        std=lightly.data.collate.imagenet_normalize['std'],
    )
])

################### MOCO Augmentations ###################
# MoCo v2 uses SimCLR augmentations, additionally, disable blur
collate_fn = lightly.data.SimCLRCollateFunction(
    input_size=32,
    gaussian_blur=0.,
)

### Datasets

In [5]:
################### Classifier Datasets ###################
#Since we also train a linear classifier on the pre-trained moco model we
# reuse the test augmentations here (MoCo augmentations are very strong and
# usually reduce accuracy of models which are not used for contrastive learning.
# Our linear layer will be trained using cross entropy loss and labels provided
# by the dataset. Therefore we chose light augmentations.)
dataset_train_classifier = lightly.data.LightlyDataset(
    input_dir=path_to_train,
    transform=train_classifier_transforms
)

dataset_test = lightly.data.LightlyDataset(
    input_dir=path_to_test,
    transform=test_transforms
)

################### MOCO Dataset ###################
# We use the moco augmentations for training moco
dataset_train_moco = lightly.data.LightlyDataset(
    input_dir=path_to_train
)

### Dataloaders

In [6]:
################### Classifier Dataloaders ###################
dataloader_train_classifier = torch.utils.data.DataLoader(
    dataset_train_classifier,
    batch_size=classifier_train_batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers
)

dataloader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=classifier_test_batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)

################### MOCO Dataloader ###################
dataloader_train_moco = torch.utils.data.DataLoader(
    dataset_train_moco,
    batch_size=moco_batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=True,
    num_workers=num_workers
)

# 🧠 Load Model

In [7]:
# MODEL hyperparams
memory_bank_size = 4096
moco_max_epochs = 10000
downstream_max_epochs = 60
downstream_test_every = 300

In [8]:
import moco_model_RESNETCOPY

In [9]:
model = moco_model_RESNETCOPY.MocoModel(memory_bank_size, moco_max_epochs, 
                             downstream_max_epochs, dataloader_train_classifier, dataloader_test,
                            downstream_test_every=downstream_test_every)

 -- 0 -- 


In [10]:
print('--- Training these layers ---')
for name,param in model.named_parameters():
    if param.requires_grad is True:
        print(name, param.requires_grad)

--- Training these layers ---
resnet_moco.backbone.8.weight True
resnet_moco.backbone.8.bias True
resnet_moco.backbone.10.weight True
resnet_moco.backbone.10.bias True
resnet_moco.projection_head.0.weight True
resnet_moco.projection_head.0.bias True
resnet_moco.projection_head.2.weight True
resnet_moco.projection_head.2.bias True


In [11]:
model.resnet_moco.parameters()

<generator object Module.parameters at 0x7f9cef1df970>

I will continue training from 900 epoch model

In [12]:
# # # WHEN LOADING A SAVED MODEL DO IT LIKE DIS
# model.load_from_checkpoint('./saved_models/resnet_moco/epoch=2989-train_loss_ssl=1.37.ckpt',
#                           memory_bank_size=memory_bank_size, moco_max_epochs=moco_max_epochs,
#                           downstream_max_epochs=downstream_max_epochs, dataloader_train_classifier=dataloader_train_classifier,
#                            dataloader_test=dataloader_test,
#                            downstream_test_every=downstream_test_every)
# #                           last_epoch=2989);
# # checkpoint_callback = ModelCheckpoint(dirpath="./saved_models/resnet_moco/epoch=897-train_loss_ssl=1.67.ckpt")

In [13]:
# you can also define a checkpoint callback to save best model like keras.
checkpoint_callback = ModelCheckpoint(
    dirpath='./saved_models/resnet_moco_RESNETCOPY_extralinear',
    filename='{epoch}-{train_loss_ssl:.2f}-2head',
    save_top_k=5,
    verbose=True,
    monitor='train_loss_ssl',
    mode='min'
)
lr_monitor = LearningRateMonitor(logging_interval='step')

# 🏋️‍♂️ Train

In [14]:
# use a GPU if available
gpus = 1 if torch.cuda.is_available() else 0
print(f'Using gpu: {bool(gpus)}')
if(gpus == 0): print('--- NOT USING GPUS THIS TAKE LONG TIME ---')

# set up tensorboard logger
tb_logger = pl_loggers.TensorBoardLogger(save_dir='./lightning_logs/', name=f'RESNETCOPYmoco_{moco_max_epochs}eps')

Using gpu: True


In [15]:
# ### WHEN TRAINING FROM CHECKPOINT ###

# trainer = pl.Trainer(
#     resume_from_checkpoint="./saved_models/resnet_moco/epoch=2989-train_loss_ssl=1.37.ckpt",
#     max_epochs=moco_max_epochs,
#     gpus=gpus,
#     callbacks=[checkpoint_callback],
#     logger=tb_logger)
# trainer.fit(model, dataloader_train_moco)

In [None]:
## WHEN TRAINING FROM SCRATCH (EPOCH 0) ###

trainer = pl.Trainer(max_epochs=moco_max_epochs, gpus=1, callbacks=[checkpoint_callback, lr_monitor], logger=tb_logger)
trainer.fit(
    model,
    dataloader_train_moco
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type       | Params
-------------------------------------------
0 | resnet_moco | MoCo       | 24.0 M
1 | criterion   | NTXentLoss | 0     
-------------------------------------------
853 K     Trainable params
23.2 M    Non-trainable params
24.0 M    Total params
96.180    Total estimated model params size (MB)


Training: -1it [00:00, ?it/s]

Epoch 0, global step 96: train_loss_ssl reached 6.72615 (best 6.72615), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco_RESNETCOPY_extralinear/epoch=0-train_loss_ssl=6.73-2head.ckpt" as top 5
Epoch 1, global step 193: train_loss_ssl reached 6.57033 (best 6.57033), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco_RESNETCOPY_extralinear/epoch=1-train_loss_ssl=6.57-2head.ckpt" as top 5
Epoch 2, global step 290: train_loss_ssl reached 6.54647 (best 6.54647), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco_RESNETCOPY_extralinear/epoch=2-train_loss_ssl=6.55-2head.ckpt" as top 5
Epoch 3, global step 387: train_loss_ssl reached 6.42723 (best 6.42723), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco_RESNETCOPY_extralinear/epoch=3-train_loss_ssl=6.43-2head.ckpt" as top 5
Epoch 4, global step 484: train_loss_ssl reached 6.26957 (best 6.26957), saving model to "/home/shatz/Documents/more_

Epoch 50, global step 4946: train_loss_ssl was not in top 5
Epoch 51, global step 5043: train_loss_ssl was not in top 5
Epoch 52, global step 5140: train_loss_ssl was not in top 5
Epoch 53, global step 5237: train_loss_ssl was not in top 5
Epoch 54, global step 5334: train_loss_ssl reached 5.65616 (best 5.53780), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco_RESNETCOPY_extralinear/epoch=54-train_loss_ssl=5.66-2head.ckpt" as top 5
Epoch 55, global step 5431: train_loss_ssl was not in top 5
Epoch 56, global step 5528: train_loss_ssl was not in top 5
Epoch 57, global step 5625: train_loss_ssl was not in top 5
Epoch 58, global step 5722: train_loss_ssl was not in top 5
Epoch 59, global step 5819: train_loss_ssl reached 5.67377 (best 5.53780), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco_RESNETCOPY_extralinear/epoch=59-train_loss_ssl=5.67-2head.ckpt" as top 5
Epoch 60, global step 5916: train_loss_ssl was not in top 5
Epoch 61, gl

Epoch 121, global step 11833: train_loss_ssl was not in top 5
Epoch 122, global step 11930: train_loss_ssl was not in top 5
Epoch 123, global step 12027: train_loss_ssl was not in top 5
Epoch 124, global step 12124: train_loss_ssl was not in top 5
Epoch 125, global step 12221: train_loss_ssl was not in top 5
Epoch 126, global step 12318: train_loss_ssl reached 5.46104 (best 5.44025), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco_RESNETCOPY_extralinear/epoch=126-train_loss_ssl=5.46-2head.ckpt" as top 5
Epoch 127, global step 12415: train_loss_ssl reached 5.32366 (best 5.32366), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco_RESNETCOPY_extralinear/epoch=127-train_loss_ssl=5.32-2head.ckpt" as top 5
Epoch 128, global step 12512: train_loss_ssl was not in top 5
Epoch 129, global step 12609: train_loss_ssl reached 5.34622 (best 5.32366), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco_RESNETCOPY_extralinea

Epoch 216, global step 21048: train_loss_ssl was not in top 5
Epoch 217, global step 21145: train_loss_ssl was not in top 5
Epoch 218, global step 21242: train_loss_ssl was not in top 5
Epoch 219, global step 21339: train_loss_ssl was not in top 5
Epoch 220, global step 21436: train_loss_ssl was not in top 5
Epoch 221, global step 21533: train_loss_ssl was not in top 5
Epoch 222, global step 21630: train_loss_ssl was not in top 5
Epoch 223, global step 21727: train_loss_ssl was not in top 5
Epoch 224, global step 21824: train_loss_ssl was not in top 5
Epoch 225, global step 21921: train_loss_ssl was not in top 5
Epoch 226, global step 22018: train_loss_ssl reached 5.23765 (best 5.23765), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco_RESNETCOPY_extralinear/epoch=226-train_loss_ssl=5.24-2head.ckpt" as top 5
Epoch 227, global step 22115: train_loss_ssl was not in top 5
Epoch 228, global step 22212: train_loss_ssl was not in top 5
Epoch 229, global step 22309:

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Epoch 349, global step 33949: train_loss_ssl was not in top 5
Epoch 350, global step 34046: train_loss_ssl was not in top 5
Epoch 351, global step 34143: train_loss_ssl was not in top 5
Epoch 352, global step 34240: train_loss_ssl was not in top 5
Epoch 353, global step 34337: train_loss_ssl was not in top 5
Epoch 354, global step 34434: train_loss_ssl was not in top 5
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_lim

In [None]:
break

In [None]:
model

In [None]:
lightly.models.ResNetGenerator('resnet-18', 1, num_splits=0)

In [None]:
from plr18_RESNETCOPY import plr18

In [None]:
mod = nn.Sequential(
    *list(plr18().model.children()),
)

In [None]:
mod = nn.Sequential(
    *list(lightly.models.ResNetGenerator('resnet-18', 1, num_splits=0).children())[:-1],
    nn.AvgPool2d(4),
    nn.Flatten(),
    nn.Linear(512, 10),
    nn.Unflatten(1, (10, 1, 1)),
    nn.AdaptiveAvgPool2d(1)
)

In [None]:
mod

In [None]:
x = torch.rand((512, 3, 32, 32))
y = mod(x)
y.shape