# Just train and save a resnet18 on CIFAR10 with pytorch lightning

In [None]:
%matplotlib inline

In [None]:
import torch
import torchvision
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.models import resnet18
import torch.nn.functional as F
import torch.nn as nn
from torchvision.datasets import CIFAR10

import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import loggers as pl_loggers

import torchmetrics
from tqdm.notebook import tqdm
import lightly

# DATA

In [None]:
# data params
num_workers = 6
batch_size = 512

In [None]:
# The dataset structure should be like this:
# cifar10/train/
#  L airplane/
#    L 10008_airplane.png
#    L ...
#  L automobile/
#  L bird/
#  L cat/
#  L deer/
#  L dog/
#  L frog/
#  L horse/
#  L ship/
#  L truck/
path_to_train = './data/cifar10_lightly/train/'
path_to_test = './data/cifar10_lightly/test/'

In [None]:
# ------------- transforms ------------------- #

# Augmentations typically used to train on cifar-10
train_classifier_transforms = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(32, padding=4),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomVerticalFlip(),
    torchvision.transforms.ToTensor(),
    #     torchvision.transforms.Normalize(
    #         mean=(0.5, 0.5, 0.5),
    #         std=(0.5, 0.5, 0.5),
    #     ),
])

# No additional augmentations for the test set
test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32, 32)),
    torchvision.transforms.ToTensor(),
    #     torchvision.transforms.Normalize(
    #         mean=lightly.data.collate.imagenet_normalize['mean'],
    #         std=lightly.data.collate.imagenet_normalize['std'],
    #     ),
])

# --------------- datasets --------------------- #
dataset_train_classifier = lightly.data.LightlyDataset(
    input_dir=path_to_train,
    transform=train_classifier_transforms
)

dataset_test = lightly.data.LightlyDataset(
    input_dir=path_to_test,
    transform=test_transforms
)

# ------------------ dataloaders ----------------- #
train_dataloader = torch.utils.data.DataLoader(
    dataset_train_classifier,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers
)

val_dataloader = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)

# Model

In [None]:
# model hyperparams
max_epochs = 150

In [None]:
from plr18 import plr18
model = plr18()

# Train

In [None]:
# # Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

In [None]:
# you can also define a checkpoint callback to save best model like keras.
checkpoint_callback = ModelCheckpoint(
    dirpath='./saved_models/resnet_80',
    filename='{epoch}-{val_loss:.2f}-{val_acc:.2f}',
    save_top_k=5,
    verbose=True,
    monitor='avg_val_acc',
    mode='max'
)

In [None]:
tb_logger = pl_loggers.TensorBoardLogger(save_dir='./lightning_logs/', name='resnet_80')
trainer = Trainer(gpus=1, callbacks=[checkpoint_callback], max_epochs=80, logger=tb_logger)

In [None]:
trainer.fit(model, train_dataloader, val_dataloader)