# 01. Tiny Image Classifier

This notebook implements a minimal image classifier using a simple fully connected neural network to classify MNIST digits.

## Experiment Overview
- **Goal**: Classify MNIST digits using a minimal fully connected network
- **Model**: 2-layer MLP (784 → 128 → 10)
- **Features**: Basic training loop, accuracy tracking, confusion matrix
- **Learning**: Understanding basic neural network training and evaluation

## What You'll Learn
- How to load and preprocess image data
- Building a simple neural network architecture
- Training loops with PyTorch
- Model evaluation and visualization


In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import sys
import os

# Add scripts directory to path
sys.path.append('../scripts')
from utils import load_mnist_data, plot_training_history, plot_confusion_matrix, get_device, set_seed
from train import train_model
from evaluate import evaluate_model

# Set random seed for reproducibility
set_seed(42)

# Get device
device = get_device()
print(f"Using device: {device}")


In [None]:
# Define the Tiny Image Classifier model
class TinyImageClassifier(nn.Module):
    def __init__(self, input_size=784, hidden_size=128, num_classes=10):
        super(TinyImageClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, num_classes)
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, x):
        # Flatten the input
        x = x.view(x.size(0), -1)
        
        # First layer with ReLU activation
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        
        # Output layer
        x = self.fc2(x)
        return x

# Create model instance
model = TinyImageClassifier().to(device)

# Print model summary
print("Model Architecture:")
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model size: {sum(p.numel() for p in model.parameters()) * 4 / 1024 / 1024:.2f} MB")


In [None]:
# Load MNIST dataset
print("Loading MNIST dataset...")
train_loader, val_loader, test_loader = load_mnist_data(batch_size=64, test_split=0.2)

print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")

# Visualize some training samples
fig, axes = plt.subplots(2, 5, figsize=(12, 6))
for i in range(10):
    row, col = i // 5, i % 5
    # Get a batch and show first sample
    data, target = next(iter(train_loader))
    axes[row, col].imshow(data[0].squeeze(), cmap='gray')
    axes[row, col].set_title(f'Label: {target[0].item()}')
    axes[row, col].axis('off')
plt.tight_layout()
plt.show()


In [None]:
# Train the model
print("Starting training...")
trainer = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    task_type='classification',
    epochs=10,
    lr=0.001,
    device=device,
    save_dir='../results/logs/tiny_classifier'
)

# Plot training history
trainer.plot_training_history(save_path='../results/plots/tiny_classifier_training.png')


In [None]:
# Evaluate the model on test set
print("Evaluating on test set...")
results = evaluate_model(
    model=model,
    data_loader=test_loader,
    task_type='classification',
    device=device,
    save_dir='../results/plots/tiny_classifier'
)

# Show some predictions
model.eval()
with torch.no_grad():
    # Get a batch of test data
    data, target = next(iter(test_loader))
    data, target = data.to(device), target.to(device)
    output = model(data)
    pred = output.argmax(dim=1)
    
    # Visualize predictions
    fig, axes = plt.subplots(2, 5, figsize=(12, 6))
    for i in range(10):
        row, col = i // 5, i % 5
        axes[row, col].imshow(data[i].cpu().squeeze(), cmap='gray')
        axes[row, col].set_title(f'True: {target[i].item()}, Pred: {pred[i].item()}')
        axes[row, col].axis('off')
    plt.tight_layout()
    plt.show()

print(f"\nFinal Test Accuracy: {results['accuracy']:.4f}")
print(f"Final Test F1-Score: {results['f1_score']:.4f}")
