# PyTorch MLP Implementation

This notebook demonstrates the PyTorch implementation of a Multi-layer Perceptron for MNIST digit classification, including training, validation, and evaluation.


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
import sys
import os

# Add src to path
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), 'src'))

from pytorch_mlp import MNISTCustomDataset, CustomMLP, train, validate, report_accuracy, compute_confusion_matrix

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


In [None]:
# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Create datasets
data_dir = '../data'
train_dataset = MNISTCustomDataset(os.path.join(data_dir, 'training'), transform=transform)
val_dataset = MNISTCustomDataset(os.path.join(data_dir, 'validation'), transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Number of classes: 10")
print(f"Image size: 28x28 pixels")
print(f"Input features: 784 (28*28)")


In [None]:
# Create model
model = CustomMLP(input_size=784, output_size=10, dropout_rate=0.1)
model.to(device)

# Define loss and optimizer
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

print("Model Architecture:")
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")


In [None]:
# Training loop
num_epochs = 10
train_losses = []
val_losses = []
val_accuracies = []

print(f"Training for {num_epochs} epochs...")
print("Epoch\tTrain Loss\tVal Loss\tVal Acc")
print("-" * 45)

for epoch in range(num_epochs):
    train_loss = train(model, train_loader, optimizer, loss_function, device)
    val_loss, val_acc = validate(model, val_loader, loss_function, device)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    
    print(f"{epoch+1}\t{train_loss:.4f}\t\t{val_loss:.4f}\t\t{val_acc:.2f}%")

print(f"\nTraining completed!")
print(f"Best validation accuracy: {max(val_accuracies):.2f}%")


In [None]:
# Visualize training progress
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Plot training and validation loss
epochs = range(1, num_epochs + 1)
ax1.plot(epochs, train_losses, label='Training Loss', marker='o', linewidth=2)
ax1.plot(epochs, val_losses, label='Validation Loss', marker='s', linewidth=2)
ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot validation accuracy
ax2.plot(epochs, val_accuracies, label='Validation Accuracy', marker='o', linewidth=2, color='green')
ax2.set_title('Validation Accuracy', fontsize=14, fontweight='bold')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


In [None]:
# Generate confusion matrix
model.eval()
all_predictions = []
all_labels = []

with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        predictions = model(images)
        _, predicted = torch.max(predictions.data, 1)
        
        all_predictions.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

confusion_matrix = compute_confusion_matrix(all_predictions, all_labels)

# Plot confusion matrix
plt.figure(figsize=(10, 8))
plt.imshow(confusion_matrix, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix', fontsize=16, fontweight='bold')
plt.colorbar()
tick_marks = np.arange(10)
plt.xticks(tick_marks, tick_marks)
plt.yticks(tick_marks, tick_marks)
plt.ylabel('True Label')
plt.xlabel('Predicted Label')

# Add text annotations
thresh = confusion_matrix.max() / 2.
for i, j in np.ndindex(confusion_matrix.shape):
    plt.text(j, i, format(confusion_matrix[i, j], 'd'),
            ha="center", va="center",
            color="white" if confusion_matrix[i, j] > thresh else "black")

plt.tight_layout()
plt.show()

# Calculate per-class accuracy
print("\nPer-class accuracy:")
for i in range(10):
    class_correct = confusion_matrix[i, i]
    class_total = confusion_matrix[i, :].sum()
    accuracy = class_correct / class_total * 100
    print(f"Class {i}: {accuracy:.2f}% ({class_correct}/{class_total})")
