In [6]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms as T 
from sklearn.metrics import confusion_matrix


transform = T.Compose([T.ToTensor(), T.Normalize((0.5,), (0.5,))])  # Renamed to transform


train = datasets.MNIST('.', train=True, download=True, transform=transform)
test = datasets.MNIST('.', train=False, download=True, transform=transform)

train_loader = DataLoader(train, batch_size=64, shuffle=True)
test_loader = DataLoader(test, batch_size=64, shuffle=False)


class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        # Convolutional layers
        self.net = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d((2, 2), stride=2),
            nn.Conv2d(64, 128, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d((2, 2), stride=2),
            nn.Conv2d(128, 64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d((2, 2), stride=2)
        )
        
        
        self._dummy_input = torch.zeros(1, 1, 28, 28)  
        self.flattened_size = self._get_flattened_size()

        
        self.classify_head = nn.Sequential(
            nn.Linear(self.flattened_size, 20, bias=True),
            nn.ReLU(),
            nn.Linear(20, 10, bias=True)
        )

    def _get_flattened_size(self):
        
        x = self.net(self._dummy_input)
        return x.numel()  

    def forward(self, x):
       
        x = self.net(x)
        
        
        x = x.view(x.size(0), -1)  
        return self.classify_head(x)


model = CNN()


criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)


for epoch in range(10):
    model.train()
    running_loss = 0.0
    for input, target in train_loader:
        optimizer.zero_grad()  
        output = model(input)  
        loss = criterion(output, target) 
        loss.backward()
        optimizer.step() 
        running_loss += loss.item()
    print(f'Epoch - {epoch}, Loss = {running_loss}')


model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for input, target in test_loader:
        output = model(input)
        val, index = torch.max(output, 1)
        all_preds.extend(index.cpu().numpy())
        all_labels.extend(target.cpu().numpy())


cm = confusion_matrix(all_labels, all_preds)
print(cm)


num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of learnable parameters: {num_params}')


Epoch - 0, Loss = 2157.709214448929
Epoch - 1, Loss = 2145.9824142456055
Epoch - 2, Loss = 2130.759788751602
Epoch - 3, Loss = 2106.668748855591
Epoch - 4, Loss = 2061.107239484787
Epoch - 5, Loss = 1970.9392185211182
Epoch - 6, Loss = 1824.2634164094925
Epoch - 7, Loss = 1578.1991910934448
Epoch - 8, Loss = 1198.2527633309364
Epoch - 9, Loss = 874.4469019770622
[[ 912    0    0    1    2   21   33    6    2    3]
 [   0 1096   27    4    0    0    2    1    5    0]
 [  13   39  656  116   19   10    9   38  108   24]
 [   1    2  106  760    0   14    0   24   74   29]
 [   1    0    3    0  804   10   78    2    2   82]
 [  26   20   24  101   32  368   90   33   35  163]
 [  20    7    0    0   94   12  821    0    4    0]
 [   4   16   50    2    9    0    0  887   28   32]
 [   2   17   77   46   14   62   19   10  610  117]
 [   9    2    1    6   55   30    7   28   30  841]]
Number of learnable parameters: 149798
