# 🏞 Convolutional Neural Network

In this notebook, we'll walk through the steps required to train your own convolutional neural network (CNN) on the CIFAR dataset

In [1]:
import os

import torch
from torch import nn
import torch.nn.functional as F
import torch.utils.data as data

import torchvision
import torchvision.transforms as transforms


from torchmetrics import Accuracy
import lightning.pytorch as L




## 0. Parameters <a name="parameters"></a>

In [33]:
batch_size = 32

NUM_WORKERS = int(os.cpu_count() / 2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
NUM_CLASSES = len(classes)

## 1. Prepare the Data <a name="prepare"></a>

In [34]:
# Data transformation
transform = transforms.Compose(
    [
        # torchvision.transforms.RandomCrop(32, padding=4),
        # torchvision.transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]

)



In [35]:
# Loading TrainSet
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

# Splitting Trainset to Train / Val
train_set_size = int(len(trainset) * 0.8)
valid_set_size = len(trainset) - train_set_size

seed = torch.Generator().manual_seed(42)
trainset, validset = torch.utils.data.random_split(trainset, [train_set_size, valid_set_size], generator=seed)

Files already downloaded and verified


In [36]:
# Loading
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)


Files already downloaded and verified


In [37]:
# Making dataloaders

train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=NUM_WORKERS)

valid_dataloader = torch.utils.data.DataLoader(validset, batch_size=batch_size,
                                          shuffle=False, num_workers=NUM_WORKERS)

test_dataloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=NUM_WORKERS)

## 2. Build the model <a name="build"></a>

In [38]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(num_features=32)

        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(num_features=32)

        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(num_features=64)


        self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(num_features=64)

        self.l1 = nn.Sequential(nn.Linear(64 * 8 * 8, 128), nn.BatchNorm1d(128))

        self.dropout = nn.Dropout(p=0.5)
        
        self.l2 = nn.Linear(128, 10)

    def forward(self, x):
      
        x = F.leaky_relu(self.bn1(self.conv1(x)))
        x = F.leaky_relu(self.bn2(self.conv2(x)))
        x = F.leaky_relu(self.bn3(self.conv3(x)))
        x = F.leaky_relu(self.bn4(self.conv4(x)))

        x = x.view(x.size(0), -1)

        x = F.leaky_relu(self.l1(x))

        x = self.dropout(x)
       
        x = self.l2(x)
    
        return x




### `torchsummary` can help you see what the input and output in each layer is going on

In [39]:

from torchsummary import summary

model = Net()
summary(model, (3, 32, 32))  # Assuming input size of 3x32x32

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 32, 32]             896
       BatchNorm2d-2           [-1, 32, 32, 32]              64
            Conv2d-3           [-1, 32, 16, 16]           9,248
       BatchNorm2d-4           [-1, 32, 16, 16]              64
            Conv2d-5           [-1, 64, 16, 16]          18,496
       BatchNorm2d-6           [-1, 64, 16, 16]             128
            Conv2d-7             [-1, 64, 8, 8]          36,928
       BatchNorm2d-8             [-1, 64, 8, 8]             128
            Linear-9                  [-1, 128]         524,416
      BatchNorm1d-10                  [-1, 128]             256
          Dropout-11                  [-1, 128]               0
           Linear-12                   [-1, 10]           1,290
Total params: 591,914
Trainable params: 591,914
Non-trainable params: 0
-------------------------------

In [2]:
class MyLitModel(L.LightningModule):
    def __init__(self, lr = 0.0005):
        super().__init__()

        self.save_hyperparameters()
        # Tip: 
        # The LightningModule allows you to automatically save all the hyperparameters passed to init simply by calling self.save_hyperparameters().
        self.hparams.lr = lr
        self.model = Net()
        self.train_accuracy = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
        self.val_accuracy = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x, y = batch

        logits = self(x)
        loss = self.criterion(logits, y)
        self.log("train_loss", loss)
        acc = self.train_accuracy(logits, y)
        self.log("train_accuracy", acc, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch

        logits = self(x)
        loss = self.criterion(logits, y)
        self.log("val_loss", loss)
        acc = self.val_accuracy(logits, y)
        self.log("val_accuracy", acc)


    def test_step(self, batch, batch_idx):
        x, y = batch
 
        logits = self(x)
        loss = self.criterion(logits, y)
        self.log("val_loss", loss)
        acc = self.val_accuracy(logits, y)
        self.log("val_accuracy", acc)


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        return optimizer

## 3. Train the model <a name="train"></a>

In [30]:
model = MyLitModel()

# You can use the checkpoint

# checkpoint = "./log/lightning_logs/version_1/checkpoints/epoch=0-step=1250.ckpt"
# model = MyLitModel.load_from_checkpoint(checkpoint)

trainer = L.Trainer(
    max_epochs=1,
    default_root_dir="./log",
)

trainer.fit(model, train_dataloader, valid_dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: log/lightning_logs

  | Name           | Type               | Params
------------------------------------------------------
0 | model          | Net                | 591 K 
1 | train_accuracy | MulticlassAccuracy | 0     
2 | val_accuracy   | MulticlassAccuracy | 0     
3 | criterion      | CrossEntropyLoss   | 0     
------------------------------------------------------
591 K     Trainable params
0         Non-trainable params
591 K     Total params
2.368     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/jjookim/miniconda3/envs/ml_project/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.
/Users/jjookim/miniconda3/envs/ml_project/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


## 4. Evaluation <a name="evaluate"></a>

In [31]:
checkpoint = "./log/lightning_logs/version_0/checkpoints/epoch=0-step=1250.ckpt"
model = MyLitModel.load_from_checkpoint(checkpoint)


trainer.test(model, dataloaders=test_dataloader)

/Users/jjookim/miniconda3/envs/ml_project/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'test_dataloader' to speed up the dataloader worker initialization.


Testing: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 0.9919371008872986, 'val_accuracy': 0.6521000266075134}]