<a href="https://colab.research.google.com/github/vikrampathare/ERA_Session05/blob/main/ERA_Session05.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        dropout_value = 0.05

        # Block 1 - Light feature extraction
        self.conv1 = nn.Conv2d(1, 10, 3, padding=1)  # 28x28x10, RF=3
        self.bn1 = nn.BatchNorm2d(10)
        self.dropout1 = nn.Dropout(dropout_value)

        self.conv2 = nn.Conv2d(10, 16, 3, padding=1)  # 28x28x16, RF=5
        self.bn2 = nn.BatchNorm2d(16)
        self.dropout2 = nn.Dropout(dropout_value)

        # Transition Block 1 - Reduce spatial dimensions
        self.pool1 = nn.MaxPool2d(2, 2)  # 14x14x16, RF=6
        self.conv1x1_1 = nn.Conv2d(16, 10, 1)  # 14x14x10, RF=6

        # Block 2 - Deep feature extraction
        self.conv3 = nn.Conv2d(10, 16, 3, padding=1)  # 14x14x16, RF=10
        self.bn3 = nn.BatchNorm2d(16)
        self.dropout3 = nn.Dropout(dropout_value)

        self.conv4 = nn.Conv2d(16, 16, 3, padding=1)  # 14x14x16, RF=14
        self.bn4 = nn.BatchNorm2d(16)
        self.dropout4 = nn.Dropout(dropout_value)

        # Transition Block 2
        self.pool2 = nn.MaxPool2d(2, 2)  # 7x7x16, RF=16
        self.conv1x1_2 = nn.Conv2d(16, 10, 1)  # 7x7x10, RF=16

        # Block 3 - Final feature extraction
        self.conv5 = nn.Conv2d(10, 16, 3, padding=1)  # 7x7x16, RF=24
        self.bn5 = nn.BatchNorm2d(16)
        self.dropout5 = nn.Dropout(dropout_value)

        self.conv6 = nn.Conv2d(16, 16, 3, padding=1)  # 7x7x16, RF=32
        self.bn6 = nn.BatchNorm2d(16)
        self.dropout6 = nn.Dropout(dropout_value)

        # Output Block with GAP
        self.conv7 = nn.Conv2d(16, 10, 1)  # 7x7x10, RF=32
        self.gap = nn.AdaptiveAvgPool2d(1)  # 1x1x10 - Global Average Pooling

    def forward(self, x):
        # Block 1
        x = self.dropout1(self.bn1(F.relu(self.conv1(x))))
        x = self.dropout2(self.bn2(F.relu(self.conv2(x))))

        # Transition 1
        x = self.pool1(x)
        x = F.relu(self.conv1x1_1(x))

        # Block 2
        x = self.dropout3(self.bn3(F.relu(self.conv3(x))))
        x = self.dropout4(self.bn4(F.relu(self.conv4(x))))

        # Transition 2
        x = self.pool2(x)
        x = F.relu(self.conv1x1_2(x))

        # Block 3
        x = self.dropout5(self.bn5(F.relu(self.conv5(x))))
        x = self.dropout6(self.bn6(F.relu(self.conv6(x))))

        # Output
        x = self.conv7(x)
        x = self.gap(x)
        x = x.view(-1, 10)

        return F.log_softmax(x, dim=1)

# Setup
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(f"Device: {device}")

# Model initialization and summary
model = Net().to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Install and use torchsummary if available
try:
    from torchsummary import summary
    summary(model, input_size=(1, 28, 28))
except:
    print("torchsummary not available, skipping summary")

# Set random seed for reproducibility
torch.manual_seed(1)

# Data loaders with augmentation for training
batch_size = 128
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

# Enhanced data augmentation for training
train_transform = transforms.Compose([
    transforms.RandomRotation((-7.0, 7.0), fill=0),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Standard normalization for testing
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True, transform=train_transform),
    batch_size=batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=test_transform),
    batch_size=batch_size, shuffle=False, **kwargs)

# Training function
from tqdm import tqdm

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    pbar = tqdm(train_loader)
    train_loss = 0
    correct = 0
    processed = 0

    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)

        train_loss += loss.item()

        loss.backward()
        optimizer.step()

        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        processed += len(data)

        pbar.set_description(desc=f'Epoch={epoch} Loss={loss.item():.4f} Batch_id={batch_idx} Accuracy={100*correct/processed:.2f}%')

# Testing function
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)

    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)\n')

    return accuracy

# Training with learning rate scheduling
model = Net().to(device)

# Using SGD with momentum and learning rate scheduling
optimizer = optim.SGD(model.parameters(), lr=0.015, momentum=0.9, weight_decay=0.0001)

# Learning rate scheduler - ReduceLROnPlateau or StepLR
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=0.015,
    epochs=19,
    steps_per_epoch=len(train_loader),
    pct_start=0.2,
    div_factor=10,
    final_div_factor=100
)

# Training loop
best_accuracy = 0
epochs = 19

for epoch in range(1, epochs + 1):
    print(f'\nEpoch {epoch}/{epochs}')
    train(model, device, train_loader, optimizer, epoch)
    accuracy = test(model, device, test_loader)

    # Update scheduler after each epoch
    if epoch <= 18:  # OneCycleLR handles its own stepping
        for _ in range(len(train_loader)):
            scheduler.step()

    if accuracy > best_accuracy:
        best_accuracy = accuracy
        print(f'New best accuracy: {best_accuracy:.2f}%')

    if accuracy >= 99.4:
        print(f'Target accuracy of 99.4% reached at epoch {epoch}!')
        break

print(f'\nFinal best accuracy: {best_accuracy:.2f}%')
print(f'Total parameters: {total_params:,}')

Device: cuda
Total parameters: 9,798
Trainable parameters: 9,798
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 10, 28, 28]             100
       BatchNorm2d-2           [-1, 10, 28, 28]              20
           Dropout-3           [-1, 10, 28, 28]               0
            Conv2d-4           [-1, 16, 28, 28]           1,456
       BatchNorm2d-5           [-1, 16, 28, 28]              32
           Dropout-6           [-1, 16, 28, 28]               0
         MaxPool2d-7           [-1, 16, 14, 14]               0
            Conv2d-8           [-1, 10, 14, 14]             170
            Conv2d-9           [-1, 16, 14, 14]           1,456
      BatchNorm2d-10           [-1, 16, 14, 14]              32
          Dropout-11           [-1, 16, 14, 14]               0
           Conv2d-12           [-1, 16, 14, 14]           2,320
      BatchNorm2d-13           [-1, 16

Epoch=1 Loss=0.6395 Batch_id=468 Accuracy=55.80%: 100%|██████████| 469/469 [00:22<00:00, 21.28it/s]



Test set: Average loss: 0.6716, Accuracy: 8470/10000 (84.70%)

New best accuracy: 84.70%

Epoch 2/19


Epoch=2 Loss=0.1060 Batch_id=468 Accuracy=93.90%: 100%|██████████| 469/469 [00:20<00:00, 22.64it/s]



Test set: Average loss: 0.1255, Accuracy: 9700/10000 (97.00%)

New best accuracy: 97.00%

Epoch 3/19


Epoch=3 Loss=0.0444 Batch_id=468 Accuracy=96.83%: 100%|██████████| 469/469 [00:21<00:00, 22.20it/s]



Test set: Average loss: 0.0918, Accuracy: 9748/10000 (97.48%)

New best accuracy: 97.48%

Epoch 4/19


Epoch=4 Loss=0.0968 Batch_id=468 Accuracy=97.62%: 100%|██████████| 469/469 [00:21<00:00, 22.10it/s]



Test set: Average loss: 0.0560, Accuracy: 9832/10000 (98.32%)

New best accuracy: 98.32%

Epoch 5/19


Epoch=5 Loss=0.0514 Batch_id=468 Accuracy=97.97%: 100%|██████████| 469/469 [00:21<00:00, 21.96it/s]



Test set: Average loss: 0.0608, Accuracy: 9811/10000 (98.11%)


Epoch 6/19


Epoch=6 Loss=0.0275 Batch_id=468 Accuracy=98.20%: 100%|██████████| 469/469 [00:21<00:00, 21.97it/s]



Test set: Average loss: 0.0426, Accuracy: 9886/10000 (98.86%)

New best accuracy: 98.86%

Epoch 7/19


Epoch=7 Loss=0.0431 Batch_id=468 Accuracy=98.39%: 100%|██████████| 469/469 [00:21<00:00, 21.90it/s]



Test set: Average loss: 0.0354, Accuracy: 9898/10000 (98.98%)

New best accuracy: 98.98%

Epoch 8/19


Epoch=8 Loss=0.0315 Batch_id=468 Accuracy=98.44%: 100%|██████████| 469/469 [00:21<00:00, 21.98it/s]



Test set: Average loss: 0.0385, Accuracy: 9878/10000 (98.78%)


Epoch 9/19


Epoch=9 Loss=0.0969 Batch_id=468 Accuracy=98.60%: 100%|██████████| 469/469 [00:21<00:00, 22.23it/s]



Test set: Average loss: 0.0368, Accuracy: 9886/10000 (98.86%)


Epoch 10/19


Epoch=10 Loss=0.1025 Batch_id=468 Accuracy=98.64%: 100%|██████████| 469/469 [00:20<00:00, 22.52it/s]



Test set: Average loss: 0.0319, Accuracy: 9901/10000 (99.01%)

New best accuracy: 99.01%

Epoch 11/19


Epoch=11 Loss=0.0567 Batch_id=468 Accuracy=98.77%: 100%|██████████| 469/469 [00:20<00:00, 22.73it/s]



Test set: Average loss: 0.0267, Accuracy: 9912/10000 (99.12%)

New best accuracy: 99.12%

Epoch 12/19


Epoch=12 Loss=0.0397 Batch_id=468 Accuracy=98.83%: 100%|██████████| 469/469 [00:20<00:00, 22.68it/s]



Test set: Average loss: 0.0259, Accuracy: 9918/10000 (99.18%)

New best accuracy: 99.18%

Epoch 13/19


Epoch=13 Loss=0.0596 Batch_id=468 Accuracy=98.86%: 100%|██████████| 469/469 [00:21<00:00, 22.19it/s]



Test set: Average loss: 0.0268, Accuracy: 9913/10000 (99.13%)


Epoch 14/19


Epoch=14 Loss=0.0279 Batch_id=468 Accuracy=98.93%: 100%|██████████| 469/469 [00:21<00:00, 21.99it/s]



Test set: Average loss: 0.0262, Accuracy: 9914/10000 (99.14%)


Epoch 15/19


Epoch=15 Loss=0.0584 Batch_id=468 Accuracy=99.03%: 100%|██████████| 469/469 [00:21<00:00, 22.12it/s]



Test set: Average loss: 0.0225, Accuracy: 9929/10000 (99.29%)

New best accuracy: 99.29%

Epoch 16/19


Epoch=16 Loss=0.0468 Batch_id=468 Accuracy=99.10%: 100%|██████████| 469/469 [00:21<00:00, 22.26it/s]



Test set: Average loss: 0.0211, Accuracy: 9938/10000 (99.38%)

New best accuracy: 99.38%

Epoch 17/19


Epoch=17 Loss=0.0791 Batch_id=468 Accuracy=99.12%: 100%|██████████| 469/469 [00:20<00:00, 22.34it/s]



Test set: Average loss: 0.0218, Accuracy: 9928/10000 (99.28%)


Epoch 18/19


Epoch=18 Loss=0.0271 Batch_id=468 Accuracy=99.24%: 100%|██████████| 469/469 [00:21<00:00, 22.30it/s]



Test set: Average loss: 0.0203, Accuracy: 9933/10000 (99.33%)


Epoch 19/19


Epoch=19 Loss=0.0019 Batch_id=468 Accuracy=99.18%: 100%|██████████| 469/469 [00:20<00:00, 22.80it/s]



Test set: Average loss: 0.0197, Accuracy: 9933/10000 (99.33%)


Final best accuracy: 99.38%
Total parameters: 9,798
