# Lightly is great

In [1]:
import torch
import torch.nn as nn
import torchvision
import pytorch_lightning as pl
import lightly

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import loggers as pl_loggers

  rank_zero_deprecation(


In [2]:
# The dataset structure should be like this:
# cifar10/train/
#  L airplane/
#    L 10008_airplane.png
#    L ...
#  L automobile/
#  L bird/
#  L cat/
#  L deer/
#  L dog/
#  L frog/
#  L horse/
#  L ship/
#  L truck/
path_to_train = './data/cifar10_lightly/train/'
path_to_test = './data/cifar10_lightly/test/'

In [4]:
# MoCo v2 uses SimCLR augmentations, additionally, disable blur
collate_fn = lightly.data.SimCLRCollateFunction(
    input_size=32,
    gaussian_blur=0.,
)

In [4]:
# Augmentations typically used to train on cifar-10
train_classifier_transforms = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(32, padding=4),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=lightly.data.collate.imagenet_normalize['mean'],
        std=lightly.data.collate.imagenet_normalize['std'],
    )
])

# No additional augmentations for the test set
test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32, 32)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=lightly.data.collate.imagenet_normalize['mean'],
        std=lightly.data.collate.imagenet_normalize['std'],
    )
])

# We use the moco augmentations for training moco
dataset_train_moco = lightly.data.LightlyDataset(
    input_dir=path_to_train
)

# Since we also train a linear classifier on the pre-trained moco model we
# reuse the test augmentations here (MoCo augmentations are very strong and
# usually reduce accuracy of models which are not used for contrastive learning.
# Our linear layer will be trained using cross entropy loss and labels provided
# by the dataset. Therefore we chose light augmentations.)
dataset_train_classifier = lightly.data.LightlyDataset(
    input_dir=path_to_train,
    transform=train_classifier_transforms
)

dataset_test = lightly.data.LightlyDataset(
    input_dir=path_to_test,
    transform=test_transforms
)

In [5]:
# hyperparams
num_workers = 6
batch_size = 512
memory_bank_size = 4096
seed = 1
max_epochs = 150

In [6]:
dataloader_train_moco = torch.utils.data.DataLoader(
    dataset_train_moco,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=True,
    num_workers=num_workers
)

dataloader_train_classifier = torch.utils.data.DataLoader(
    dataset_train_classifier,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers
)

dataloader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)

In [7]:
class MocoModel(pl.LightningModule):
    def __init__(self):
        super().__init__()

        # create a ResNet backbone and remove the classification head
        resnet = lightly.models.ResNetGenerator('resnet-18', 1, num_splits=8)
        backbone = nn.Sequential(
            *list(resnet.children())[:-1],
            nn.AdaptiveAvgPool2d(1),
        )

        # create a moco based on ResNet
        self.resnet_moco = \
            lightly.models.MoCo(backbone, num_ftrs=512, m=0.99, batch_shuffle=True)

        # create our loss with the optional memory bank
        self.criterion = lightly.loss.NTXentLoss(
            temperature=0.1,
            memory_bank_size=memory_bank_size)

    def forward(self, x):
        self.resnet_moco(x)

    # We provide a helper method to log weights in tensorboard
    # which is useful for debugging.
    def custom_histogram_weights(self):
        for name, params in self.named_parameters():
            self.logger.experiment.add_histogram(
                name, params, self.current_epoch)

    def training_step(self, batch, batch_idx):
        (x0, x1), _, _ = batch
        y0, y1 = self.resnet_moco(x0, x1)
        loss = self.criterion(y0, y1)
        self.log('train_loss_ssl', loss)
        return loss

    def training_epoch_end(self, outputs):
        self.custom_histogram_weights()


    def configure_optimizers(self):
        optim = torch.optim.SGD(self.resnet_moco.parameters(), lr=6e-2,
                                momentum=0.9, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs)
        return [optim], [scheduler]


In [8]:
# you can also define a checkpoint callback to save best model like keras.
checkpoint_callback = ModelCheckpoint(
    dirpath='./saved_models/resnet_moco',
    filename='{epoch}-{train_loss_ssl:.2f}',
    save_top_k=5,
    verbose=True,
    monitor='train_loss_ssl',
    mode='min'
)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


In [9]:
# # Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

In [10]:
# use a GPU if available
gpus = 1 if torch.cuda.is_available() else 0

model = MocoModel()
tb_logger = pl_loggers.TensorBoardLogger(save_dir='./lightning_logs/', name='moco_150')
trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, callbacks=[checkpoint_callback], logger=tb_logger)
trainer.fit(
    model,
    dataloader_train_moco
)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type       | Params
-------------------------------------------
0 | resnet_moco | MoCo       | 23.0 M
1 | criterion   | NTXentLoss | 0     
-------------------------------------------
11.5 M    Trainable params
11.5 M    Non-trainable params
23.0 M    Total params
91.977    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Epoch 0, global step 96: train_loss_ssl reached 6.98624 (best 6.98624), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=0-train_loss_ssl=6.99.ckpt" as top 5
Epoch 1, global step 193: train_loss_ssl reached 7.15018 (best 6.98624), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=1-train_loss_ssl=7.15.ckpt" as top 5
Epoch 2, global step 290: train_loss_ssl reached 6.95802 (best 6.95802), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=2-train_loss_ssl=6.96.ckpt" as top 5
Epoch 3, global step 387: train_loss_ssl reached 6.76561 (best 6.76561), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=3-train_loss_ssl=6.77.ckpt" as top 5
Epoch 4, global step 484: train_loss_ssl reached 6.57230 (best 6.57230), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=4-train_loss_ssl=6.57.ckpt" as top 5
Epoch 5, global step 581: train_loss_ssl 

Epoch 43, global step 4267: train_loss_ssl reached 3.47386 (best 3.47386), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=43-train_loss_ssl=3.47.ckpt" as top 5
Epoch 44, global step 4364: train_loss_ssl was not in top 5
Epoch 45, global step 4461: train_loss_ssl reached 3.43905 (best 3.43905), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=45-train_loss_ssl=3.44.ckpt" as top 5
Epoch 46, global step 4558: train_loss_ssl reached 3.55804 (best 3.43905), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=46-train_loss_ssl=3.56.ckpt" as top 5
Epoch 47, global step 4655: train_loss_ssl reached 3.61927 (best 3.43905), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=47-train_loss_ssl=3.62.ckpt" as top 5
Epoch 48, global step 4752: train_loss_ssl reached 3.44943 (best 3.43905), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=48-

Epoch 94, global step 9214: train_loss_ssl was not in top 5
Epoch 95, global step 9311: train_loss_ssl reached 2.83680 (best 2.83099), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=95-train_loss_ssl=2.84.ckpt" as top 5
Epoch 96, global step 9408: train_loss_ssl was not in top 5
Epoch 97, global step 9505: train_loss_ssl reached 2.88812 (best 2.83099), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=97-train_loss_ssl=2.89.ckpt" as top 5
Epoch 98, global step 9602: train_loss_ssl reached 2.80977 (best 2.80977), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=98-train_loss_ssl=2.81.ckpt" as top 5
Epoch 99, global step 9699: train_loss_ssl was not in top 5
Epoch 100, global step 9796: train_loss_ssl reached 2.76853 (best 2.76853), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_moco/epoch=100-train_loss_ssl=2.77.ckpt" as top 5
Epoch 101, global step 9893: train_los