# CNN

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchmetrics
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import numpy as np

In [23]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Hyper-parameters

In [24]:
n_epochs = 4
batch_size = 6
learning_rate = 0.003

In [25]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

## Set train and test datasets

In [26]:
train_dataset = torchvision.datasets.CIFAR10(
    root='./CIFAR10/data',
    train=True,
    download=True,
    transform=transform
)

test_dataset = torchvision.datasets.CIFAR10(
    root='./CIFAR10/data',
    train=False,
    download=True,
    transform=transform
)

Files already downloaded and verified
Files already downloaded and verified


## Dataloaders

In [27]:
train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [28]:
classes = ('plane', 'car', 'brid', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

## Model

In [29]:
# Input size is 3 because we will send 3 types of color channels
input_size = 3
output_size = 6
kernel_size = 5

class ConvNet(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.configure_metrics()
        self.loss_func = nn.CrossEntropyLoss()
        # Feature learning
        self.conv1 = nn.Conv2d(input_size, output_size, kernel_size)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(output_size, 16, kernel_size)
        # Classification
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.reshape(-1, 16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def configure_metrics(self):
        self.train_acc = torchmetrics.Accuracy()
        self.valid_acc = torchmetrics.Accuracy()

    # def backward(self, loss, optimizer, optimizer_idx):
    #     loss.backward()
    
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
        # optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        output = self(x)
        loss = self.loss_func(output, y)
        self.train_acc(output, y)
        self.log('train_acc', self.train_acc)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        output = self(x)
        loss = self.loss_func(output, y)
        self.valid_acc(output, y)
        self.log('val_acc', self.valid_acc)
        self.log('val_loss', loss)


model = ConvNet()

## Train and validate

In [30]:
trainer = pl.Trainer(max_epochs=n_epochs)
trainer.fit(model, train_dl, test_dl)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name      | Type             | Params
-----------------------------------------------
0 | train_acc | Accuracy         | 0     
1 | valid_acc | Accuracy         | 0     
2 | loss_func | CrossEntropyLoss | 0     
3 | conv1     | Conv2d           | 456   
4 | pool      | MaxPool2d        | 0     
5 | conv2     | Conv2d           | 2.4 K 
6 | fc1       | Linear           | 48.1 K
7 | fc2       | Linear           | 10.2 K
8 | fc3       | Linear           | 850   
-----------------------------------------------
62.0 K    Trainable params
0         Non-trainable params
62.0 K    Total params
0.248     Total estimated model params size (MB)


Epoch 3: 100%|██████████| 1876/1876 [00:56<00:00, 33.40it/s, loss=1.77, v_num=3]
