# Dynamic Simple MLP implementation

In [11]:
from typing import List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

class MLP(nn.Module):
    def __init__(self, inputs: int, layer_sizes: List[int], outputs: int):
        super().__init__()
        self.input = nn.Linear(inputs, layer_sizes[0])
        
        # create a model with hidden layers with the given sizes and ReLU
        layers = []
        for i in range(len(layer_sizes) - 1):
            layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))
            layers.append(nn.ReLU())
        self.hidden = nn.Sequential(*layers)
        
        self.output = nn.Linear(layer_sizes[-1], outputs)

    def forward(self, x):
        x1 = self.input(x)
        x2 = self.hidden(x1)
        return self.output(x2)
        
    
mnist_train = datasets.MNIST(root="./datasets", train=True, transform=transforms.ToTensor(), download=True)
mnist_test = datasets.MNIST(root="./datasets", train=False, transform=transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=100, shuffle=False)

model = MLP(784, [500, 128, 64], 10)

# Training
epochs = 20
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

for epoch in range(epochs):
    print(f"Epoch {epoch+1}")
    for images, labels in tqdm(train_loader):
        optimizer.zero_grad()
        
        x = images.view(-1, 28*28)
        y = model(x) # weird syntax but whatever

        loss(y, labels).backward()
        optimizer.step()

    # Testing
    correct = 0
    total = len(mnist_test)

    with torch.no_grad():
        for images, labels in tqdm(test_loader):
            x = images.view(-1, 28*28)
            y = model(x)

            predictions = torch.argmax(y, dim=1)
            correct += torch.sum((predictions == labels).float())
        
    print(f"Test acc: {correct/total}")

Epoch 1


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9180999994277954
Epoch 2


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9376000165939331
Epoch 3


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9531999826431274
Epoch 4


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9560999870300293
Epoch 5


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9643999934196472
Epoch 6


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9660000205039978
Epoch 7


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9702000021934509
Epoch 8


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9686999917030334
Epoch 9


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9695000052452087
Epoch 10


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9722999930381775
Epoch 11


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9757999777793884
Epoch 12


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9753999710083008
Epoch 13


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9753999710083008
Epoch 14


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9764000177383423
Epoch 15


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9768000245094299
Epoch 16


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9760000109672546
Epoch 17


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.978600025177002
Epoch 18


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9753000140190125
Epoch 19


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.9771000146865845
Epoch 20


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Test acc: 0.977400004863739
