## Vision Encoder (Simple CNN)

1. **Input**: Fashion MNIST images (bs, 1, 28, 28)
1. 2-3 convolutional layers + pooling
1. **Output**: Single feature vector (bs, 512)

### Setup

In [None]:
try:
    import google.colab
    !pip install -q git+https://github.com/tripathysagar/NanoTransformer.git
except Exception as e:
    pass

Name: NanoTransformer
Version: 0.0.1
Summary: a transformer experiments
Home-page: https://github.com/tripathysagar/NanoTransformer
Author: tripathysagar
Author-email: tripathysagar08@gmail.com
License: Apache Software License 2.0
Location: /app/data/.local/lib/python3.12/site-packages
Editable project location: /app/data/NanoTransformer
Requires: 
Required-by: 


Note: you may need to restart the kernel to use updated packages.


In [1]:
#|export
from NanoTransformer.data import *
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_

For creating the dataloders

In [2]:
#|export
dls = get_vision_classifier_dl()
for x, y in dls['valid']:
    break
x.shape, y.shape

(torch.Size([64, 1, 28, 28]), torch.Size([64]))

configs for device and dtype

In [3]:
#|export
from dataclasses import dataclass

@dataclass
class VisionConfig:
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16 # Use bfloat16 on Ampere+ GPUs, otherwise use float16

    lr = 1e-3
    max_grad_norm = 1.0

visConfig = VisionConfig()

### Model
The classifier consists of 3 type of object
1. **ResBlock**: For building a resblock 
1. **VisionEncoder**: which consists of vision head for classifier
1. **classfier head**: helps in classification

In [4]:
#|export
import torch.nn.functional as F

class ResBlock(nn.Module):
    """
    Residual Block with skip connection.
    Applies two convolutions with a skip connection that adds input to output.
    """
    def __init__(self, ni, nf, ks=3, stride=2):
        """
        Args:
            ni: number of input channels
            nf: number of output channels (filters)
            ks: kernel size (default 3)
            stride: stride for first conv (default 2 for downsampling)
        """
        super().__init__()
        # First conv: changes channels and spatial dims
        self.conv1 = nn.Sequential(
            nn.Conv2d(ni, nf, ks, padding=ks//2, stride=stride),
            nn.BatchNorm2d(nf))
        
        # Second conv: keeps channels and spatial dims constant
        self.conv2 = nn.Sequential(
            nn.Conv2d(nf, nf, ks, padding=ks//2, stride=1),
            nn.BatchNorm2d(nf))
        
        # Handle dimension mismatch
        self.skip = nn.Conv2d(ni, nf, 1, stride=stride) if ni != nf else nn.Identity()
    
    def forward(self, x):
        # Add skip connection to output of two convs
        return F.relu(self.skip(x) + self.conv2(F.relu(self.conv1(x))))

In [5]:
#|export
class VisionEncoder(nn.Module):
    """
    CNN encoder for Fashion MNIST images.
    Progressively downsamples and increases channels to create feature vector.
    """
    def __init__(self):
        super().__init__()

        self.VisionHead = nn.Sequential(
            ResBlock(1, 64, ks=7, stride=2),      # 28→14
            ResBlock(64, 128, stride=2),          # 14→7
            ResBlock(128, 256, stride=2),         # 7→4
            ResBlock(256, 512, stride=2),         # 4→2
            ResBlock(512, 512, stride=2),         # 2→2
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten()                          # 512
        )

    def forward(self, x):
        """
        Args:
            x: input images (bs, 1, 28, 28)
        Returns:
            feature vector (bs, 512)
        """
        return self.VisionHead(x)

In [6]:
#|export 
classifier = nn.Sequential(
    VisionEncoder(),
    nn.Sequential(
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, 10)))

### loss function

In [7]:
#|export
loss_func = nn.CrossEntropyLoss()

In [8]:
pred = classifier(x)
pred[0]

tensor([ 0.7743, -0.4576, -0.0745,  0.6165,  0.7950, -0.0955, -0.5897, -0.6505,
        -0.7653,  0.1114], grad_fn=<SelectBackward0>)

In [9]:
loss_func(pred, y)

tensor(2.4736, grad_fn=<NllLossBackward0>)

### Training

In [None]:
#|export 
def vision_encoder_train(model, epochs=10):
    model = model.to(visConfig.device)
    optimizer = AdamW(model.parameters(), lr=visConfig.lr)

    for epoch in range(epochs):
        model.train()
        train_loss = 0

        for x, y in dls['train']:
            x, y = x.to(visConfig.device), y.to(visConfig.device)
            optimizer.zero_grad()

            with torch.autocast(device_type=visConfig.device, dtype=visConfig.dtype):
                logits = model(x)
                loss = loss_func(logits, y)

            loss.backward()

            clip_grad_norm_(model.parameters(), visConfig.max_grad_norm) # to clip gradients

            optimizer.step()

            train_loss += loss.item()

        classifier.eval()
        val_loss = 0
        with torch.no_grad(), torch.autocast(device_type=visConfig.device, dtype=visConfig.dtype):
            for x, y in dls['valid']:
                x, y = x.to(visConfig.device), y.to(visConfig.device)

                logits = model(x)
                loss = loss_func(logits, y)
                val_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss/len(dls['train']):.4f} Validation Loss: {val_loss/len(dls['valid']):.4f}")

In [None]:
vision_encoder_train(classifier, 10)

### save model

In [None]:
#|export 
torch.save(classifier, path/'classfier.pth')