In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from sklearn.metrics import confusion_matrix

# Define the transform for MNIST dataset
transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, ))])

# Load the dataset
train = datasets.MNIST('.', train=True, download=True, transform=transforms)
test = datasets.MNIST('.', train=False, download=True, transform=transforms)

# Create data loaders
train_loader = DataLoader(train, batch_size=64, shuffle=True)
test_loader = DataLoader(test, batch_size=64, shuffle=False)

# Define the CNN model
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        # Convolutional layers
        self.net = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d((2, 2), stride=2),
            nn.Conv2d(64, 128, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d((2, 2), stride=2),
            nn.Conv2d(128, 64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d((2, 2), stride=2)
        )
        
        # Use a dummy input to calculate the flattened size after convolutions
        self._dummy_input = torch.zeros(1, 1, 28, 28)  # Example input (1 channel, 28x28)
        self.flattened_size = self._get_flattened_size()

        # Fully connected layers
        self.classify_head = nn.Sequential(
            nn.Linear(self.flattened_size, 20, bias=True),
            nn.ReLU(),
            nn.Linear(20, 10, bias=True)
        )

    def _get_flattened_size(self):
        # Forward pass through the convolution layers to get the flattened size
        x = self.net(self._dummy_input)
        return x.numel()  # Number of elements in the output tensor

    def forward(self, x):
        # Pass through the convolutional layers
        x = self.net(x)
        
        # Flatten the output before passing to the fully connected layers
        x = x.view(x.size(0), -1)  # Flatten to (batch_size, flattened_size)
        
        # Pass through the fully connected layers
        return self.classify_head(x)

# Instantiate the model
model = CNN()

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

# Training loop
for epoch in range(10):
    model.train()
    running_loss = 0.0
    for input, target in train_loader:
        optimizer.zero_grad()  # Zero the gradients
        output = model(input)  # Forward pass
        loss = criterion(output, target)  # Calculate the loss
        loss.backward()  # Backward pass
        optimizer.step()  # Update model parameters
        running_loss += loss.item()
    print(f'Epoch - {epoch}, Loss = {running_loss}')

# Evaluation
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for input, target in test_loader:
        output = model(input)
        val, index = torch.max(output, 1)
        all_preds.extend(index.cpu().numpy())
        all_labels.extend(target.cpu().numpy())

# Print confusion matrix
cm = confusion_matrix(all_labels, all_preds)
print(cm)

# Print the number of learnable parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of learnable parameters: {num_params}')


Epoch - 0, Loss = 2158.841808080673
Epoch - 1, Loss = 2145.4184141159058
Epoch - 2, Loss = 2124.737733602524
Epoch - 3, Loss = 2085.4698164463043
Epoch - 4, Loss = 1998.9035980701447
Epoch - 5, Loss = 1767.7134654521942
Epoch - 6, Loss = 1284.176250398159
Epoch - 7, Loss = 847.8750349283218
Epoch - 8, Loss = 613.6570642888546
Epoch - 9, Loss = 483.20047226548195
[[ 935    0    2    0    4   18   16    3    2    0]
 [   0 1095    0    7    0    4    2    1   26    0]
 [   7    8  846   38   17    7    9   29   63    8]
 [   0    7   26  902    0   16    0   31   24    4]
 [   4    1    0    0  885    2   36    3   19   32]
 [  14   24    3   22    6  743   32    9   29   10]
 [  27    5    1    0   51   13  849    0   12    0]
 [   4   13   45   56    1    5    0  851    3   50]
 [   0   15    8   53   12   26   19    4  784   53]
 [  14    2    3   19   31   15    4   22   15  884]]
Number of learnable parameters: 149798
