In [None]:
#| default_exp ImageEncoder

## Vision Encoder (Simple CNN)

1. **Input**: Fashion MNIST images (bs, 1, 28, 28)
1. 2-3 convolutional layers + pooling
1. **Output**: Single feature vector (bs, 512)

### Setup

In [1]:
try:
    import google.colab
    !pip install -q git+https://github.com/tripathysagar/NanoTransformer.git
except Exception as e:
    pass

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [2]:
#|export
from NanoTransformer.data import *
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_

For creating the dataloders

In [3]:
#|export
dls = get_vision_classifier_dl()
for x, y in dls['valid']:
    break
x.shape, y.shape

(torch.Size([64, 1, 28, 28]), torch.Size([64]))

configs for device and dtype

In [None]:
#|export
from dataclasses import dataclass

@dataclass
class VisionConfig:
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16 # Use bfloat16 on Ampere+ GPUs, otherwise use float16

    lr = 1e-3
    max_grad_norm = 1.0

    head_op_dim = 512
    nc = 10

visConfig = VisionConfig()

### Model
The classifier consists of 3 type of object
1. **ResBlock**: For building a resblock
1. **VisionEncoder**: which consists of vision head for classifier
1. **classfier head**: helps in classification

In [5]:
#|export
import torch.nn.functional as F

class ResBlock(nn.Module):
    """
    Residual Block with skip connection.
    Applies two convolutions with a skip connection that adds input to output.
    """
    def __init__(self, ni, nf, ks=3, stride=2):
        """
        Args:
            ni: number of input channels
            nf: number of output channels (filters)
            ks: kernel size (default 3)
            stride: stride for first conv (default 2 for downsampling)
        """
        super().__init__()
        # First conv: changes channels and spatial dims
        self.conv1 = nn.Sequential(
            nn.Conv2d(ni, nf, ks, padding=ks//2, stride=stride),
            nn.BatchNorm2d(nf))

        # Second conv: keeps channels and spatial dims constant
        self.conv2 = nn.Sequential(
            nn.Conv2d(nf, nf, ks, padding=ks//2, stride=1),
            nn.BatchNorm2d(nf))

        # Handle dimension mismatch
        self.skip = nn.Conv2d(ni, nf, 1, stride=stride) if ni != nf else nn.Identity()

    def forward(self, x):
        # Add skip connection to output of two convs
        return F.relu(self.skip(x) + self.conv2(F.relu(self.conv1(x))))

In [None]:
#|export
class VisionEncoder(nn.Module):
    """
    CNN encoder for Fashion MNIST images.
    Progressively downsamples and increases channels to create feature vector.
    """
    def __init__(self):
        super().__init__()

        self.VisionHead = nn.Sequential(
            ResBlock(1, 64, ks=7, stride=2),      # 28→14
            ResBlock(64, 128, stride=2),          # 14→7
            ResBlock(128, 256, stride=2),         # 7→4
            ResBlock(256, visConfig.head_op_dim, stride=2),         # 4→2
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten()                          # 512
        )

    def forward(self, x):
        """
        Args:
            x: input images (bs, 1, 28, 28)
        Returns:
            feature vector (bs, 512)
        """
        return self.VisionHead(x)

In [None]:
#|export
classifier = nn.Sequential(
    VisionEncoder(),
    nn.Sequential(
            nn.Linear(visConfig.head_op_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, visConfig.nc)))

### loss function

In [8]:
#|export
loss_func = nn.CrossEntropyLoss()

In [9]:
pred = classifier(x)
pred[0]

tensor([ 0.5929,  0.2011, -0.2210, -0.7946,  0.3044,  0.2682,  0.2516,  0.3780,
         0.8079,  0.0866], grad_fn=<SelectBackward0>)

In [10]:
(pred.softmax(-1).argmax(-1) == y).float().mean()

tensor(0.1094)

In [11]:
loss_func(pred, y)

tensor(2.4303, grad_fn=<NllLossBackward0>)

### Training

In [12]:
#|export
def log(*args):
    print(f"{args[0]}   \t{args[1]:.4f}   \t{args[2]:.4f}\t\t{args[3]:.4f}")

In [13]:
#|export
def vision_encoder_train(model, epochs=10):
    model = model.to(visConfig.device)
    optimizer = AdamW(model.parameters(), lr=visConfig.lr)

    print(f"Epoch \tTrain Loss \tValid Loss \taccurecy")


    for epoch in range(epochs):
        model.train()
        train_loss = 0

        for x, y in dls['train']:
            x, y = x.to(visConfig.device), y.to(visConfig.device)
            optimizer.zero_grad()

            with torch.autocast(device_type=visConfig.device, dtype=visConfig.dtype):
                logits = model(x)
                loss = loss_func(logits, y)

            loss.backward()

            clip_grad_norm_(model.parameters(), visConfig.max_grad_norm) # to clip gradients

            optimizer.step()

            train_loss += loss.item()

        classifier.eval()
        val_loss = 0
        total_correct = 0
        total_samples = 0
        with torch.no_grad(), torch.autocast(device_type=visConfig.device, dtype=visConfig.dtype):
            for x, y in dls['valid']:
                x, y = x.to(visConfig.device), y.to(visConfig.device)

                logits = model(x)
                loss = loss_func(logits, y)

                val_loss += loss.item()
                predicted = logits.softmax(-1).argmax(-1)
                total_correct += (predicted == y).sum().item()
                total_samples += y.size(0)

        accurecy = total_correct / total_samples
        log(epoch+1, train_loss/len(dls['train']), val_loss/len(dls['valid']), accurecy)

In [None]:
vision_encoder_train(classifier, 15)

Epoch 	Train Loss 	Valid Loss 	accurecy
1   	0.4452   	0.3378		0.8728
2   	0.2998   	0.2488		0.9080
3   	0.2496   	0.2210		0.9178
4   	0.2149   	0.1625		0.9411
5   	0.1841   	0.1398		0.9501
6   	0.1607   	0.1164		0.9571
7   	0.1340   	0.0999		0.9644
8   	0.1126   	0.0833		0.9702
9   	0.0928   	0.0622		0.9776
10   	0.0827   	0.0523		0.9815
11   	0.0696   	0.0634		0.9762
12   	0.0604   	0.0397		0.9862
13   	0.0564   	0.0533		0.9813
14   	0.0481   	0.0270		0.9904
15   	0.0430   	0.0317		0.9895


### save model

In [None]:
#|export
def save_model(fn='classfier.pth'):
    torch.save(classifier, path/fn)