In [None]:
import numpy as np
import torch

In [None]:
np.random.seed(42)
torch.manual_seed(42)

In [None]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

datapath = 'data'

# load dataset
data_train = MNIST(
    root = datapath,
    train = True,                         
    transform = ToTensor(), 
)
data_test = MNIST(
    root = datapath, 
    train = False, 
    transform = ToTensor(),
)

In [None]:
# hyperparameters
batch_size = 16
lr = 1e-3
n_epochs = 10

In [None]:
print("Size of training dataset:", len(data_train))
print("Size of test dataset:", len(data_test))

In [None]:
x_sample, y_sample = data_train[0]

print("Shape of x_sample:", x_sample.shape)
print("y_sample is an integer:", type(y_sample))

In [None]:
import matplotlib.pyplot as plt

plt.figure()
plt.imshow(x_sample[0], cmap='binary')
plt.title(f'Label: {y_sample: d}')
plt.xticks([])
plt.yticks([])
plt.show()

In [None]:
import torch.nn as nn

class SimpleMLP(nn.Module):

    """
    - input_shape: shape of a single input data point
    """
    def __init__(self, input_shape, n_classes):
        super(SimpleMLP, self).__init__()
        self.input_shape = np.asarray(input_shape)
        self.n_classes = n_classes
        self.seq_model = nn.Sequential(
            nn.Linear(self.input_shape.prod(), 128),
            nn.ReLU(),
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Linear(32, self.n_classes)
        )

    def forward(self, x):
        x = x.view(-1, self.input_shape.prod())  # make the input of shape (batch_size, height*weight)
        logits = self.seq_model(x)
        return logits
    
    # def __repr__(self):
    #     return "Overwritten print"


In [None]:
model = SimpleMLP(input_shape=(1, 28, 28), n_classes=10)
print(model)

In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(data_test, batch_size=32, shuffle=False)

In [None]:
x_batch, y_batch = next(iter(train_loader))
print("x_batch shape:", x_batch.shape)
print("y_batch shape:", y_batch.shape)

In [None]:
pred_batch = model(x_batch)
print("Example of model's logits shape:", pred_batch.shape)

In [None]:
def model_accuracy(data_loader):
    n_total = 0
    n_correct = 0

    for x_batch, y_batch in data_loader:
        logits_batch = model(x_batch)  # model's output scores
        n_total += len(y_batch)
        n_correct += sum(logits_batch.argmax(axis=-1) == y_batch)
    return (n_correct / n_total).item()

print(f"Train accuracy before training: {model_accuracy(train_loader):.4f}")
print(f"Test accuracy before training: {model_accuracy(test_loader):.4f}")

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
from tqdm import tqdm

for _ in tqdm(range(n_epochs)):
    for x_batch, y_batch in train_loader:

        optimizer.zero_grad()
        logits_batch = model(x_batch)
        loss_batch = loss_fn(logits_batch, y_batch)
        loss_batch.backward()
        optimizer.step()

In [None]:
print(f"Train accuracy before training: {model_accuracy(train_loader):.4f}")
print(f"Test accuracy before training: {model_accuracy(test_loader):.4f}")

In [None]:
torch.save(model.state_dict(), "saved_models/MLP.pt")