# Multilayer Perceptron

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
# Data preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root="../data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="../data", train=False, download=True, transform=transform)

batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [3]:
# Model architecture
class MLP(nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.fc1 = nn.Linear(28 * 28, 16)
    self.fc2 = nn.Linear(16, 16)
    self.fc3 = nn.Linear(16, 10)

  def forward(self, x):
    x = x.view(-1, 28 * 28)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

In [4]:
# Loss function and optimizer
model = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [5]:
# Train loop
num_epochs = 5

for epoch in range(num_epochs):
  model.train()
  running_loss = 0.0

  for X, y in train_loader:
    optimizer.zero_grad()
    y_pred = model(X)

    # Loss computation
    loss = criterion(y_pred, y)

    # Backpropagation
    loss.backward()

    # Update weights
    optimizer.step()

    running_loss += loss.item()

  avg_loss = running_loss / len(train_loader)
  print(f"Epoch [{epoch + 1}/{num_epochs}]  Loss: {avg_loss:.6f}")

Epoch [1/5]  Loss: 0.468677
Epoch [2/5]  Loss: 0.241734
Epoch [3/5]  Loss: 0.205192
Epoch [4/5]  Loss: 0.183026
Epoch [5/5]  Loss: 0.167910


In [6]:
# Evaluating
model.eval()
correct = 0
total = 0

with torch.no_grad():
  for X, y in test_loader:
    y_pred = model(X)
    _, label = torch.max(y_pred, 1)
    total += y.size(0)
    correct += (label == y).sum().item()

accuracy = 100 * correct / total
print(f"Accuracy: {accuracy:.2f}%")

Accuracy: 94.82%
