# CNN

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchmetrics
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import numpy as np

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Hyper-parameters

In [None]:
batch_size = 64
learning_rate = 0.0003

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

## Set train and test datasets

In [None]:
train_dataset = torchvision.datasets.CIFAR10(
    root='./CIFAR10/data',
    train=True,
    download=True,
    transform=transform
)

test_dataset = torchvision.datasets.CIFAR10(
    root='./CIFAR10/data',
    train=False,
    download=True,
    transform=transform
)

## Dataloaders

In [None]:
train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
classes = ('plane', 'car', 'brid', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

## Model

In [None]:
# Input size is 3 because we will send 3 types of color channels
input_size = 3
output_size = 6
kernel_size = 5

class ConvNet(pl.LightningModule):
    def __init__(self, learning_rate):
        super().__init__()
        self.learning_rate = learning_rate
        self.configure_metrics()
        # self.loss_func = nn.CrossEntropyLoss()
        # Feature learning
        self.conv1 = nn.Conv2d(input_size, output_size, kernel_size)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(output_size, 16, kernel_size)
        # Classification
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.reshape(-1, 16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    def configure_metrics(self):
        self.train_acc = torchmetrics.Accuracy()
        self.valid_acc = torchmetrics.Accuracy()
        self.valid_precision = torchmetrics.Precision(num_classes=10)
        self.valid_recall = torchmetrics.Recall(num_classes=10)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        output = self(x)
        loss = nn.CrossEntropyLoss()(output, y)
        # self.loss_func(output, y)
        self.train_acc(output, y)
        self.log('train_acc', self.train_acc, on_step=False, on_epoch=True)
        self.log('train_loss', loss, on_step=False, on_epoch=True)
        return loss
    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        output = self(x)
        loss = nn.CrossEntropyLoss()(output, y)
        self.valid_precision(output, y)
        self.valid_recall(output, y)
        self.valid_acc(output, y)
        self.log("precision", self.valid_precision, on_step=False, on_epoch=True)
        self.log("recall", self.valid_recall, on_step=False, on_epoch=True)
        self.log('val_acc', self.valid_acc, on_step=False, on_epoch=True)
        self.log('val_loss', loss, on_step=False, on_epoch=True)
    
    # def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
    #     x, y = batch
    #     tb = self.logger.experiment

    #     grid = torchvision.utils.make_grid(x)
    #     tb.add_image('Epoch start images', grid)

    # def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
    #     x, y = batch
    #     tb = self.logger.experiment

    #     grid = torchvision.utils.make_grid(x)
    #     tb.add_image('Epoch end images', grid)


model = ConvNet(learning_rate)

In [None]:
from pytorch_lightning.callbacks import Callback


class MyCallback(Callback):
    def on_init_start(self, trainer):
        print("Starting to init trainer!")

    def on_validation_epoch_start(self, trainer, pl_module):
        print("validation start")

    def on_validation_epoch_end(self, trainer, pl_module):
        print("validation ends")

    def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
        x, y = batch
        tb = pl_module.logger.experiment

        # img = np.reshape(x[0:], -1, 28, 28, 1)
        grid = torchvision.utils.make_grid(x, normalize=True)
        tb.add_image('Epoch start images', grid)

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        print(type(outputs))
        print(outputs.shape)
        print(type(batch))
        print(batch.shape)
        # x, y = outputs
        # tb = pl_module.logger.experiment

        # # img = np.reshape(x[0:], -1, 28, 28, 1)
        # grid = torchvision.utils.make_grid(x, normalize=True)
        # tb.add_image('Epoch end images', grid)

## Find best learning rate

In [None]:
# trainer = pl.Trainer(auto_lr_find=True)
# lr_finder = trainer.tuner.lr_find(model)
# lr_finder.results
# fig = lr_finder.plot(suggest=True)
# fig.show()
# new_lr = lr_finder.suggestion()
# model.hparams.lr = new_lr
# print(new_lr)

## Train and validate

In [None]:
trainer = pl.Trainer(max_epochs=2, callbacks=[MyCallback()])
trainer.fit(model, train_dl, test_dl)

In [None]:
%tensorboard