In [None]:
#| default_exp models.conv

In [2]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

# Convolution-based Model

In [3]:
#| export
import torch.nn as nn
import torch
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

from pytorch_lightning import LightningModule, Trainer
from torchmetrics import Accuracy
from hydra.utils import instantiate
from omegaconf import OmegaConf

from nimrod.image.datasets import MNISTDataModule
from nimrod.utils import get_device

## Conv Layer

Using a convolution with a stride of 2 instead of max pooling essentially achieves the same goal of downsampling an image by reducing its spatial dimensions, but with the key difference that the convolution layer can learn more complex feature combinations from overlapping regions, while max pooling only selects the maximum value within a window, potentially losing information about the finer details within that region; making the convolution with stride approach often preferred for preserving more spatial information in a neural network. 

In [52]:
#| export
class ConvLayer(nn.Module):
    def __init__(self,
                in_channels:int, # input channels
                out_channels:int, # output channels
                kernel_size:int=3, # kernel size
                activation:bool=True
                ):

        super().__init__()
        self.activation = activation
        # use stride 2 for downsampling instead of max or average pooling with stride 1
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, 2, kernel_size//2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        if self.activation:
            x = self.relu(x)
        return x

### Usage

In [67]:
B, C, H, W = 64, 1, 28, 28
X = torch.rand(B, C, H, W)
c = ConvLayer(1, 16, 3)
# flatten all dims except batch
Y = torch.flatten(c(X), 1)
print(Y.shape)

torch.Size([64, 3136])


## Convnet
Simple convolution network for image recognition

In [110]:
#| export
class ConvNet(nn.Module):
    def __init__(self, in_channels:int=1, out_channels:int=10):
        super().__init__()
        self.net = nn.Sequential(
            ConvLayer(in_channels, 8, kernel_size=5), #14x14
            nn.BatchNorm2d(8),
            ConvLayer(8, 16), #7x7
            nn.BatchNorm2d(16),
            ConvLayer(16, 32), #4x4
            nn.BatchNorm2d(32),
            ConvLayer(32, 64), #2x2
            nn.BatchNorm2d(64),
            ConvLayer(64, 10, activation=False), #1x1
            nn.BatchNorm2d(10),
            nn.Flatten()

        )

    def forward(self, x:torch.Tensor # input image tensor of dimension (B, C, W, H)
                ) -> torch.Tensor: # output probs (B, N_classes)

        return self.net(x)

### Usage

#### Mock data

In [105]:
B, C, H, W = 64, 1, 28, 28
X = torch.rand(B, C, H, W)
X.shape

torch.Size([64, 1, 28, 28])

#### Model

In [111]:
# model instantiation
convnet = ConvNet()
print(convnet)
out = convnet(X)
print(out.shape)

ConvNet(
  (net): Sequential(
    (0): ConvLayer(
      (conv): Conv2d(1, 8, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
      (relu): ReLU()
    )
    (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ConvLayer(
      (conv): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (relu): ReLU()
    )
    (3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): ConvLayer(
      (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (relu): ReLU()
    )
    (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ConvLayer(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (relu): ReLU()
    )
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ConvLayer(
      (conv): Conv2d(64, 10, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (rel

## Basic Training

### Dataloaders

In [135]:
# data module config
cfg = OmegaConf.load('../config/image/data/mnist.yaml')
cfg.datamodule.batch_size = 512
cfg.datamodule.pin_memory = True
cfg.num_workers = 1

# data module instantiation
datamodule = instantiate(cfg.datamodule)
datamodule.prepare_data()
datamodule.setup()

# one data point 
X,y = datamodule.data_test[0]
print("X (C,H,W): ", X.shape, "y: ", y)

# a batch of data via dataloader
XX,YY = next(iter(datamodule.test_dataloader()))
print("XX (B,C,H,W): ", XX.shape, "YY: ", YY.shape)

train_loader = datamodule.train_dataloader()
print(len(train_loader))
print(len(datamodule.data_train))
print(len(datamodule.data_train)//cfg.datamodule.batch_size)
val_loader = datamodule.val_dataloader()
test_loader = datamodule.test_dataloader()

X (C,H,W):  torch.Size([1, 28, 28]) y:  0
XX (B,C,H,W):  torch.Size([512, 1, 28, 28]) YY:  torch.Size([512])
110
56000
109


### Model & hardware

In [136]:
device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
model = ConvNet()
model = model.to(device)

mps


### Loss, optimizer, scheduler

In [137]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
steps_per_epochs = len(datamodule.data_train)//cfg.datamodule.batch_size
print(len(train_loader))
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=steps_per_epochs, epochs=1)

110


### Training loop

In [138]:
%%time
n_epochs = 1
for epoch in range(n_epochs):
    model.train()
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        # scheduler.step()

    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            # model expects input (B,H*W)
            images = images.to(device)
            labels = labels.to(device)
            # Pass the input through the model
            outputs = model(images)
            # Get the predicted labels
            _, predicted = torch.max(outputs.data, 1)

            # Update the total and correct counts
            total += labels.size(0)
            correct += (predicted == labels).sum()

        # Print the accuracy
        print(f"Epoch {epoch + 1}: Loss {loss.item():.4f}, Accuracy = {100 * correct / total:.2f}%")

Epoch 1: Loss 0.7540, Accuracy = 88.07%
CPU times: user 3.55 s, sys: 145 ms, total: 3.7 s
Wall time: 3.8 s


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()