In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Step 1: Load the MNIST dataset
# We'll use the torchvision package to download and preprocess the MNIST dataset.
# Transformations:
# - Convert the images to tensors (grayscale images of size 28x28)
# - Normalize the pixel values to be between 0 and 1

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Load the train and test datasets
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

# DataLoader is used to handle batching, shuffling, and parallel loading of data
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

# Step 2: Define the CNN architecture
# A CNN typically consists of convolutional layers (to learn spatial hierarchies), 
# followed by pooling layers (to downsample the data), and finally fully connected layers for classification.

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # Convolutional layer 1
        # Input: 1 channel (grayscale image), Output: 16 channels, Kernel size: 3x3
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
        # Convolutional layer 2
        # Input: 16 channels, Output: 32 channels, Kernel size: 3x3
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        
        # Max pooling layer
        # Reduces the size of the feature maps by taking the max value from 2x2 patches
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Fully connected (dense) layer 1
        # After pooling, the feature map size is reduced to 32 channels of 7x7
        # Therefore, the input size to the fully connected layer is 32 * 7 * 7 = 1568
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        # Fully connected layer 2 (output layer)
        # This maps the 128 features to the 10 classes (0-9 digits)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        # Apply first convolutional layer and ReLU
        x = torch.relu(self.conv1(x))
        # Apply max pooling after the first convolution
        x = self.pool(x)
        
        # Apply second convolutional layer and ReLU
        x = torch.relu(self.conv2(x))
        # Apply max pooling after the second convolution
        x = self.pool(x)
        
        # Flatten the tensor into a vector for fully connected layers
        x = x.view(x.size(0), -1)  # Keep the batch size intact, flatten the rest
        # Pass through the fully connected layers
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x


# Step 3: Instantiate the model, define the loss function and the optimizer
# CrossEntropyLoss is used for multi-class classification problems. It combines softmax and negative log-likelihood loss.
model = CNN()
criterion = nn.CrossEntropyLoss()  # loss function
optimizer = optim.Adam(model.parameters(), lr=0.001)  # optimizer with a learning rate of 0.001

# Step 4: Train the model
# The training loop involves several epochs. In each epoch, we pass the data through the model, compute the loss,
# backpropagate the error, and update the weights using the optimizer.


num_epochs = 5
for epoch in range(num_epochs):
    running_loss = 0.0
    for images, labels in train_loader:
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass: compute the model's predictions
        outputs = model(images)
        # Compute the loss
        loss = criterion(outputs, labels)
        # Backward pass: compute gradients
        loss.backward()
        # Update the model parameters
        optimizer.step()

        # Print statistics (average loss over the batch)
        running_loss += loss.item()
    
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}')
    running_loss = 0.0

# Step 5: Test the model
# After training, we evaluate the model's performance on the test dataset (which it hasn't seen before).
# The model will be put into evaluation mode, meaning no gradients will be computed.

model.eval()  # set the model to evaluation mode
correct = 0
total = 0
with torch.no_grad():  # no need to calculate gradients
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)  # get the class with the highest predicted score
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

# Calculate and print the accuracy
accuracy = 100 * correct / total
print(f'Accuracy on the test set: {accuracy:.2f}%')

# Step 6: Save the trained model (optional)
torch.save(model.state_dict(), 'mnist_cnn.pth')


Epoch [1/5], Loss: 0.2203
Epoch [2/5], Loss: 0.0569
Epoch [3/5], Loss: 0.0394
Epoch [4/5], Loss: 0.0310
Epoch [5/5], Loss: 0.0237
Accuracy on the test set: 99.11%
