In [2]:
import torch
from torchvision import datasets
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim

## Data

In [3]:
train_data = datasets.MNIST(
    root='data/mnist',
    train=True,
    transform=transforms.ToTensor(),
    download=True,
)

test_data = datasets.MNIST(
    root='data/mnist',
    train=False,
    transform=transforms.ToTensor(),
    download=True,
)

In [4]:
batch_size = 100

train_loader = torch.utils.data.DataLoader(
    dataset=train_data,
    batch_size=batch_size,
    shuffle=True,
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_data,
    batch_size=batch_size,
    shuffle=False,
)

## Network Structure

In [5]:
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MLP, self).__init__()

        self.flatten = nn.Flatten()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        output = self.flatten(x)
        output = self.fc1(output)
        output = self.relu(output)
        output = self.fc2(output)
        output = self.relu(output)
        output = self.fc3(output)

        return output

In [6]:
input_size = 28 * 28
hidden_size = 1512
num_classes = 10

In [7]:
model = MLP(input_size, hidden_size, num_classes)

# Loss Function

In [8]:
loss_fn = nn.CrossEntropyLoss()

# Optimizer

In [9]:
learning_rate = 0.001
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

## Training

In [10]:
num_epochs = 10

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        pred = model(images)
        loss = loss_fn(pred, labels)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}')

Epoch [1/10], Step [100/600], Loss: 0.2178
Epoch [1/10], Step [200/600], Loss: 0.1722
Epoch [1/10], Step [300/600], Loss: 0.2339
Epoch [1/10], Step [400/600], Loss: 0.1604
Epoch [1/10], Step [500/600], Loss: 0.1409
Epoch [1/10], Step [600/600], Loss: 0.0691
Epoch [2/10], Step [100/600], Loss: 0.0561
Epoch [2/10], Step [200/600], Loss: 0.0411
Epoch [2/10], Step [300/600], Loss: 0.0668
Epoch [2/10], Step [400/600], Loss: 0.1042
Epoch [2/10], Step [500/600], Loss: 0.1444


KeyboardInterrupt: 

## Test

In [29]:
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        pred = model(images)
        # Gets the index of the largest prediction value in logits
        _, predicted = torch.max(pred.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f"Accuracy of the network on 10K test images: {100 * correct / total}")

Accuracy of the network on 10K test images: 97.97


## Save Model