# Model Definition

In [None]:
import torch
import torch.nn as nn
from torchtyping import TensorType

torch.manual_seed(0)

class DigitRecognition(nn.Module):
    def __init__(self):
        super().__init__()
        self.first_layer = nn.Linear(784, 512)
        self.reLU = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        self.projection = nn.Linear(512, 10)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, images: TensorType[float]) -> TensorType[float]:
        return self.sigmoid(
                self.projection(
                    self.dropout(
                        self.reLU(
                            self.first_layer(images)
                        )
                    )
                )
            )

# MNIST Data Loading and Preprocessing

In [None]:
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

# Convert images to tensor and normalize
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize((0.5,), (0.5,))
])
# Download and load the MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST('./data', train=False, download=True, transform=transform)
# Create a DataLoader to handle batching
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=False)

# Training Loop

In [None]:
model = DigitRecognition()
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

# Training loop
epochs = 5
for epoch in range(epochs):
    model.train() # Set the model to training mode
    running_loss = 0.0
    for images, labels in train_dataloader:
        images = images.view(images.shape[0], 784) # Flatten images from 28x28 to 784
        # Forward pass
        model_prediction = model(images)
        # Compute loss
        loss = loss_function(model_prediction, labels)
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_dataloader)}")

# Testing Loop

In [None]:
import matplotlib.pyplot as plt

# Set model to evaluation mode
model.eval()

correct = 0  # Variable to track correct predictions
total = 0    # Variable to track total number of images

# Iterate over the test set and make predictions
for images, labels in test_dataloader:
    images = images.view(images.shape[0], 784) # Flatten images from 28x28 to 784

    # Get model predictions
    model_prediction = model(images)
    # Find the predicted class (max probability)
    _, idx = torch.max(model_prediction, dim=1)

    # Update correct and total
    correct += (idx == labels).sum().item()
    total += labels.size(0)
    # Display each image and its predicted class
    for i in range(len(images)):
        plt.imshow(images[i].view(28, 28).detach().numpy(), cmap='gray') # Reshape back to 28x28 and plot
        plt.show()
        print(f'Predicted label: {idx[i].item()}')
    
    break # Display only the first batch (for brevity)

# Calculate and print accuracy
accuracy = 100 * correct / total
print(f'Accuracy: {accuracy:.2f}%')

# ONNX Export

In [None]:
# Create a dummy input that matches the input size the model expects (e.g., a 28x28 grayscale image)
dummy_input = torch.randn(1, 784)  # Batch size of 1, 1 channel, 28x28 image

# Export the model to an ONNX file
torch.onnx.export(model, dummy_input, "digit_recognition_model.onnx", verbose=True)