# Code to train an MNIST Classifier

In [1]:
!pip install torchvision

Defaulting to user installation because normal site-packages is not writeable
Collecting torch==2.5.1 (from torchvision)
  Obtaining dependency information for torch==2.5.1 from https://files.pythonhosted.org/packages/d1/35/e8b2daf02ce933e4518e6f5682c72fd0ed66c15910ea1fb4168f442b71c4/torch-2.5.1-cp311-cp311-manylinux1_x86_64.whl.metadata
  Using cached torch-2.5.1-cp311-cp311-manylinux1_x86_64.whl.metadata (28 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.5.1->torchvision)
  Obtaining dependency information for nvidia-cudnn-cu12==9.1.0.70 from https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata
  Using cached nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-nccl-cu12==2.21.5 (from torch==2.5.1->torchvision)
  Obtaining dependency information for nvidia-nccl-cu12==2.21.5 from https://files.pythonhosted.or

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

In [3]:
batch_size = 64
learning_rate = 1e-3
num_epochs = 10

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)) 
])


train_dataset = datasets.MNIST(root='./mnist_data', train=True, transform = transform, download=True)
test_dataset = datasets.MNIST(root='./mnist_data', train=False, transform=transform, download=True)

In [5]:
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=64)
val_loader = DataLoader(test_dataset, batch_size=64)

In [6]:
xtrain, ytrain = next(iter(train_loader))
print(xtrain.shape, ytrain.shape)

torch.Size([64, 1, 28, 28]) torch.Size([64])


# Model

In [17]:
class MNISTClassifier(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 256)
        self.dropout = nn.Dropout(0.1)
        self.fc4 = nn.Linear(256, 128)
        self.fc5 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
    
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc4(x)
        x = self.relu(x)
        x = self.fc5(x)
        return x

# Training Loop

In [18]:
model = MNISTClassifier()
learning_rate = 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, 2, gamma=0.85)

In [19]:
def calculate_num_correct(outputs, labels):
    # outputs.shape: (b,10)
    # labels.shape: (b)
    max_output_vals, max_output_indices = outputs.max(dim=-1)
    correct = torch.where(max_output_indices == labels, 1, 0)
    
    num_correct = torch.sum(correct, dim=-1)
    total = correct.size(0)
    return num_correct, total


def run_training_loop(model, optimizer, train_loader, val_loader, scheduler, num_epochs=10, print_every=200):
    device = next(model.parameters()).device # device that model is stored on
    training_log = {
        'training_loss': [],
        'training_acc': [],
        'validation_loss': [],
        'validation_acc': []
    }
    
    model.train()
    num_iter = 0
    for i in range(num_epochs):
        
        for iter_idx, (xb,yb) in enumerate(train_loader):
            xb, yb = xb.to(device), yb.to(device)
            xb = xb.view(-1, 784)
            out = model(xb) # b, 10
            loss = F.cross_entropy(out, yb)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            num_correct, total = calculate_num_correct(out, yb)
            
            training_log['training_loss'].append((num_iter, loss.item()))
            training_log['training_acc'].append((num_iter, num_correct / total))
            
            if num_iter % print_every == 0:
                print(f'Epoch {i}, iter: {num_iter}: training loss: {training_log["training_loss"][-1][1]}, training acc: {training_log["training_acc"][-1][1]}')
            
            num_iter += 1
        
        
        total_num_correct = 0
        total_val_samples = 0
        total_loss = 0
        num_batches = 0
        
        # run evals
        for iter_idx, (xb, yb) in enumerate(val_loader):
            
            xb, yb = xb.to(device), yb.to(device)
            xb = xb.view(-1, 784)
            out = model(xb)
            loss = F.cross_entropy(out, yb)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            num_iter += 1
            
            num_correct, total = calculate_num_correct(out, yb)
            
            total_num_correct += num_correct
            total_val_samples += total
            total_loss += loss.item()
            
            num_batches += 1

            
            
        training_log['validation_loss'].append((num_iter, loss.item()))
        training_log['validation_acc'].append((num_iter, num_correct / total))
        
        
        scheduler.step()
        
        print(f'Epoch {i}, iter: {num_iter}: validation loss: {training_log["validation_loss"][-1][1]}, validation acc: {training_log["validation_acc"][-1][1]}')

    
    return training_log

In [20]:
training_log = run_training_loop(model, optimizer, train_loader, val_loader, scheduler, num_epochs=10)

Epoch 0, iter: 0: training loss: 2.3021204471588135, training acc: 0.046875
Epoch 0, iter: 200: training loss: 0.26834169030189514, training acc: 0.890625
Epoch 0, iter: 400: training loss: 0.3655937910079956, training acc: 0.90625
Epoch 0, iter: 600: training loss: 0.12318333983421326, training acc: 0.96875
Epoch 0, iter: 800: training loss: 0.3014961779117584, training acc: 0.921875
Epoch 0, iter: 1095: validation loss: 0.11952541768550873, validation acc: 0.9375
Epoch 1, iter: 1200: training loss: 0.22016116976737976, training acc: 0.90625
Epoch 1, iter: 1400: training loss: 0.1932019293308258, training acc: 0.96875
Epoch 1, iter: 1600: training loss: 0.11853666603565216, training acc: 0.953125
Epoch 1, iter: 1800: training loss: 0.04204023629426956, training acc: 0.984375
Epoch 1, iter: 2000: training loss: 0.06824303418397903, training acc: 0.984375
Epoch 1, iter: 2190: validation loss: 0.015101161785423756, validation acc: 1.0
Epoch 2, iter: 2200: training loss: 0.128111883997917

In [21]:
torch.save(model.state_dict(), 'mnist_state_dict.pt')