# 🧈 Training MOCOv2 on CIFAR10 🔥 With pytorch lightning 🧈

In [1]:
import torch
import torch.nn as nn
import torchvision
import pytorch_lightning as pl
import lightly

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import loggers as pl_loggers

SEED = 1

  rank_zero_deprecation(


# ⬇️ Build Dataset

In [2]:
# DATA hyperparams
num_workers = 6
moco_batch_size = 512
classifier_train_batch_size = 512
classifier_test_batch_size = 512

In [3]:
# The dataset structure should be like this:
# cifar10/train/
#  L airplane/
#    L 10008_airplane.png
#    L ...
#  L automobile/
#  L bird/
#  L cat/
#  L deer/
#  L dog/
#  L frog/
#  L horse/
#  L ship/
#  L truck/
path_to_train = './data/cifar10_lightly/train/'
path_to_test = './data/cifar10_lightly/test/'

### Augmentations

In [4]:
################### Classifier Augmentations ###################
# Augmentations typically used to train on cifar-10
train_classifier_transforms = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(32, padding=4),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=lightly.data.collate.imagenet_normalize['mean'],
        std=lightly.data.collate.imagenet_normalize['std'],
    )
])

# No additional augmentations for the test set
test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32, 32)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=lightly.data.collate.imagenet_normalize['mean'],
        std=lightly.data.collate.imagenet_normalize['std'],
    )
])

################### MOCO Augmentations ###################
# MoCo v2 uses SimCLR augmentations, additionally, disable blur
collate_fn = lightly.data.SimCLRCollateFunction(
    input_size=32,
    gaussian_blur=0.,
)

### Datasets

In [5]:
################### Classifier Datasets ###################
#Since we also train a linear classifier on the pre-trained moco model we
# reuse the test augmentations here (MoCo augmentations are very strong and
# usually reduce accuracy of models which are not used for contrastive learning.
# Our linear layer will be trained using cross entropy loss and labels provided
# by the dataset. Therefore we chose light augmentations.)
dataset_train_classifier = lightly.data.LightlyDataset(
    input_dir=path_to_train,
    transform=train_classifier_transforms
)

dataset_test = lightly.data.LightlyDataset(
    input_dir=path_to_test,
    transform=test_transforms
)

################### MOCO Dataset ###################
# We use the moco augmentations for training moco
dataset_train_moco = lightly.data.LightlyDataset(
    input_dir=path_to_train
)

### Dataloaders

In [6]:
################### Classifier Dataloaders ###################
dataloader_train_classifier = torch.utils.data.DataLoader(
    dataset_train_classifier,
    batch_size=classifier_train_batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers
)

dataloader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=classifier_test_batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)

################### MOCO Dataloader ###################
dataloader_train_moco = torch.utils.data.DataLoader(
    dataset_train_moco,
    batch_size=moco_batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=True,
    num_workers=num_workers
)

# 🧠 Load Model

In [7]:
# MODEL hyperparams
memory_bank_size = 4096
moco_max_epochs = 9000
downstream_max_epochs = 60
downstream_test_every = 100

In [8]:
import moco_model

In [9]:
model = moco_model.MocoModel(memory_bank_size, moco_max_epochs, 
                             downstream_max_epochs, dataloader_train_classifier, dataloader_test,
                            downstream_test_every=downstream_test_every)

I will continue training from 900 epoch model

In [10]:
# # WHEN LOADING A SAVED MODEL DO IT LIKE DIS
model.load_from_checkpoint('./saved_models/resnet_moco/epoch=2989-train_loss_ssl=1.37.ckpt',
                          memory_bank_size=memory_bank_size, moco_max_epochs=moco_max_epochs,
                          downstream_max_epochs=downstream_max_epochs, dataloader_train_classifier=dataloader_train_classifier,
                           dataloader_test=dataloader_test,
                           downstream_test_every=downstream_test_every);
# checkpoint_callback = ModelCheckpoint(dirpath="./saved_models/resnet_moco/epoch=897-train_loss_ssl=1.67.ckpt")

In [11]:
# you can also define a checkpoint callback to save best model like keras.
checkpoint_callback = ModelCheckpoint(
    dirpath='./saved_models/resnet_moco',
    filename='{epoch}-{train_loss_ssl:.2f}',
    save_top_k=5,
    verbose=True,
    monitor='train_loss_ssl',
    mode='min'
)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


# 🏋️‍♂️ Train

In [12]:
# use a GPU if available
gpus = 1 if torch.cuda.is_available() else 0
print(f'Using gpu: {bool(gpus)}')
if(gpus == 0): print('--- NOT USING GPUS THIS TAKE LONG TIME ---')

# set up tensorboard logger
tb_logger = pl_loggers.TensorBoardLogger(save_dir='./lightning_logs/', name=f'TESTmoco_{moco_max_epochs}eps')

Using gpu: True


In [13]:
trainer = pl.Trainer(
    resume_from_checkpoint="./saved_models/resnet_moco/epoch=2989-train_loss_ssl=1.37.ckpt",
    max_epochs=moco_max_epochs,
    gpus=gpus,
    callbacks=[checkpoint_callback],
    logger=tb_logger)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, dataloader_train_moco)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type       | Params
-------------------------------------------
0 | resnet_moco | MoCo       | 23.0 M
1 | criterion   | NTXentLoss | 0     
-------------------------------------------
11.5 M    Trainable params
11.5 M    Non-trainable params
23.0 M    Total params
91.977    Total estimated model params size (MB)
Restored states from the checkpoint file at ./saved_models/resnet_moco/epoch=2989-train_loss_ssl=1.37.ckpt


Training: 0it [00:00, ?it/s]

Epoch 2990, global step 290126: train_loss_ssl reached 1.62537 (best 1.62537), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=2990-train_loss_ssl=1.63.ckpt" as top 5
Epoch 2991, global step 290223: train_loss_ssl reached 1.57939 (best 1.57939), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=2991-train_loss_ssl=1.58.ckpt" as top 5
Epoch 2992, global step 290320: train_loss_ssl reached 1.45108 (best 1.45108), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=2992-train_loss_ssl=1.45.ckpt" as top 5
Epoch 2993, global step 290417: train_loss_ssl reached 1.62739 (best 1.45108), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=2993-train_loss_ssl=1.63.ckpt" as top 5
Epoch 2994, global step 290514: train_loss_ssl reached 1.66138 (best 1.45108), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=2994-train_loss_ssl=1.66.ckpt" as t

... training downstream classifier...


Validation sanity check: 0it [00:00, ?it/s]

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Epoch 3000, global step 291096: train_loss_ssl was not in top 5


[1.0979354, 0.73298025, 0.5943247, 0.5201256, 0.47338232, 0.44068217, 0.41678634, 0.39874253, 0.38361573, 0.3713445, 0.3611299, 0.35190663, 0.34427562, 0.33745602, 0.33141768, 0.325974, 0.32106307, 0.31661698, 0.3125279, 0.30867705, 0.30524954, 0.30186403, 0.29885274, 0.29590496, 0.29325986, 0.29071814, 0.2881616, 0.28592587, 0.2837138, 0.28162882, 0.2796103, 0.27771598, 0.2758793, 0.27411047, 0.2724421, 0.27083424, 0.269256, 0.26773906, 0.266338, 0.26492855, 0.26360735, 0.26222062, 0.26090747, 0.25966802, 0.258513, 0.2574088, 0.25628924, 0.25519964, 0.25414208, 0.25313953, 0.2521674, 0.25119928, 0.2503316, 0.24942796, 0.24854916, 0.24769999, 0.24690984, 0.24617553, 0.24544938, 0.24468957]
[0.5945768, 0.6809809, 0.7256998, 0.7533838, 0.77252686, 0.78669477, 0.79753447, 0.80626047, 0.81340545, 0.81939167, 0.8245007, 0.8289243, 0.8327958, 0.8362214, 0.83927804, 0.8420287, 0.84452116, 0.84677863, 0.8488372, 0.85072744, 0.85247815, 0.854098, 0.85559744, 0.85699797, 0.8583053, 0.8595323, 0.

Epoch 3001, global step 291193: train_loss_ssl reached 1.45632 (best 1.45108), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=3001-train_loss_ssl=1.46.ckpt" as top 5
Epoch 3002, global step 291290: train_loss_ssl was not in top 5
Epoch 3003, global step 291387: train_loss_ssl was not in top 5
Epoch 3004, global step 291484: train_loss_ssl was not in top 5
Epoch 3005, global step 291581: train_loss_ssl was not in top 5
Epoch 3006, global step 291678: train_loss_ssl was not in top 5
Epoch 3007, global step 291775: train_loss_ssl reached 1.45087 (best 1.45087), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=3007-train_loss_ssl=1.45.ckpt" as top 5
Epoch 3008, global step 291872: train_loss_ssl was not in top 5
Epoch 3009, global step 291969: train_loss_ssl was not in top 5
Epoch 3010, global step 292066: train_loss_ssl was not in top 5
Epoch 3011, global step 292163: train_loss_ssl reached 1.50054 (best 1.45087), saving 

In [14]:
trainer = pl.Trainer(max_epochs=moco_max_epochs, gpus=gpus, callbacks=[checkpoint_callback], logger=tb_logger)
trainer.fit(
    model,
    dataloader_train_moco
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type       | Params
-------------------------------------------
0 | resnet_moco | MoCo       | 23.0 M
1 | criterion   | NTXentLoss | 0     
-------------------------------------------
11.5 M    Trainable params
11.5 M    Non-trainable params
23.0 M    Total params
91.977    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

  stream(template_mgs % msg_args)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type       | Params
-------------------------------------------
0 | resnet_moco | MoCo       | 23.0 M
1 | fc          | Sequential | 267 K 
2 | accuracy    | Accuracy   | 0     
-------------------------------------------
267 K     Trainable params
23.0 M    Non-trainable params
23.3 M    Total params
93.048    Total estimated model params size (MB)


... training downstream classifier...


Validation sanity check: 0it [00:00, ?it/s]

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Epoch 0, global step 96: train_loss_ssl reached 7.02934 (best 7.02934), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=0-train_loss_ssl=7.03.ckpt" as top 5


[2.0283015, 1.9727648, 1.9470253, 1.9317042, 1.9205841, 1.9113137, 1.9022614, 1.8948704, 1.887622, 1.881369, 1.8755313, 1.8699584, 1.8650374, 1.8604697, 1.8558487, 1.8516186, 1.8471583, 1.8428562, 1.8385209, 1.8344973, 1.8302956, 1.8264673, 1.822944, 1.8194226, 1.8161191, 1.8126774, 1.8091489, 1.8058736, 1.802578, 1.7993127, 1.7961755, 1.7930323, 1.7900654, 1.7870538, 1.7840934, 1.7811106, 1.7781818, 1.7753371, 1.7724425, 1.7696832, 1.7669392, 1.7642806, 1.7616137, 1.7589884, 1.7564116, 1.7539533, 1.7516084, 1.7492803, 1.7469553, 1.7447519, 1.7424794, 1.7403294, 1.7382442, 1.7362604, 1.7342988, 1.7324601, 1.7306111, 1.7288764, 1.7272174, 1.7255715]
[0.14195748, 0.19591479, 0.21982193, 0.2351632, 0.2434787, 0.25068486, 0.256891, 0.2619934, 0.26623875, 0.26991007, 0.27346855, 0.27668294, 0.2797659, 0.28235126, 0.28487542, 0.28726166, 0.28952643, 0.29164657, 0.29361933, 0.29552457, 0.2972754, 0.29890656, 0.30041993, 0.3019088, 0.30335063, 0.304776, 0.3060627, 0.30730867, 0.30850747, 0.309

Epoch 1, global step 193: train_loss_ssl reached 6.79746 (best 6.79746), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=1-train_loss_ssl=6.80.ckpt" as top 5
Epoch 2, global step 290: train_loss_ssl reached 6.55716 (best 6.55716), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=2-train_loss_ssl=6.56.ckpt" as top 5
Epoch 3, global step 387: train_loss_ssl reached 6.48588 (best 6.48588), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=3-train_loss_ssl=6.49.ckpt" as top 5
Epoch 4, global step 484: train_loss_ssl reached 6.24258 (best 6.24258), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=4-train_loss_ssl=6.24.ckpt" as top 5
  rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...')


# Final train

In [14]:
break

SyntaxError: 'break' outside loop (<ipython-input-14-6aaf1f276005>, line 1)

In [None]:
class testMocoModel(pl.LightningModule):
    def __init__(self):
        super().__init__()

        # create a ResNet backbone and remove the classification head
        resnet = lightly.models.ResNetGenerator('resnet-18', 1, num_splits=8)
        backbone = nn.Sequential(
            *list(resnet.children())[:-1],
            nn.AdaptiveAvgPool2d(1),
        )

        # create a moco based on ResNet
        self.resnet_moco = \
            lightly.models.MoCo(backbone, num_ftrs=512, m=0.99, batch_shuffle=True)

        # create our loss with the optional memory bank
        self.criterion = lightly.loss.NTXentLoss(
            temperature=0.1,
            memory_bank_size=memory_bank_size)

    def forward(self, x):
        self.resnet_moco(x)
        
    def contrastive_loss(self, x0, x1):
        # calculate the contrastive loss for some transformed x -> x0, x1
        # also return grad for each of these
        self.zero_grad()
        x0.requires_grad = True
        x1.requires_grad = True
        y0, y1 = self.resnet_moco(x0, x1)
        loss = self.criterion(y0, y1)
        loss.backward()
        return x0.grad, x1.grad, loss
    
    def contrastive_loss_nograd(self, x0, x1):
        with torch.no_grad():
            y0, y1 = self.resnet_moco(x0, x1)
            loss = self.criterion(y0, y1)
        return loss
        

    # We provide a helper method to log weights in tensorboard
    # which is useful for debugging.
    def custom_histogram_weights(self):
        for name, params in self.named_parameters():
            self.logger.experiment.add_histogram(
                name, params, self.current_epoch)

    def training_step(self, batch, batch_idx):
        (x0, x1), _, _ = batch
        y0, y1 = self.resnet_moco(x0, x1)
        loss = self.criterion(y0, y1)
        self.log('train_loss_ssl', loss)
        return loss

    def training_epoch_end(self, outputs):
        self.custom_histogram_weights()


    def configure_optimizers(self):
        optim = torch.optim.SGD(self.resnet_moco.parameters(), lr=6e-2,
                                momentum=0.9, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs)
        return [optim], [scheduler]


In [None]:
mocomodel = testMocoModel()
mocomodel.load_from_checkpoint('./saved_models/resnet_moco/epoch=142-train_loss_ssl=2.46.ckpt')
mocomodel.eval();

In [None]:
clf = moco_model.Classifier(mocomodel.resnet_moco, max_epochs=25)

In [None]:
trainer = pl.Trainer(max_epochs=max_epochs, gpus=1)
trainer.fit(
    clf,
    dataloader_train_classifier,
    dataloader_test
)

In [None]:
trainer.