In [None]:
# !pip install torch torchvision onnxscript onnx

In [None]:
# Install first (run in terminal if not done):
# pip install torch torchvision matplotlib onnx onnxscript

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# Step 1: Load MNIST dataset
transform = transforms.ToTensor()
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)

# Step 2: Define a simple MLP
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = MLP()

# Step 3: Define loss & optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Step 4: Train for 1 epoch (keep it short)
for images, labels in train_loader:
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
print("✅ Training done!")

# Step 5: Test on a few images
test_images, test_labels = next(iter(train_loader))
with torch.no_grad():
    preds = model(test_images)
    predicted = preds.argmax(dim=1)

# Step 6: Show results
plt.figure(figsize=(8, 8))
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.imshow(test_images[i].squeeze(), cmap='gray')
    plt.title(f"Pred: {predicted[i].item()}, True: {test_labels[i].item()}")
    plt.axis('off')
plt.show()

# Step 7: Export to ONNX
dummy_input = torch.randn(1, 1, 28, 28)
torch.onnx.export(model, dummy_input, "mlp_mnist.onnx", input_names=['input'], output_names=['output'])
print("💾 Model exported to mlp_mnist.onnx")
