<a href="https://colab.research.google.com/github/pmgarg/ERA_V4_Session6/blob/main/Session_7_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from tqdm import tqdm
import numpy as np
import random
from PIL import Image
import platform
import argparse
import sys
import os


In [2]:
DATASET = 'CIFAR-10'
NUM_CLASSES = 10
IMAGE_SIZE = 32
MEAN = (0.4914, 0.4822, 0.4465)
STD = (0.2470, 0.2435, 0.2616)

INITIAL_CHANNELS = 3

# Training
BATCH_SIZE = 128
LEARNING_RATE = 0.01
MOMENTUM = 0.9
WEIGHT_DECAY = 1e-4
NUM_EPOCHS = 50

# Augmentation
USE_AUGMENTATION = True

In [3]:
if torch.backends.mps.is_available():
    DEVICE = 'mps'
elif torch.cuda.is_available():
    DEVICE = 'cuda'
else:
    DEVICE = 'cpu'

In [4]:
class CutoutTransform:
    """Custom implementation of Cutout/CoarseDropout using torchvision"""
    def __init__(self, n_holes=1, length=16, fill_value=None):
        self.n_holes = n_holes
        self.length = length
        self.fill_value = fill_value

    def __call__(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to apply cutout
        Returns:
            PIL Image or Tensor: Image with cutout applied
        """
        if isinstance(img, Image.Image):
            img = np.array(img)
            was_pil = True
        else:
            was_pil = False

        h, w = img.shape[:2]

        # Use dataset mean if no fill value provided
        if self.fill_value is None:
            self.fill_value = [125, 122, 113]  # CIFAR-10 approximate means in 0-255 range

        for _ in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = max(0, y - self.length // 2)
            y2 = min(h, y + self.length // 2)
            x1 = max(0, x - self.length // 2)
            x2 = min(w, x + self.length // 2)

            # Apply cutout
            img[y1:y2, x1:x2] = self.fill_value

        if was_pil:
            return Image.fromarray(img)
        return img


class ShiftScaleRotate:
    """Custom implementation of ShiftScaleRotate using torchvision"""
    def __init__(self, shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5):
        self.shift_limit = shift_limit
        self.scale_limit = scale_limit
        self.rotate_limit = rotate_limit
        self.p = p

    def __call__(self, img):
        if random.random() > self.p:
            return img

        # Random parameters
        angle = random.uniform(-self.rotate_limit, self.rotate_limit)
        scale = random.uniform(1 - self.scale_limit, 1 + self.scale_limit)

        # Get image dimensions
        width, height = img.size if isinstance(img, Image.Image) else (img.shape[1], img.shape[0])

        # Calculate shift
        max_dx = self.shift_limit * width
        max_dy = self.shift_limit * height
        dx = random.uniform(-max_dx, max_dx)
        dy = random.uniform(-max_dy, max_dy)

        # Apply transformations using torchvision
        if isinstance(img, Image.Image):
            # Apply affine transformation
            img = transforms.functional.affine(
                img,
                angle=angle,
                translate=(dx, dy),
                scale=scale,
                shear=0
            )

        return img


In [5]:
fill_value = tuple([int(m * 255) for m in MEAN])
train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        ShiftScaleRotate(
            shift_limit=0.1,
            scale_limit=0.1,
            rotate_limit=15,
            p=0.5
        ),
        CutoutTransform(
            n_holes=1,
            length=16,
            fill_value=fill_value
        ),
        transforms.ToTensor(),
        transforms.Normalize(mean=MEAN, std=STD)
    ])
val_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=MEAN, std=STD)
    ])

In [6]:
train_dataset = datasets.CIFAR10(
        root='./data',
        train=True,
        transform=train_transform,
        download=True
    )

val_dataset = datasets.CIFAR10(
        root='./data',
        train=False,
        transform=val_transform,
        download=True
    )

train_loader = DataLoader(
        train_dataset,
        batch_size=128,
        shuffle=True
    )

val_loader = DataLoader(
        val_dataset,
        batch_size=128,
        shuffle=False
    )

100%|██████████| 170M/170M [01:15<00:00, 2.26MB/s]


In [7]:
class CIFAR10_CNN(nn.Module):
    """
    Advanced CNN for CIFAR-10 with:
    - C1C2C3C4 architecture (no MaxPooling)
    - Dilated convolutions for downsampling
    - Depthwise Separable Convolution
    - Global Average Pooling
    - Total RF > 44
    - Parameters < 200k
    """

    def __init__(self, num_classes=10):
        super().__init__()

        # C1: Initial convolution block (RF: 3)
        self.c1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3,stride=1, padding=1,bias=False), # RF: 3
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 32, kernel_size=3,stride=1, padding=1,bias=False), # RF: 5
            nn.BatchNorm2d(32),
        )

        # C2: Depthwise Separable Convolution block (RF increases)
        self.c2 = nn.Sequential(
                  nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, groups=32,
                            bias=False),
                  #  convolution (1x1)
                  nn.Conv2d(32, 64, kernel_size=1, bias=False),# RF: 7

                  nn.BatchNorm2d(64),
                  nn.Conv2d(64, 64, kernel_size=3,stride=1, padding=1,bias=False), # RF: 9
                  nn.BatchNorm2d(64),
                )

        # C3: Standard convolution with dilated conv (RF increases significantly)
        self.c3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3,stride=1, padding=1,bias=False), # RF: 11
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, kernel_size=3, padding=2, dilation=2, bias=False), # RF: 15 (dilated)
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, kernel_size=3,stride=1, padding=1,bias=False), # RF: 5
            nn.BatchNorm2d(128),
        )

        # C4: Final block with dilated convolution for downsampling (instead of stride)
        # Using dilated convolutions with higher dilation for effective downsampling
        self.c4 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=4, dilation=4, bias=False), # RF: 25 (high dilation)
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3,stride=1, padding=1,bias=False), # RF: 27
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, padding=8, dilation=8, bias=False), # RF: 43 (very high dilation)
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=1,stride=1, padding=0,bias=False), # 1x1 conv, RF unchanged
            nn.BatchNorm2d(256),
        )

        # Global Average Pooling
        self.gap = nn.AdaptiveAvgPool2d(1)

        # Fully Connected layer after GAP
        self.fc = nn.Linear(256, num_classes)

        # Initialize weights
        #self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = F.relu(self.c1(x))
        x = F.relu(self.c2(x))
        x = F.relu(self.c3(x))
        x = F.relu(self.c4(x))
        x = self.gap(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    def get_receptive_field(self):
        """Calculate and return the total receptive field"""
        # With the dilated convolutions:
        # C1: RF = 5
        # C2: RF = 9
        # C3: RF = 17 (with dilation=2)
        # C4: RF = 43+ (with dilation=4 and 8)
        # Total RF > 44 ✓
        return 45


In [8]:
class Trainer:
    """Trainer class for the CNN model"""

    def __init__(self, model):
        self.model = model.to(DEVICE)
        self.device = DEVICE

        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.SGD(
            self.model.parameters(),
            lr=LEARNING_RATE,
            momentum=MOMENTUM,
            weight_decay=WEIGHT_DECAY
        )

        self.best_accuracy = 0

    def train_epoch(self, train_loader):
        """Train for one epoch"""
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        pbar = tqdm(train_loader, desc='Training')
        for inputs, labels in pbar:
            inputs, labels = inputs.to(self.device), labels.to(self.device)

            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            pbar.set_postfix({
                'loss': running_loss / len(pbar),
                'acc': 100. * correct / total
            })

        return running_loss / len(train_loader), 100. * correct / total

    def train_epoch_with_scheduler(self, train_loader):
        """Train for one epoch with scheduler step"""
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        pbar = tqdm(train_loader, desc='Training')
        for inputs, labels in pbar:
            inputs, labels = inputs.to(self.device), labels.to(self.device)

            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()

            # Step scheduler after each batch
            if hasattr(self, 'scheduler'):
                self.scheduler.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            pbar.set_postfix({
                'loss': running_loss / len(pbar),
                'acc': 100. * correct / total
            })

        return running_loss / len(train_loader), 100. * correct / total

    def validate(self, val_loader):
        """Validate the model"""
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            pbar = tqdm(val_loader, desc='Validation')
            for inputs, labels in pbar:
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)

                running_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

                pbar.set_postfix({
                    'loss': running_loss / len(pbar),
                    'acc': 100. * correct / total
                })

        accuracy = 100. * correct / total
        return running_loss / len(val_loader), accuracy

    def train(self, train_loader, val_loader, num_epochs):
        """Full training loop"""

        # Setup learning rate scheduler
        total_steps = num_epochs * len(train_loader)
        self.scheduler = OneCycleLR(
            self.optimizer,
            max_lr=0.1,
            total_steps=total_steps,
            pct_start=0.3,
            anneal_strategy='cos'
        )

        for epoch in range(num_epochs):
            print(f'\nEpoch {epoch+1}/{num_epochs}')

            # Train
            train_loss, train_acc = self.train_epoch_with_scheduler(train_loader)

            # Validate
            val_loss, val_acc = self.validate(val_loader)

            print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
            print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')

            # Save best model
            if val_acc > self.best_accuracy:
                self.best_accuracy = val_acc
                self.save_checkpoint(epoch, val_acc)
                print(f'Best model saved! Accuracy: {val_acc:.2f}%')

    def save_checkpoint(self, epoch, accuracy):
        """Save model checkpoint"""
        os.makedirs('./checkpoints', exist_ok=True)

        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'accuracy': accuracy,
        }

        path = os.path.join('./checkpoints', f'best_model.pth')
        torch.save(checkpoint, path)


In [9]:
def count_parameters(model):
    """Count the number of trainable parameters"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def calculate_receptive_field(model):
    """Calculate theoretical receptive field"""
    rf = 1
    stride = 1

    # This is a simplified calculation
    # For accurate RF, trace through each layer
    layers_info = [
        (3, 1, 1),  # kernel, stride, padding
        (3, 1, 1),
        (3, 1, 1),
        (3, 1, 1),
        (3, 1, 1),
        (3, 2, 1),  # dilated conv acts like larger kernel
        (3, 1, 1),
        (3, 4, 1),  # dilated conv
        (3, 1, 1),
        (3, 8, 1),  # dilated conv
    ]

    for k, s, _ in layers_info:
        rf = rf + (k - 1) * stride
        stride = stride * s

    return rf


def print_model_summary(model):
    """Print model summary"""
    num_params = count_parameters(model)
    print(f"\n{'='*50}")
    print(f"Model: CIFAR-10 Advanced CNN")
    print(f"{'='*50}")
    print(f"Total Parameters: {num_params:,}")
    print(f"Receptive Field: {model.get_receptive_field()}")
    print(f"Architecture: C1-C2-C3-C4-GAP-FC")
    print(f"Depthwise Separable Conv: ✓ (in C2)")
    print(f"Dilated Convolution: ✓ (in C3 and C4)")
    print(f"Global Average Pooling: ✓")
    print(f"Target Accuracy: 85%")
    print(f"Parameter Limit: 200,000")
    print(f"Parameter Check: {'✓ PASS' if num_params < 200000 else '✗ FAIL'}")
    print(f"{'='*50}\n")

In [None]:
    if torch.backends.mps.is_available():
        DEVICE = 'mps'
        print("Using MPS (Metal Performance Shaders) device")
    elif torch.cuda.is_available():
        DEVICE = 'cuda'
        print("Using CUDA device")
    else:
        DEVICE = 'cpu'
        print("Using CPU device")

    # Model
    model = CIFAR10_CNN(num_classes=10)
    print_model_summary(model)


    # Trainer
    trainer = Trainer(model)

    # Train
    trainer.train(train_loader, val_loader, NUM_EPOCHS)

    print(f"\n{'='*50}")
    print(f"Training Complete!")
    print(f"Best Validation Accuracy: {trainer.best_accuracy:.2f}%")
    print(f"{'='*50}")

Using CUDA device

Model: CIFAR-10 Advanced CNN
Total Parameters: 1,963,786
Receptive Field: 45
Architecture: C1-C2-C3-C4-GAP-FC
Depthwise Separable Conv: ✓ (in C2)
Dilated Convolution: ✓ (in C3 and C4)
Global Average Pooling: ✓
Target Accuracy: 85%
Parameter Limit: 200,000
Parameter Check: ✗ FAIL


Epoch 1/50


Training: 100%|██████████| 391/391 [02:34<00:00,  2.53it/s, loss=1.6, acc=40.4]
Validation: 100%|██████████| 79/79 [00:10<00:00,  7.63it/s, loss=1.35, acc=51.1]


Train Loss: 1.6040, Train Acc: 40.45%
Val Loss: 1.3487, Val Acc: 51.12%
Best model saved! Accuracy: 51.12%

Epoch 2/50


Training: 100%|██████████| 391/391 [02:38<00:00,  2.47it/s, loss=1.31, acc=52.4]
Validation: 100%|██████████| 79/79 [00:10<00:00,  7.50it/s, loss=1.12, acc=60.1]


Train Loss: 1.3118, Train Acc: 52.41%
Val Loss: 1.1219, Val Acc: 60.06%
Best model saved! Accuracy: 60.06%

Epoch 3/50


Training: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s, loss=1.2, acc=57]
Validation: 100%|██████████| 79/79 [00:10<00:00,  7.47it/s, loss=1.21, acc=57.3]


Train Loss: 1.1985, Train Acc: 56.96%
Val Loss: 1.2116, Val Acc: 57.33%

Epoch 4/50


Training: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s, loss=1.14, acc=59.1]
Validation: 100%|██████████| 79/79 [00:10<00:00,  7.52it/s, loss=1.03, acc=63.6]


Train Loss: 1.1368, Train Acc: 59.13%
Val Loss: 1.0337, Val Acc: 63.63%
Best model saved! Accuracy: 63.63%

Epoch 5/50


Training: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s, loss=1.09, acc=61.2]
Validation: 100%|██████████| 79/79 [00:10<00:00,  7.54it/s, loss=1.24, acc=59.9]


Train Loss: 1.0893, Train Acc: 61.23%
Val Loss: 1.2430, Val Acc: 59.85%

Epoch 6/50


Training: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s, loss=1.04, acc=63.1]
Validation: 100%|██████████| 79/79 [00:10<00:00,  7.52it/s, loss=0.944, acc=67]


Train Loss: 1.0401, Train Acc: 63.07%
Val Loss: 0.9439, Val Acc: 66.96%
Best model saved! Accuracy: 66.96%

Epoch 7/50


Training: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s, loss=0.985, acc=65.2]
Validation: 100%|██████████| 79/79 [00:10<00:00,  7.53it/s, loss=0.829, acc=70.6]


Train Loss: 0.9846, Train Acc: 65.17%
Val Loss: 0.8286, Val Acc: 70.58%
Best model saved! Accuracy: 70.58%

Epoch 8/50


Training: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s, loss=0.938, acc=66.7]
Validation: 100%|██████████| 79/79 [00:10<00:00,  7.52it/s, loss=0.817, acc=72.1]


Train Loss: 0.9383, Train Acc: 66.74%
Val Loss: 0.8170, Val Acc: 72.08%
Best model saved! Accuracy: 72.08%

Epoch 9/50


Training: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s, loss=0.897, acc=68.3]
Validation: 100%|██████████| 79/79 [00:10<00:00,  7.51it/s, loss=0.808, acc=72.5]


Train Loss: 0.8967, Train Acc: 68.31%
Val Loss: 0.8076, Val Acc: 72.53%
Best model saved! Accuracy: 72.53%

Epoch 10/50


Training: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s, loss=0.863, acc=69.7]
Validation: 100%|██████████| 79/79 [00:10<00:00,  7.56it/s, loss=0.684, acc=76.3]


Train Loss: 0.8634, Train Acc: 69.72%
Val Loss: 0.6835, Val Acc: 76.35%
Best model saved! Accuracy: 76.35%

Epoch 11/50


Training: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s, loss=0.839, acc=70.7]
Validation: 100%|██████████| 79/79 [00:10<00:00,  7.57it/s, loss=0.689, acc=76.3]


Train Loss: 0.8388, Train Acc: 70.72%
Val Loss: 0.6889, Val Acc: 76.29%

Epoch 12/50


Training: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s, loss=0.799, acc=71.8]
Validation: 100%|██████████| 79/79 [00:10<00:00,  7.58it/s, loss=0.709, acc=76]


Train Loss: 0.7990, Train Acc: 71.79%
Val Loss: 0.7088, Val Acc: 76.05%

Epoch 13/50


Training: 100%|██████████| 391/391 [02:39<00:00,  2.45it/s, loss=0.784, acc=72.5]
Validation: 100%|██████████| 79/79 [00:10<00:00,  7.55it/s, loss=0.638, acc=77.1]


Train Loss: 0.7844, Train Acc: 72.50%
Val Loss: 0.6376, Val Acc: 77.10%
Best model saved! Accuracy: 77.10%

Epoch 14/50


Training: 100%|██████████| 391/391 [02:39<00:00,  2.46it/s, loss=0.763, acc=73.1]
Validation: 100%|██████████| 79/79 [00:10<00:00,  7.56it/s, loss=0.692, acc=76.4]


Train Loss: 0.7632, Train Acc: 73.13%
Val Loss: 0.6925, Val Acc: 76.44%

Epoch 15/50


Training: 100%|██████████| 391/391 [02:39<00:00,  2.46it/s, loss=0.746, acc=73.9]
Validation: 100%|██████████| 79/79 [00:10<00:00,  7.57it/s, loss=0.596, acc=80]


Train Loss: 0.7463, Train Acc: 73.90%
Val Loss: 0.5957, Val Acc: 80.03%
Best model saved! Accuracy: 80.03%

Epoch 16/50


Training:  59%|█████▉    | 232/391 [01:34<01:04,  2.47it/s, loss=0.434, acc=74.5]