In [1]:
# !pip install torch torchvision
# !pip install wandb
# !pip install nbformat

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
# hyperparameters

batch_size = 64
learning_rate = 0.001
epochs = 30

In [3]:
import wandb
# Start a new wandb run to track this script.
run = wandb.init(
    project="mnist-basic",
    config={
        "learning_rate": learning_rate,
        "batch_size": batch_size,
        "architecture": "CNN",
        "dataset": "MNIST",
        "epochs": epochs,
    },
)




[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mzeri[0m ([33mzeri-university-of-michigan[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
transform = transforms.ToTensor()

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)


In [5]:

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)


In [6]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
        
    def forward(self, x):
        return self.fc(x)


In [7]:

# 4. Initialize model, loss, and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


In [8]:
from sklearn.metrics import average_precision_score
from sklearn.preprocessing import label_binarize
import numpy as np

for epoch in range(epochs):
    # -------- TRAINING --------
    model.train()
    running_loss = 0
    for batch_idx, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)

        # Forward pass
        scores = model(data)
        loss = criterion(scores, targets)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_train_loss = running_loss / len(train_loader)

    # -------- EVALUATION --------
    model.eval()
    correct = 0
    total = 0
    eval_loss = 0
    all_probs = []
    all_targets = []

    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)

            loss = criterion(outputs, targets)
            eval_loss += loss.item()

            probs = torch.softmax(outputs, dim=1)  # For multi-class
            all_probs.append(probs.cpu())
            all_targets.append(targets.cpu())

            _, predicted = outputs.max(1)
            correct += (predicted == targets).sum().item()
            total += targets.size(0)

    # Convert to NumPy for PR-AUC
    all_probs = torch.cat(all_probs).numpy()          # shape: [N, 10]
    all_targets = torch.cat(all_targets).numpy()      # shape: [N]
    all_targets_bin = label_binarize(all_targets, classes=np.arange(10))  # [N, 10]

    pr_auc = average_precision_score(all_targets_bin, all_probs, average='macro')
    accuracy = 100 * correct / total
    avg_eval_loss = eval_loss / len(test_loader)

    print(f"Epoch [{epoch+1}/{epochs}] "
          f"Train Loss: {avg_train_loss:.4f} "
          f"Val Loss: {avg_eval_loss:.4f} "
          f"Accuracy: {accuracy:.2f}% "
          f"PR-AUC: {pr_auc:.4f}")

    wandb.log({
        "epoch": epoch + 1,
        "train_loss": avg_train_loss,
        "val_loss": avg_eval_loss,
        "val_accuracy": accuracy,
        "val_pr_auc": pr_auc
    })

Epoch [1/30] Train Loss: 0.3368 Val Loss: 0.1804 Accuracy: 94.87% PR-AUC: 0.9863
Epoch [2/30] Train Loss: 0.1513 Val Loss: 0.1193 Accuracy: 96.44% PR-AUC: 0.9930
Epoch [3/30] Train Loss: 0.1052 Val Loss: 0.0998 Accuracy: 96.98% PR-AUC: 0.9950
Epoch [4/30] Train Loss: 0.0794 Val Loss: 0.0870 Accuracy: 97.42% PR-AUC: 0.9960
Epoch [5/30] Train Loss: 0.0627 Val Loss: 0.0814 Accuracy: 97.53% PR-AUC: 0.9965
Epoch [6/30] Train Loss: 0.0504 Val Loss: 0.0759 Accuracy: 97.69% PR-AUC: 0.9968
Epoch [7/30] Train Loss: 0.0416 Val Loss: 0.0770 Accuracy: 97.70% PR-AUC: 0.9968
Epoch [8/30] Train Loss: 0.0338 Val Loss: 0.0728 Accuracy: 97.76% PR-AUC: 0.9973
Epoch [9/30] Train Loss: 0.0287 Val Loss: 0.0699 Accuracy: 97.80% PR-AUC: 0.9975
Epoch [10/30] Train Loss: 0.0241 Val Loss: 0.0669 Accuracy: 97.86% PR-AUC: 0.9978
Epoch [11/30] Train Loss: 0.0188 Val Loss: 0.0742 Accuracy: 97.79% PR-AUC: 0.9973
Epoch [12/30] Train Loss: 0.0168 Val Loss: 0.0770 Accuracy: 97.81% PR-AUC: 0.9972
Epoch [13/30] Train Loss:

In [9]:

# 6. Evaluation
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for data, targets in test_loader:
        data, targets = data.to(device), targets.to(device)
        outputs = model(data)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

print(f"Accuracy on test set: {100 * correct / total:.2f}%")

Accuracy on test set: 97.88%


In [12]:
artifact = wandb.Artifact(
    name="mnist-basic",
    type="model",
    metadata={
        "learning_rate": learning_rate,
        "batch_size": batch_size,
        "architecture": "CNN",
        "dataset": "MNIST",
        "epochs": epochs,
    }
)

In [13]:
# save model
torch.save(model.state_dict(), "mnist-basic.pth")
artifact.add_file("mnist-basic.pth")
run.log_artifact(artifact)


<Artifact mnist-basic>

In [14]:
run.finish()

0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
train_loss,█▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▅▆▇▇▇▇▇███████▇█▇███▇██████▇█
val_loss,█▄▃▂▂▂▂▁▁▁▁▂▂▁▂▂▂▂▂▂▃▃▂▃▃▃▃▄▄▃
val_pr_auc,▁▅▆▇▇▇▇█████████████████████▇█

0,1
epoch,30.0
train_loss,0.00518
val_accuracy,97.88
val_loss,0.10672
val_pr_auc,0.99722
