In [29]:
import torch
from torchvision.datasets.mnist import MNIST
from torch.utils.data.dataloader import DataLoader
from torchvision.transforms import v2
import torch.nn.functional as F
import torch.nn as nn
from matplotlib import pyplot as plt


# constants
batch_size = 16
picture_size = (28, 28)
picture_vector_size = 28 * 28
hidden_unit_size = 128
number_of_label_classes = 10
epochs = 8
learning_rate = 0.01

In [30]:
# preparing data

feature_transforms = v2.Compose([
    v2.Resize(picture_size),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Lambda(lambda x: x.reshape(-1)) # squish 2d matrices into vectors
])

dataset_train = MNIST('./resources', train=True, download=True, transform=feature_transforms)
dataset_test = MNIST('./resources', train=False, download=True, transform=feature_transforms)

dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
dataloader_test = DataLoader(dataset_test, batch_size=batch_size)

In [31]:
# defining the model class

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(picture_vector_size, hidden_unit_size),
            nn.ReLU(),
            nn.Linear(hidden_unit_size, number_of_label_classes)
        )
    
    def forward(self, X):
        result = self.model(X)
        return result

In [32]:
# train & test loop

train_losses = [1]
test_losses = [1]
accuracy = [0]

def train(model, loss_fn, optimizer, dataloader):
    total_loss = 0

    for i, (X, y) in enumerate(dataloader):
        optimizer.zero_grad()

        pred = model(X)
        loss = loss_fn(pred, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    train_losses.append(total_loss / len(dataloader))


def test(model, loss_fn, dataloader):
    with torch.no_grad():
        full_loss = 0
        correct_predictions = 0

        for i, (X, y) in enumerate(dataloader):
            pred = model(X)
            loss = loss_fn(pred, y)
            full_loss += loss
            correct_predictions += torch.sum(torch.argmax(pred, dim=1) == y).item()

        test_losses.append(full_loss / len(dataloader))
        accuracy.append(correct_predictions / len(dataloader.dataset))
 

In [None]:
model = Model()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

for i in range(epochs):
    print(f'EPOCH {i + 1}')
    train(model, loss_fn, optimizer, dataloader_train)
    test(model, loss_fn, dataloader_test)
    print(f'Train loss: {train_losses[-1]}')
    print(f'Test loss: {test_losses[-1]}')
    print(f'Accuracy: {accuracy[-1]}')

EPOCH 1
Train loss: 0.6208378638803959
Test loss: 0.3248015344142914
Accuracy: 0.9103
EPOCH 2
Train loss: 0.3063327048972249
Test loss: 0.2684134542942047
Accuracy: 0.9258
EPOCH 3
Train loss: 0.2572227278428773
Test loss: 0.2346692830324173
Accuracy: 0.9316
EPOCH 4
Train loss: 0.22509540668961903
Test loss: 0.20859172940254211
Accuracy: 0.9404
EPOCH 5
Train loss: 0.20036781322223446
Test loss: 0.18647386133670807
Accuracy: 0.9446
EPOCH 6
Train loss: 0.18066809093852837
Test loss: 0.17214424908161163
Accuracy: 0.9509
EPOCH 7


In [None]:
plt.plot(train_losses, label='train_loss')
plt.plot(test_losses,label='val_loss')
plt.plot(accuracy, label='accuracy')
plt.legend()
plt.show()