In [2]:
import torch
from torch.utils.data import DataLoader
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim
import torchvision 
from torchvision import datasets
from torchvision.transforms import v2
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, f1_score

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

Using device: cpu


In [4]:
# Transformation

transforms = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True)
])

In [5]:
# Load data

batch_size = 32

train_dataset = torchvision.datasets.MNIST(root='../dataset/', train=True, transform=transforms, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = datasets.MNIST(root='../dataset/', train=False, transform=transforms, download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)   

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ../dataset/MNIST/raw/train-images-idx3-ubyte.gz


100.0%


Extracting ../dataset/MNIST/raw/train-images-idx3-ubyte.gz to ../dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ../dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100.0%


Extracting ../dataset/MNIST/raw/train-labels-idx1-ubyte.gz to ../dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ../dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100.0%


Extracting ../dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to ../dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ../dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100.0%

Extracting ../dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../dataset/MNIST/raw






In [6]:
# Define CNN Model

class CNN(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=(3,3), stride=(1,1), padding=(1,1))
        self.pool = nn.MaxPool2d(kernel_size=(2,2), stride=(2,2))
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3,3), stride=(1,1), padding=(1,1))
        self.fc1 = nn.Linear(16*7*7, num_classes) # 28 -> 28 -> 14 -> 14 -> 7 (changes due to maxpool)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x) 
        
        return x 

In [12]:
# Training and validation loop

num_epochs = 5
learning_rate = 0.001
train_losses = []
test_losses = []

model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimiser = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for X, y in train_loader:
        X = X.to(device=device)
        y = y.to(device=device)
        optimiser.zero_grad() 
        pred_y = model(X)
        loss = criterion(pred_y, y) # mean loss per sample
        loss.backward()
        optimiser.step()
        running_loss += loss.item() * X.size(0) 
        # running_loss = total loss per batch, X.size() refers to (batch_size, channels, height, width)
    epoch_loss = running_loss/len(train_loader)
    train_losses.append(epoch_loss)

    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        flag = 0
        for X, y in test_loader:
            X = X.to(device=device)
            y = y.to(device=device)
            pred_y = model(X)
            loss = criterion(pred_y, y)
            running_loss += loss.item() * X.size(0)
            print('=======')
            print(pred_y)
            flag += 1
            if flag == 10:
                break
    break
#             loss = criterion(pred_y, y)
#             running_loss += loss.item() * X.size(0)
#             _, preds = torch.max(pred_y, 1)
#             all_preds.extend(preds.cpu().numpy())
#             all_labels.extend(y.cpu().numpy())
#     epoch_test_loss = running_loss/len(test_loader)
#     test_losses.append(epoch_test_loss)

#     accuracy = accuracy_score(all_labels, all_preds)
#     f1 = f1_score(all_labels, all_preds, average='weighted')

#     print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}, Validation Loss: {epoch_test_loss:.4f}')
#     print(f'Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}')

# plt.figure(figsize=(12, 6))
# plt.plot(train_losses, label='Training Loss')
# plt.plot(test_losses, label='Validation Loss')
# plt.title('Training and Validation Loss')
# plt.xlabel('Epoch')
# plt.ylabel('Loss')
# plt.legend()
# plt.show()

tensor([[  8.0273, -15.8409,  -1.5494,  -8.5430, -11.5556,  -5.9137,  -1.5350,
         -11.0371,  -5.2135,  -9.1996],
        [-11.3506,  -9.7556,  -7.4262,  -6.9818,   3.0399,  -7.0107,  -2.4349,
          -3.2308,  -6.7703,  -1.4435],
        [ -8.5327, -11.6476,  -7.3132,  -5.0934,   1.7724,  -5.6292,  -9.2205,
           6.5495,  -7.0410,   1.7285],
        [ -2.9724,  -8.6680,  -4.2396,  -6.6957,  -0.2765,  -1.8738,   5.2131,
          -4.5363,  -6.4817,  -4.6759],
        [ -8.3989,  -3.6127,   5.5248,  -6.0000,  -3.9111,  -8.4037,  -4.2972,
         -11.0815,  -1.5699,  -9.4974],
        [ -8.4102,  -6.0635,  -5.5775,   5.5975,  -7.0978,  -0.4320,  -8.6124,
         -10.3319,  -7.2493,  -0.8761],
        [-10.2187, -11.9509,   8.6337,  -0.2910,  -9.9659, -10.6290, -14.6443,
          -2.3065,  -4.5835,  -9.0880],
        [ -6.9894,  -6.3538,  -7.4813,  -4.1645,  -6.5769,   1.0124,  -2.7680,
          -9.8549,  -1.5336,  -2.0750],
        [ -6.1198, -15.9746,  15.3845,   0.7075,

In [11]:
print(len(pred_y))

32
