In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

batch_size = 64
lr = 0.001
epochs = 20

train_transform = transforms.Compose([
    transforms.RandomRotation(5),
    transforms.RandomAffine(0, translate=(0.05,0.05)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=train_transform)
test_dataset  = datasets.MNIST(root='./data', train=False, download=True, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Using device: cpu


In [36]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,12 , kernel_size = 3,padding = 0)
        self.bn1 = nn.BatchNorm2d(12)
        self.conv2 = nn.Conv2d(12,16, kernel_size = 3,padding = 0)
        self.bn2 = nn.BatchNorm2d(16)
        self.conv3 = nn.Conv2d(16,20 ,kernel_size = 3,padding = 0)
        self.bn3 = nn.BatchNorm2d(20)
        self.conv4 = nn.Conv2d(20,28 ,kernel_size = 3,padding = 0)
        self.bn4 = nn.BatchNorm2d(28)
        self.conv5 = nn.Conv2d(28,10 ,kernel_size = 3,padding = 0)
        self.bn5 = nn.BatchNorm2d(10)
        
        
        self.pool = nn.MaxPool2d(2,2)
        self.gap = nn.AdaptiveAvgPool2d((1,1))
        self.relu = nn.ReLU()

        self.dropout = nn.Dropout(0.1)
    def forward(self,x):
        x = (self.bn1(self.conv1(x)))
        x = self.relu(x)
        x = self.pool(torch.relu(self.bn2(self.conv2(x))))
        x = (self.bn3(self.conv3(x)))
        x = self.relu(x)
        x = self.dropout(x)
        x = (self.bn4(self.conv4(x)))
        x = self.relu(x)
        x = self.dropout(x)
        x = (self.bn5(self.conv5(x)))
        x = self.relu(x)
        x = self.dropout(x)
        x = self.gap(x)
        x = torch.flatten(x,1)
        return x

model = CNN().to(device)
print(model)

def count_params(layer):
    return sum(p.numel() for p in layer.parameters())

print("Conv1:", count_params(model.conv1))
print("Conv2:", count_params(model.conv2))
print("Conv3:", count_params(model.conv3))
print("Conv4:", count_params(model.conv4))
print("Conv5:", count_params(model.conv5))

total = sum(p.numel() for p in model.parameters())
print("Total Parameters:", total)


CNN(
  (conv1): Conv2d(1, 12, kernel_size=(3, 3), stride=(1, 1))
  (bn1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(12, 16, kernel_size=(3, 3), stride=(1, 1))
  (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(16, 20, kernel_size=(3, 3), stride=(1, 1))
  (bn3): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(20, 28, kernel_size=(3, 3), stride=(1, 1))
  (bn4): BatchNorm2d(28, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): Conv2d(28, 10, kernel_size=(3, 3), stride=(1, 1))
  (bn5): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (gap): AdaptiveAvgPool2d(output_size=(1, 1))
  (relu): ReLU()
  (dropout): Dropout(p=0.1, inplace=False)
)
Conv1: 120
Conv2: 1744
Conv3: 2900
Conv4: 5068
Conv5

What is cosine annealing?
What is step decay?
- in the final layer if the number of feature map were not exactly 10 than there could be an issue in the accuracy of the model as gap would convert feature map into that numbers such that their mean is zero
10 degree of data augmenation is applied to increase accuracy
-

In [37]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr = lr)
schedular = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=epochs,eta_min = 1e-5)

for epoch in range(epochs):
    model.train()
    total_loss = 0.0
    
    pbar = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{epochs}]")

    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        pbar.set_postfix(
            loss=f"{loss.item():.4f}",
            lr=f"{optimizer.param_groups[0]['lr']:.6f}"
        )

    schedular.step()


Epoch [1/20]: 100%|██████████| 938/938 [01:09<00:00, 13.55it/s, loss=0.3159, lr=0.001000]
Epoch [2/20]: 100%|██████████| 938/938 [01:11<00:00, 13.13it/s, loss=0.2794, lr=0.000994]
Epoch [3/20]: 100%|██████████| 938/938 [01:13<00:00, 12.76it/s, loss=0.0575, lr=0.000976]
Epoch [4/20]: 100%|██████████| 938/938 [01:13<00:00, 12.82it/s, loss=0.1097, lr=0.000946]
Epoch [5/20]: 100%|██████████| 938/938 [01:12<00:00, 12.92it/s, loss=0.0599, lr=0.000905]
Epoch [6/20]: 100%|██████████| 938/938 [01:10<00:00, 13.23it/s, loss=0.1173, lr=0.000855]
Epoch [7/20]: 100%|██████████| 938/938 [01:10<00:00, 13.36it/s, loss=0.1350, lr=0.000796]
Epoch [8/20]: 100%|██████████| 938/938 [01:11<00:00, 13.13it/s, loss=0.0479, lr=0.000730]
Epoch [9/20]: 100%|██████████| 938/938 [01:07<00:00, 13.81it/s, loss=0.1836, lr=0.000658]
Epoch [10/20]: 100%|██████████| 938/938 [01:12<00:00, 12.95it/s, loss=0.0797, lr=0.000582]
Epoch [11/20]: 100%|██████████| 938/938 [01:13<00:00, 12.83it/s, loss=0.0609, lr=0.000505]
Epoch [1

In [38]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")

Test Accuracy: 99.44%
