# Dynamic Simple MLP implementation

In [3]:
from typing import List
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

class MLP(nn.Module):
    def __init__(self, inputs: int, layer_sizes: List[int], outputs: int):
        super().__init__()
        self.input = nn.Linear(inputs, layer_sizes[0])
        
        # create a model with hidden layers with the given sizes and ReLU
        layers = []
        for i in range(len(layer_sizes) - 1):
            layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))
            layers.append(nn.ReLU())
        self.hidden = nn.Sequential(*layers)
        
        self.output = nn.Linear(layer_sizes[-1], outputs)

    def forward(self, x):
        x1 = self.input(x)
        x2 = self.hidden(x1)
        return self.output(x2)
        
    
def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    elif torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")

device = get_device()
print(f"Using device: {device}")

# Load the data
mnist_train = datasets.MNIST(root="./datasets", 
                             train=True, 
                             transform=transforms.ToTensor(), 
                             download=True)
mnist_test = datasets.MNIST(root="./datasets", 
                            train=False, 
                            transform=transforms.ToTensor(), 
                            download=True)

# Split train set into train (90%) and validation (10%)
train_size = int(0.9 * len(mnist_train))
val_size = len(mnist_train) - train_size
# If this causes errors, restart the kernel lmao
mnist_train, mnist_val = random_split(mnist_train, [train_size, val_size])

train_loader = DataLoader(mnist_train, 
                          batch_size=100, 
                          shuffle=True)
val_loader = DataLoader(mnist_val,
                        batch_size=100, 
                        shuffle=True)
test_loader = DataLoader(mnist_test, 
                         batch_size=100, 
                         shuffle=False)


model = MLP(784, [500, 128, 64], 10).to(device)

# Training
epochs = 20
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

for epoch in range(epochs):
    print(f"Epoch {epoch+1}")
    for images, labels in tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        
        x = images.view(-1, 28*28)
        y = model(x) # weird syntax but whatever

        loss(y, labels).backward()
        optimizer.step()

    # Testing
    correct = 0
    total = len(mnist_test)

    with torch.no_grad():
        for images, labels in tqdm(test_loader):
            images, labels = images.to(device), labels.to(device)
            x = images.view(-1, 28*28)
            y = model(x)

            predictions = torch.argmax(y, dim=1)
            correct += torch.sum((predictions == labels).float())
        
    print(f"Test acc: {correct/total}")

Using device: mps
Epoch 1


  0%|          | 0/540 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9009000062942505
Epoch 2


  0%|          | 0/540 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9294000267982483
Epoch 3


  0%|          | 0/540 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9470999836921692
Epoch 4


  0%|          | 0/540 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9538000226020813
Epoch 5


  0%|          | 0/540 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9581000208854675
Epoch 6


  0%|          | 0/540 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9627000093460083
Epoch 7


  0%|          | 0/540 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9664999842643738
Epoch 8


  0%|          | 0/540 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9656999707221985
Epoch 9


  0%|          | 0/540 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.97079998254776
Epoch 10


  0%|          | 0/540 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9700999855995178
Epoch 11


  0%|          | 0/540 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9715999960899353
Epoch 12


  0%|          | 0/540 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9713000059127808
Epoch 13


  0%|          | 0/540 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9708999991416931
Epoch 14


  0%|          | 0/540 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9740999937057495
Epoch 15


  0%|          | 0/540 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9731000065803528
Epoch 16


  0%|          | 0/540 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9729999899864197
Epoch 17


  0%|          | 0/540 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9732000231742859
Epoch 18


  0%|          | 0/540 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9740999937057495
Epoch 19


  0%|          | 0/540 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9751999974250793
Epoch 20


  0%|          | 0/540 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9731000065803528
