In [1]:
import torch

print(torch.__version__)

from sklearn.metrics import classification_report

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [3]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),  # (H,W) â†’ (1,H,W) + normalize [0,1]
])

train_dataset = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=transform
)

test_dataset = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

In [4]:
import torch.nn as nn


class MNIST_CNN(nn.Module):

    def __init__(self):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Dropout(0.25))

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 10))

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


model = MNIST_CNN().to(device)
print(model)

MNIST_CNN(
  (features): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Dropout(p=0.25, inplace=False)
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=3136, out_features=128, bias=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=10, bias=True)
  )
)


In [5]:
import torch.optim as optim


criterion = nn.CrossEntropyLoss()  # same as sparse_categorical_crossentropy
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [6]:

epochs = 2

for epoch in range(epochs):
    model.train()
    running_loss = 0.0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()  # TF: handled automatically
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()  # BACKPROP (explicit!)
        optimizer.step()  # UPDATE WEIGHTS

        running_loss += loss.item()

    print(f"Epoch [{epoch + 1}/{epochs}] Loss: {running_loss / len(train_loader):.4f}")

Epoch [1/2] Loss: 2.2594
Epoch [2/2] Loss: 1.5468


In [7]:

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = torch.argmax(outputs, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

In [8]:

print("Classification Report:")
print(classification_report(all_labels, all_preds))

# -----------------------
# SAVE MODEL
# -----------------------
torch.save(model.state_dict(), "data/MNIST/mnist_cnn.pth")
print("Model saved.")

Classification Report:
              precision    recall  f1-score   support

           0       0.87      0.91      0.89       980
           1       0.93      0.93      0.93      1135
           2       0.88      0.83      0.85      1032
           3       0.87      0.70      0.77      1010
           4       0.80      0.82      0.81       982
           5       0.92      0.52      0.66       892
           6       0.89      0.87      0.88       958
           7       0.80      0.88      0.84      1028
           8       0.55      0.86      0.67       974
           9       0.78      0.75      0.76      1009

    accuracy                           0.81     10000
   macro avg       0.83      0.81      0.81     10000
weighted avg       0.83      0.81      0.81     10000

Model saved.
