# CNN

In [222]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchmetrics
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.tensorboard import SummaryWriter

Hyper-parameters

In [223]:
batch_size = 32
learning_rate = 0.001

In [224]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

## Set train and test datasets

In [225]:
train_dataset = torchvision.datasets.CIFAR10(
    root='./CIFAR10/data',
    train=True,
    download=True,
    transform=transform
)

test_dataset = torchvision.datasets.CIFAR10(
    root='./CIFAR10/data',
    train=False,
    download=True,
    transform=transform
)

Files already downloaded and verified
Files already downloaded and verified


## Dataloaders

In [226]:
train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

In [227]:
classes = ('plane', 'car', 'brid', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

## Model

In [228]:
# Input size is 3 because we will send 3 types of color channels
input_size = 3
output_size = 6
kernel_size = 5

class ConvNet(pl.LightningModule):
    def __init__(self, learning_rate):
        super().__init__()
        self.learning_rate = learning_rate
        self.configure_metrics()

        # Feature learning
        self.conv1 = nn.Conv2d(input_size, output_size, kernel_size)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(output_size, 16, kernel_size)
        
        # Classification
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

        # self.fc1.register_forward_hook(self.activation_hook)
        # self.fc2.register_forward_hook(self.activation_hook)
        # self.fc3.register_forward_hook(self.activation_hook)
        # self.get_all_layers()
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.reshape(-1, 16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    # def train_dataloader(self):
    #     return torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    # def val_dataloader(self):
    #     return torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    def configure_metrics(self):
        self.train_acc = torchmetrics.Accuracy()
        self.valid_acc = torchmetrics.Accuracy()
        self.valid_precision = torchmetrics.Precision(num_classes=10)
        self.valid_recall = torchmetrics.Recall(num_classes=10)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        output = self(x)
        loss = nn.CrossEntropyLoss()(output, y)
        self.train_acc(output, y)
        self.log('train_acc', self.train_acc, on_step=False, on_epoch=True, logger=True)
        self.log('train_loss', loss, on_step=False, on_epoch=True, logger=True)
        
        return loss    

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        output = self(x)
        loss = nn.CrossEntropyLoss()(output, y)
        
        self.valid_precision(output, y)
        self.valid_recall(output, y)
        self.valid_acc(output, y)
        self.log("precision", self.valid_precision, on_step=False, on_epoch=True, logger=True)
        self.log("recall", self.valid_recall, on_step=False, on_epoch=True, logger=True)
        self.log('val_acc', self.valid_acc, on_step=False, on_epoch=True, logger=True)
        self.log('val_loss', loss, on_step=False, on_epoch=True, logger=True)

    # def activation_hook(self, inst, inp, out):
    #     """Run activation hook

    #     Parameters
    #     ----------
    #     inst : torch.nn.Module
    #         The layer we want to attach the hook to.
    #     inp : torch.Tensor
    #         The input to the `forward` method.
    #     out : torch.Tensor
    #         The output of the `forward` method.
    #     """
    #     # tb = SummaryWriter()
    #     # Create histogram of layer weights
    #     self.logger.experiment.add_histogram(repr(inst), out)

    #     img_grid = torchvision.utils.make_grid(inp[0])
    #     self.logger.experiment.add_image('Output images', img_grid)

    #     img_grid = torchvision.utils.make_grid(out)
    #     self.logger.experiment.add_image('Output images', img_grid)

model = ConvNet(learning_rate)

# model.fc1.register_forward_hook(model.activation_hook)
# model.fc2.register_forward_hook(model.activation_hook)
# model.fc3.register_forward_hook(model.activation_hook)



In [229]:
from pytorch_lightning.callbacks import Callback


class MyCallback(Callback):
    # def on_init_start(self, trainer):
        # trainer.lightning_module.logger.experiment
        # self.writer = trainer.lightning_module.logger.experiment
        # trainer.lightning_module.fc1.register_forward_hook(self.activation_hook)
        # trainer.lightning_module.fc2.register_forward_hook(self.activation_hook)
        # trainer.lightning_module.fc3.register_forward_hook(self.activation_hook)

    def on_fit_start(self, trainer, pl_module):
        """Callback function that gets executed before the fit starts

        Parameters
        ----------
        trainer : pl.Trainer
            The trainer of the CNN module (pl_module)
        pl_module : pl.LightningModule
            The model we want to use to retrieve information
        """
        print("Starting to fit trainer!")
        
        self.writer = pl_module.logger.experiment
        pl_module.fc1.register_forward_hook(self.activation_hook)
        pl_module.fc2.register_forward_hook(self.activation_hook)
        pl_module.fc3.register_forward_hook(self.activation_hook)

    def activation_hook(self, inst, inp, out):
        """Run activation hook

        Parameters
        ----------
        inst : torch.nn.Module
            The layer we want to attach the hook to.
        inp : torch.Tensor
            The input to the `forward` method.
        out : torch.Tensor
            The output of the `forward` method.
        """
        # Create histogram of layer weights
        self.writer.add_histogram(repr(inst), out)

        # idx = torch.randint(0, inp[0].size(0), ())
        # pred = self.normalize_output(inp[0][idx, 0])

        img_grid = torchvision.utils.make_grid(inp[0])
        self.writer.add_image('Forward Input images', img_grid)

        idx = torch.randint(0, out.size(0), ())
        pred = self.normalize_output(out[idx, 0])

        img_grid = torchvision.utils.make_grid(out)
        self.writer.add_image('Forward Output images', img_grid)

    # def get_all_layers(self, model):
    #     """Gets all the layers of the CNN

    #     Parameters
    #     ----------
    #     model : pl.LightningModule
    #         The model we want to use to retrieve information
    #     """
    #     for name, layer in model._modules.items():
    #         if isinstance(layer, nn.Sequential):
    #             self.get_all_layers(layer)
    #         else:
    #             layer.register_forward_hook(self.activation_hook)

    # def matplotlib_imshow(self, img, one_channel=False):
    #     if one_channel:
    #         img = img.mean(dim=0)
    #     img = img / 2 + 0.5
    #     npimg = img.numpy()
        # if one_channel:
        #     plt.imshow(npimg, cmap="Greys")
        # else:
        #     plt.imshow(np.transpose(npimg, (1, 2, 0)))

    # def on_train_epoch_end(self, trainer, pl_module):
    #     print("Callbak training epoch end")

    #     for name, params in pl_module.named_parameters():
    #         print(f"Callback for in {pl_module.current_epoch}")
    #         pl_module.logger.experiment.add_histogram(name, params, pl_module.current_epoch)

    # def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
    #     x, y = batch
    #     tb = pl_module.logger.experiment

    #     # img = np.reshape(x[0:], -1, 28, 28, 1)
    #     grid = torchvision.utils.make_grid(x, normalize=True)
    #     tb.add_image('Epoch start images', grid)

    # def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
    #     x, y = outputs
    #     tb = pl_module.logger.experiment

    #     # img = np.reshape(x[0:], -1, 28, 28, 1)
    #     grid = torchvision.utils.make_grid(x, normalize=True)
    #     tb.add_image('Epoch end images', grid)

## Find best learning rate

In [230]:
# trainer = pl.Trainer(auto_lr_find=True)
# lr_finder = trainer.tuner.lr_find(model)
# lr_finder.results
# fig = lr_finder.plot(suggest=True)
# fig.show()
# new_lr = lr_finder.suggestion()
# model.hparams.lr = new_lr
# print(new_lr)

## Train and validate

In [231]:
trainer = pl.Trainer(max_epochs=2, callbacks=[MyCallback()])
trainer.fit(model, train_dl, test_dl)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name            | Type      | Params
----------------------------------------------
0 | train_acc       | Accuracy  | 0     
1 | valid_acc       | Accuracy  | 0     
2 | valid_precision | Precision | 0     
3 | valid_recall    | Recall    | 0     
4 | conv1           | Conv2d    | 456   
5 | pool            | MaxPool2d | 0     
6 | conv2           | Conv2d    | 2.4 K 
7 | fc1             | Linear    | 48.1 K
8 | fc2             | Linear    | 10.2 K
9 | fc3             | Linear    | 850   
----------------------------------------------
62.0 K    Trainable params
0         Non-trainable params
62.0 K    Total params
0.248     Total estimated model params size (MB)


Starting to fit trainer!
Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                      

  rank_zero_warn(


Epoch 1: 100%|██████████| 1876/1876 [03:45<00:00,  8.33it/s, loss=1.24, v_num=1]


In [232]:
%tensorboard

UsageError: Line magic function `%tensorboard` not found.
