# Just train and save a resnet18 on CIFAR10 with pytorch lightning

In [1]:
%matplotlib inline

In [2]:
import torch
import torchvision
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.models import resnet18
import torch.nn.functional as F
import torch.nn as nn
from torchvision.datasets import CIFAR10

import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import loggers as pl_loggers

import torchmetrics
from tqdm.notebook import tqdm
import lightly



# DATA

In [3]:
# data params
num_workers = 6
batch_size = 512

In [4]:
# The dataset structure should be like this:
# cifar10/train/
#  L airplane/
#    L 10008_airplane.png
#    L ...
#  L automobile/
#  L bird/
#  L cat/
#  L deer/
#  L dog/
#  L frog/
#  L horse/
#  L ship/
#  L truck/
path_to_train = './data/cifar10_lightly/train/'
path_to_test = './data/cifar10_lightly/test/'

In [5]:
# ------------- transforms ------------------- #

# Augmentations typically used to train on cifar-10
train_classifier_transforms = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(32, padding=4),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomVerticalFlip(),
    torchvision.transforms.ToTensor(),
    #     torchvision.transforms.Normalize(
    #         mean=(0.5, 0.5, 0.5),
    #         std=(0.5, 0.5, 0.5),
    #     ),
])

# No additional augmentations for the test set
test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32, 32)),
    torchvision.transforms.ToTensor(),
    #     torchvision.transforms.Normalize(
    #         mean=lightly.data.collate.imagenet_normalize['mean'],
    #         std=lightly.data.collate.imagenet_normalize['std'],
    #     ),
])

# --------------- datasets --------------------- #
dataset_train_classifier = lightly.data.LightlyDataset(
    input_dir=path_to_train,
    transform=train_classifier_transforms
)

dataset_test = lightly.data.LightlyDataset(
    input_dir=path_to_test,
    transform=test_transforms
)

# ------------------ dataloaders ----------------- #
train_dataloader = torch.utils.data.DataLoader(
    dataset_train_classifier,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers
)

val_dataloader = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)

# Model

In [6]:
# model hyperparams
max_epochs = 200

In [7]:
from plr18 import plr18
model = plr18()

# Train

In [8]:
# # Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

Reusing TensorBoard on port 6006 (pid 236117), started 5 days, 21:30:28 ago. (Use '!kill 236117' to kill it.)

In [9]:
# you can also define a checkpoint callback to save best model like keras.
checkpoint_callback = ModelCheckpoint(
    dirpath='./saved_models/resnet_80',
    filename='{epoch}-{val_loss:.2f}-{val_acc:.2f}',
    save_top_k=5,
    verbose=True,
    monitor='avg_val_acc',
    mode='max'
)

In [10]:
tb_logger = pl_loggers.TensorBoardLogger(save_dir='./lightning_logs/', name='resnet_80')
trainer = Trainer(gpus=1, callbacks=[checkpoint_callback], max_epochs=80, logger=tb_logger)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [11]:
trainer.fit(model, train_dataloader, val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type             | Params
----------------------------------------------
0 | model    | ResNet           | 11.2 M
1 | criteria | CrossEntropyLoss | 0     
----------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Epoch 0, global step 96: avg_val_acc reached 0.37483 (best 0.37483), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=0-val_loss=1.78-val_acc=0.38.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 1, global step 193: avg_val_acc reached 0.37457 (best 0.37483), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=1-val_loss=1.91-val_acc=0.37.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 2, global step 290: avg_val_acc reached 0.45832 (best 0.45832), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=2-val_loss=1.64-val_acc=0.46.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 3, global step 387: avg_val_acc reached 0.52592 (best 0.52592), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=3-val_loss=1.33-val_acc=0.53.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 4, global step 484: avg_val_acc reached 0.43286 (best 0.52592), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=4-val_loss=1.87-val_acc=0.44.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 5, global step 581: avg_val_acc reached 0.59018 (best 0.59018), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=5-val_loss=1.18-val_acc=0.59.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 6, global step 678: avg_val_acc reached 0.54636 (best 0.59018), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=6-val_loss=1.29-val_acc=0.54.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 7, global step 775: avg_val_acc reached 0.63771 (best 0.63771), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=7-val_loss=1.04-val_acc=0.64.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 8, global step 872: avg_val_acc reached 0.50869 (best 0.63771), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=8-val_loss=1.58-val_acc=0.50.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 9, global step 969: avg_val_acc reached 0.62031 (best 0.63771), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=9-val_loss=1.13-val_acc=0.62.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 10, global step 1066: avg_val_acc reached 0.65194 (best 0.65194), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=10-val_loss=0.99-val_acc=0.65.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 11, global step 1163: avg_val_acc reached 0.69014 (best 0.69014), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=11-val_loss=0.90-val_acc=0.69.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 12, global step 1260: avg_val_acc reached 0.67054 (best 0.69014), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=12-val_loss=0.97-val_acc=0.67.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 13, global step 1357: avg_val_acc reached 0.67202 (best 0.69014), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=13-val_loss=0.94-val_acc=0.67.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 14, global step 1454: avg_val_acc reached 0.67410 (best 0.69014), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=14-val_loss=0.94-val_acc=0.68.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 15, global step 1551: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 16, global step 1648: avg_val_acc reached 0.66148 (best 0.69014), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=16-val_loss=1.04-val_acc=0.66.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 17, global step 1745: avg_val_acc reached 0.69169 (best 0.69169), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=17-val_loss=0.91-val_acc=0.69.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 18, global step 1842: avg_val_acc reached 0.68784 (best 0.69169), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=18-val_loss=0.98-val_acc=0.68.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 19, global step 1939: avg_val_acc reached 0.70138 (best 0.70138), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=19-val_loss=0.88-val_acc=0.70.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 20, global step 2036: avg_val_acc reached 0.71988 (best 0.71988), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=20-val_loss=0.84-val_acc=0.72.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 21, global step 2133: avg_val_acc reached 0.70788 (best 0.71988), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=21-val_loss=0.89-val_acc=0.70.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 22, global step 2230: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 23, global step 2327: avg_val_acc reached 0.69906 (best 0.71988), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=23-val_loss=0.89-val_acc=0.70.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 24, global step 2424: avg_val_acc reached 0.71843 (best 0.71988), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=24-val_loss=0.85-val_acc=0.72.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 25, global step 2521: avg_val_acc reached 0.73516 (best 0.73516), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=25-val_loss=0.78-val_acc=0.73.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 26, global step 2618: avg_val_acc reached 0.70141 (best 0.73516), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=26-val_loss=0.95-val_acc=0.70.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 27, global step 2715: avg_val_acc reached 0.73003 (best 0.73516), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=27-val_loss=0.82-val_acc=0.73.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 28, global step 2812: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 29, global step 2909: avg_val_acc reached 0.73945 (best 0.73945), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=29-val_loss=0.76-val_acc=0.74.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 30, global step 3006: avg_val_acc reached 0.72740 (best 0.73945), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=30-val_loss=0.85-val_acc=0.72.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 31, global step 3103: avg_val_acc reached 0.74995 (best 0.74995), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=31-val_loss=0.78-val_acc=0.75.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 32, global step 3200: avg_val_acc reached 0.74862 (best 0.74995), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=32-val_loss=0.77-val_acc=0.75.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 33, global step 3297: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 34, global step 3394: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 35, global step 3491: avg_val_acc reached 0.75476 (best 0.75476), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=35-val_loss=0.76-val_acc=0.75.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 36, global step 3588: avg_val_acc reached 0.76189 (best 0.76189), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=36-val_loss=0.71-val_acc=0.76.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 37, global step 3685: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 38, global step 3782: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 39, global step 3879: avg_val_acc reached 0.75811 (best 0.76189), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=39-val_loss=0.74-val_acc=0.76.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 40, global step 3976: avg_val_acc reached 0.77022 (best 0.77022), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=40-val_loss=0.70-val_acc=0.77.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 41, global step 4073: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 42, global step 4170: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 43, global step 4267: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 44, global step 4364: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 45, global step 4461: avg_val_acc reached 0.75499 (best 0.77022), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=45-val_loss=0.77-val_acc=0.75.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 46, global step 4558: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 47, global step 4655: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 48, global step 4752: avg_val_acc reached 0.77386 (best 0.77386), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=48-val_loss=0.70-val_acc=0.77.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 49, global step 4849: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 50, global step 4946: avg_val_acc reached 0.77481 (best 0.77481), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=50-val_loss=0.71-val_acc=0.77.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 51, global step 5043: avg_val_acc reached 0.76631 (best 0.77481), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=51-val_loss=0.74-val_acc=0.77.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 52, global step 5140: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 53, global step 5237: avg_val_acc reached 0.76430 (best 0.77481), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=53-val_loss=0.76-val_acc=0.76.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 54, global step 5334: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 55, global step 5431: avg_val_acc reached 0.77879 (best 0.77879), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=55-val_loss=0.72-val_acc=0.78.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 56, global step 5528: avg_val_acc reached 0.76642 (best 0.77879), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=56-val_loss=0.75-val_acc=0.77.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 57, global step 5625: avg_val_acc reached 0.79226 (best 0.79226), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=57-val_loss=0.65-val_acc=0.79.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 58, global step 5722: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 59, global step 5819: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 60, global step 5916: avg_val_acc reached 0.78142 (best 0.79226), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=60-val_loss=0.73-val_acc=0.78.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 61, global step 6013: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 62, global step 6110: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 63, global step 6207: avg_val_acc reached 0.79846 (best 0.79846), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=63-val_loss=0.65-val_acc=0.80.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 64, global step 6304: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 65, global step 6401: avg_val_acc reached 0.77489 (best 0.79846), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=65-val_loss=0.73-val_acc=0.77.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 66, global step 6498: avg_val_acc reached 0.80195 (best 0.80195), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=66-val_loss=0.68-val_acc=0.80.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 67, global step 6595: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 68, global step 6692: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 69, global step 6789: avg_val_acc reached 0.78261 (best 0.80195), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=69-val_loss=0.71-val_acc=0.78.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 70, global step 6886: avg_val_acc reached 0.79077 (best 0.80195), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=70-val_loss=0.73-val_acc=0.79.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 71, global step 6983: avg_val_acc reached 0.78505 (best 0.80195), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=71-val_loss=0.73-val_acc=0.78.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 72, global step 7080: avg_val_acc reached 0.79055 (best 0.80195), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=72-val_loss=0.70-val_acc=0.79.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 73, global step 7177: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 74, global step 7274: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 75, global step 7371: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 76, global step 7468: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 77, global step 7565: avg_val_acc reached 0.79922 (best 0.80195), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=77-val_loss=0.70-val_acc=0.80.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 78, global step 7662: avg_val_acc was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 79, global step 7759: avg_val_acc reached 0.80526 (best 0.80526), saving model to "/home/shatz/Documents/more_better/saved_models/resnet_80/epoch=79-val_loss=0.70-val_acc=0.80.ckpt" as top 5
