In [None]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

In [None]:
from torchvision.datasets import MNIST
import torchvision.transforms as tf

path = "mnist"

trans = tf.Compose([tf.ToTensor(), torch.flatten])

train_ds = MNIST(path, train=True, download=True, transform=trans)
test_ds = MNIST(path, train=False, download=True, transform=trans)

train_ds, val_ds = data.random_split(train_ds, [0.8, 0.2])

In [None]:
class Net(nn.Module):
    def __init__(self, layer_sizes, act=nn.ReLU()) -> None:
        super(Net, self).__init__()
        
        self.l1 = nn.Linear(layer_sizes[0], layer_sizes[1])
        self.l2 = nn.Linear(layer_sizes[1], layer_sizes[2])
        self.output = nn.Linear(layer_sizes[2], layer_sizes[3])
        self.act = act
    
    def forward(self, x):
        v = self.act(self.l1(x))
        v = self.act(self.l2(v))
        v = self.act(self.output(v))

        return v

In [None]:
torch.cuda.is_available()

In [None]:
LEARNING_RATE = 1e-3
EPOCHS = 10
BATCH_SIZE = 128

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# model = Net((28 * 28, 256, 128, 10)).to(device)
model = nn.Sequential(
    nn.Linear(28 * 28, 256),
    nn.ReLU(),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
).to(device)
cost = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_loss = []
train_acc = []
val_loss = []
val_acc = []

for epoch in tqdm(range(EPOCHS), 'Epoch'):
    epoch_loss = []
    epoch_acc = []

    train_loader = data.DataLoader(train_ds, batch_size=BATCH_SIZE)

    for x, y in train_loader:
        model.train()
        y = y.to(device)

        # Compute prediction and loss
        pred = model(x.to(device))

        loss = cost(pred, y)

        # Backpropagation
        opt.zero_grad()
        loss.backward()
        opt.step()

        model.eval()

        acc = (pred.detach().squeeze().argmax(dim=1) == y).float().mean()
    
        epoch_loss.append(loss.detach())
        epoch_acc.append(acc)
    
    val_loader = data.DataLoader(val_ds, batch_size=len(val_ds))
    x, y = next(iter(val_loader))
    y = y.to(device)

    pred = model(x.to(device))

    loss = cost(pred, y).detach()
    acc = (pred.detach().squeeze().argmax(dim=1) == y).float().mean()

    val_loss.append(loss.cpu())
    val_acc.append(acc.cpu())

    train_loss.append(torch.tensor(epoch_loss).mean())
    train_acc.append(torch.tensor(epoch_acc).mean())


In [None]:
model.eval()

test_loader = data.DataLoader(test_ds, batch_size=len(test_ds))
x, y = next(iter(test_loader))

output = model(x.to(device)).detach().squeeze()

acc = (output.argmax(dim=1) == y.to(device)).float().mean()
print('test accuracy =', acc)

In [None]:
plt.subplot(1, 2, 1)
plt.title("Loss")
plt.plot(train_loss, label="Training")
plt.plot(val_loss, label="Validation")
plt.legend()

plt.subplot(1, 2, 2)
plt.title("Accuracy")
plt.plot(train_acc, label="Training")
plt.plot(val_acc, label="Validation")
plt.legend()
plt.show()