<a href="https://colab.research.google.com/github/shreyasbkgit/dllab/blob/main/Week_6_Transfer_Learningalexnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Week-6: Transfer Learning**

Implement the standard LeNet, AlexNet, VGG CNN architecture model to classify multicategory image dataset.

MNIST handwritten digits (0-9)

Note down accuracies obtained for epochs 5, 50, 250.



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

# 1. Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2. Transform: Resize to 64x64 (safe for AlexNet)
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # Prevent shrinking to < 1x1
    transforms.ToTensor()
])

# 3. Datasets and Loaders
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 4. AlexNet model
class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2),  # 64x64 → 15x15
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),  # 15x15 → 7x7
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),  # → 3x3
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),  # → 1x1
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 1 * 1, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# 5. Train & Evaluate Function
def train_and_evaluate(model, epochs=5):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}: Loss = {total_loss:.4f}")

    # Evaluation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    acc = 100 * correct / total
    print(f"\nTest Accuracy: {acc:.2f}%")
    return acc

# 6. Run Training
alexnet_acc = train_and_evaluate(AlexNet(), epochs=5)
alexnet_acc = train_and_evaluate(AlexNet(), epochs=50)
alexnet_acc = train_and_evaluate(AlexNet(), epochs=250)

Epoch 1/5: 100%|██████████| 938/938 [00:22<00:00, 42.55it/s]


Epoch 1: Loss = 320.6687


Epoch 2/5: 100%|██████████| 938/938 [00:22<00:00, 42.56it/s]


Epoch 2: Loss = 79.7515


Epoch 3/5: 100%|██████████| 938/938 [00:22<00:00, 42.58it/s]


Epoch 3: Loss = 64.2931


Epoch 4/5: 100%|██████████| 938/938 [00:22<00:00, 42.34it/s]


Epoch 4: Loss = 57.3871


Epoch 5/5: 100%|██████████| 938/938 [00:22<00:00, 42.32it/s]


Epoch 5: Loss = 47.5815

Test Accuracy: 98.36%


Epoch 1/50: 100%|██████████| 938/938 [00:22<00:00, 42.35it/s]


Epoch 1: Loss = 372.8656


Epoch 2/50: 100%|██████████| 938/938 [00:21<00:00, 43.22it/s]


Epoch 2: Loss = 80.2481


Epoch 3/50: 100%|██████████| 938/938 [00:22<00:00, 41.90it/s]


Epoch 3: Loss = 67.4557


Epoch 4/50: 100%|██████████| 938/938 [00:22<00:00, 41.94it/s]


Epoch 4: Loss = 56.5227


Epoch 5/50: 100%|██████████| 938/938 [00:22<00:00, 42.27it/s]


Epoch 5: Loss = 57.0791


Epoch 6/50: 100%|██████████| 938/938 [00:22<00:00, 42.21it/s]


Epoch 6: Loss = 45.1618


Epoch 7/50: 100%|██████████| 938/938 [00:22<00:00, 42.14it/s]


Epoch 7: Loss = 44.5469


Epoch 8/50: 100%|██████████| 938/938 [00:21<00:00, 43.04it/s]


Epoch 8: Loss = 43.6144


Epoch 9/50: 100%|██████████| 938/938 [00:22<00:00, 42.41it/s]


Epoch 9: Loss = 36.3009


Epoch 10/50: 100%|██████████| 938/938 [00:22<00:00, 42.23it/s]


Epoch 10: Loss = 36.2373


Epoch 11/50: 100%|██████████| 938/938 [00:22<00:00, 42.53it/s]


Epoch 11: Loss = 36.5869


Epoch 12/50: 100%|██████████| 938/938 [00:22<00:00, 42.51it/s]


Epoch 12: Loss = 31.7676


Epoch 13/50: 100%|██████████| 938/938 [00:22<00:00, 42.34it/s]


Epoch 13: Loss = 33.1017


Epoch 14/50: 100%|██████████| 938/938 [00:21<00:00, 43.13it/s]


Epoch 14: Loss = 25.9079


Epoch 15/50: 100%|██████████| 938/938 [00:22<00:00, 42.56it/s]


Epoch 15: Loss = 34.7617


Epoch 16/50: 100%|██████████| 938/938 [00:22<00:00, 42.27it/s]


Epoch 16: Loss = 30.3798


Epoch 17/50: 100%|██████████| 938/938 [00:22<00:00, 42.37it/s]


Epoch 17: Loss = 31.5316


Epoch 18/50: 100%|██████████| 938/938 [00:22<00:00, 42.11it/s]


Epoch 18: Loss = 22.9156


Epoch 19/50: 100%|██████████| 938/938 [00:22<00:00, 42.30it/s]


Epoch 19: Loss = 31.2753


Epoch 20/50: 100%|██████████| 938/938 [00:21<00:00, 43.39it/s]


Epoch 20: Loss = 27.0406


Epoch 21/50: 100%|██████████| 938/938 [00:22<00:00, 42.54it/s]


Epoch 21: Loss = 31.9065


Epoch 22/50: 100%|██████████| 938/938 [00:22<00:00, 42.60it/s]


Epoch 22: Loss = 22.9147


Epoch 23/50: 100%|██████████| 938/938 [00:22<00:00, 42.59it/s]


Epoch 23: Loss = 42.9524


Epoch 24/50: 100%|██████████| 938/938 [00:22<00:00, 42.44it/s]


Epoch 24: Loss = 24.2085


Epoch 25/50: 100%|██████████| 938/938 [00:22<00:00, 42.51it/s]


Epoch 25: Loss = 14.4450


Epoch 26/50: 100%|██████████| 938/938 [00:21<00:00, 43.16it/s]


Epoch 26: Loss = 24.4704


Epoch 27/50: 100%|██████████| 938/938 [00:22<00:00, 42.44it/s]


Epoch 27: Loss = 23.1457


Epoch 28/50: 100%|██████████| 938/938 [00:22<00:00, 42.62it/s]


Epoch 28: Loss = 30.1636


Epoch 29/50: 100%|██████████| 938/938 [00:22<00:00, 42.47it/s]


Epoch 29: Loss = 15.7952


Epoch 30/50: 100%|██████████| 938/938 [00:22<00:00, 42.62it/s]


Epoch 30: Loss = 24.5001


Epoch 31/50: 100%|██████████| 938/938 [00:21<00:00, 43.26it/s]


Epoch 31: Loss = 40.3891


Epoch 32/50: 100%|██████████| 938/938 [00:21<00:00, 42.79it/s]


Epoch 32: Loss = 18.5086


Epoch 33/50: 100%|██████████| 938/938 [00:21<00:00, 42.75it/s]


Epoch 33: Loss = 19.8140


Epoch 34/50: 100%|██████████| 938/938 [00:22<00:00, 42.63it/s]


Epoch 34: Loss = 29.2025


Epoch 35/50: 100%|██████████| 938/938 [00:21<00:00, 42.78it/s]


Epoch 35: Loss = 17.8540


Epoch 36/50: 100%|██████████| 938/938 [00:21<00:00, 42.73it/s]


Epoch 36: Loss = 25.7005


Epoch 37/50: 100%|██████████| 938/938 [00:21<00:00, 43.47it/s]


Epoch 37: Loss = 13.8000


Epoch 38/50: 100%|██████████| 938/938 [00:21<00:00, 42.67it/s]


Epoch 38: Loss = 15.5388


Epoch 39/50: 100%|██████████| 938/938 [00:21<00:00, 42.67it/s]


Epoch 39: Loss = 40.8766


Epoch 40/50: 100%|██████████| 938/938 [00:22<00:00, 42.49it/s]


Epoch 40: Loss = 14.9283


Epoch 41/50: 100%|██████████| 938/938 [00:21<00:00, 42.77it/s]


Epoch 41: Loss = 43.8014


Epoch 42/50: 100%|██████████| 938/938 [00:21<00:00, 43.43it/s]


Epoch 42: Loss = 16.0564


Epoch 43/50: 100%|██████████| 938/938 [00:21<00:00, 43.14it/s]


Epoch 43: Loss = 21.2239


Epoch 44/50: 100%|██████████| 938/938 [00:21<00:00, 42.94it/s]


Epoch 44: Loss = 23.3794


Epoch 45/50: 100%|██████████| 938/938 [00:21<00:00, 42.72it/s]


Epoch 45: Loss = 14.2913


Epoch 46/50: 100%|██████████| 938/938 [00:21<00:00, 42.84it/s]


Epoch 46: Loss = 42.0066


Epoch 47/50: 100%|██████████| 938/938 [00:21<00:00, 43.20it/s]


Epoch 47: Loss = 12.3589


Epoch 48/50: 100%|██████████| 938/938 [00:21<00:00, 43.54it/s]


Epoch 48: Loss = 20.3667


Epoch 49/50: 100%|██████████| 938/938 [00:21<00:00, 42.97it/s]


Epoch 49: Loss = 19.2877


Epoch 50/50: 100%|██████████| 938/938 [00:21<00:00, 43.05it/s]


Epoch 50: Loss = 33.7893

Test Accuracy: 98.76%


Epoch 1/250: 100%|██████████| 938/938 [00:21<00:00, 42.75it/s]


Epoch 1: Loss = 336.5359


Epoch 2/250: 100%|██████████| 938/938 [00:21<00:00, 42.79it/s]


Epoch 2: Loss = 84.7666


Epoch 3/250: 100%|██████████| 938/938 [00:21<00:00, 43.18it/s]


Epoch 3: Loss = 60.4919


Epoch 4/250: 100%|██████████| 938/938 [00:21<00:00, 43.56it/s]


Epoch 4: Loss = 56.7223


Epoch 5/250: 100%|██████████| 938/938 [00:21<00:00, 42.92it/s]


Epoch 5: Loss = 53.8002


Epoch 6/250: 100%|██████████| 938/938 [00:21<00:00, 42.87it/s]


Epoch 6: Loss = 44.6675


Epoch 7/250: 100%|██████████| 938/938 [00:21<00:00, 42.92it/s]


Epoch 7: Loss = 43.8810


Epoch 8/250: 100%|██████████| 938/938 [00:21<00:00, 42.84it/s]


Epoch 8: Loss = 45.9691


Epoch 9/250: 100%|██████████| 938/938 [00:21<00:00, 43.44it/s]


Epoch 9: Loss = 30.1950


Epoch 10/250: 100%|██████████| 938/938 [00:21<00:00, 42.71it/s]


Epoch 10: Loss = 31.0057


Epoch 11/250: 100%|██████████| 938/938 [00:22<00:00, 42.59it/s]


Epoch 11: Loss = 30.3630


Epoch 12/250: 100%|██████████| 938/938 [00:22<00:00, 42.59it/s]


Epoch 12: Loss = 30.7215


Epoch 13/250: 100%|██████████| 938/938 [00:21<00:00, 42.67it/s]


Epoch 13: Loss = 27.5641


Epoch 14/250: 100%|██████████| 938/938 [00:21<00:00, 43.22it/s]


Epoch 14: Loss = 35.1029


Epoch 15/250: 100%|██████████| 938/938 [00:21<00:00, 42.92it/s]


Epoch 15: Loss = 25.6313


Epoch 16/250: 100%|██████████| 938/938 [00:22<00:00, 42.27it/s]


Epoch 16: Loss = 28.8347


Epoch 17/250: 100%|██████████| 938/938 [00:22<00:00, 42.28it/s]


Epoch 17: Loss = 23.7225


Epoch 18/250: 100%|██████████| 938/938 [00:22<00:00, 41.62it/s]


Epoch 18: Loss = 26.5963


Epoch 19/250: 100%|██████████| 938/938 [00:22<00:00, 42.33it/s]


Epoch 19: Loss = 23.9207


Epoch 20/250: 100%|██████████| 938/938 [00:22<00:00, 41.79it/s]


Epoch 20: Loss = 28.8511


Epoch 21/250: 100%|██████████| 938/938 [00:22<00:00, 41.57it/s]


Epoch 21: Loss = 22.7519


Epoch 22/250: 100%|██████████| 938/938 [00:22<00:00, 41.29it/s]


Epoch 22: Loss = 28.2751


Epoch 23/250:  52%|█████▏    | 488/938 [00:11<00:10, 43.98it/s]