# MNIST with Wandb

In [1]:
# Cell 1: Import libraries and initialize W&B
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader, random_split
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
import wandb

# Initialize W&B with hyperparameters
wandb.init(project="mnist-cnn", 
           name="mnist-cnn-test2",
           tags=["mnist-cnn", "test2"],
           config={
    "learning_rate": 0.001,
    "epochs": 5,
    "batch_size": 64,
    "val_split": 0.2,
    "architecture": "SimpleCNN"
})

[34m[1mwandb[0m: Currently logged in as: [33mjourneyofbabo[0m ([33mjourneyofbabo-hanyang-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
# Cell 2: Setup device and data transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Simple data augmentation and normalization
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

Using device: cpu


In [3]:
# Cell 3: Load and split dataset
# Load MNIST dataset
full_train_dataset = datasets.MNIST(root="data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="data", train=False, download=True, transform=transform)

# Create train/validation split
val_split = wandb.config.val_split
val_size = int(len(full_train_dataset) * val_split)
train_size = len(full_train_dataset) - val_size

train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])
print(f'Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}')


Train: 48000, Val: 12000, Test: 10000


In [4]:
# Cell 4: Create data loaders
batch_size = wandb.config.batch_size
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
# Cell 5: Define the CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        
        # Pooling and activations
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        
        # Fully connected layers
        self.fc1 = nn.Linear(64 * 3 * 3, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = x.view(-1, 64 * 3 * 3)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [6]:
# Cell 6: Initialize model, loss, and optimizer
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=wandb.config.learning_rate)

# Track model with W&B
wandb.watch(model, criterion, log="all", log_freq=100)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

Model parameters: 130,890


In [7]:
# Cell 7: Training function with real-time metrics
def train_epoch(model, train_loader, criterion, optimizer, epoch):
    model.train()
    running_loss = 0.0
    running_correct = 0
    running_total = 0
    
    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        # Calculate batch metrics
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        running_total += labels.size(0)
        running_correct += (predicted == labels).sum().item()
        
        # Log every 50 batches for real-time visualization
        if i % 50 == 49:
            avg_loss = running_loss / 50
            avg_acc = 100 * running_correct / running_total
            
            wandb.log({
                "batch_train_loss": avg_loss,
                "batch_train_acc": avg_acc,
                "batch": epoch * len(train_loader) + i
            })
            
            running_loss = 0.0
            running_correct = 0
            running_total = 0
    
    return loss.item()

In [8]:
# Cell 8: Validation function with metrics
def validate(model, val_loader, criterion, epoch):
    model.eval()
    val_loss = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate metrics
    avg_val_loss = val_loss / len(val_loader)
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')
    
    # Log validation metrics
    wandb.log({
        "val_loss": avg_val_loss,
        "val_accuracy": accuracy * 100,
        "val_f1_score": f1,
        "epoch": epoch
    })
    
    return avg_val_loss, accuracy, f1

In [9]:
# Cell 9: Training loop
print("Starting training...")
best_val_acc = 0.0

for epoch in range(wandb.config.epochs):
    # Train
    train_loss = train_epoch(model, train_loader, criterion, optimizer, epoch)
    
    # Validate
    val_loss, val_acc, val_f1 = validate(model, val_loader, criterion, epoch)
    
    print(f'Epoch {epoch+1}/{wandb.config.epochs} - '
          f'Train Loss: {train_loss:.4f}, '
          f'Val Loss: {val_loss:.4f}, '
          f'Val Acc: {val_acc*100:.2f}%, '
          f'Val F1: {val_f1:.4f}')
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_model.pth')
        wandb.save('best_model.pth')

Starting training...
Epoch 1/5 - Train Loss: 0.0681, Val Loss: 0.0717, Val Acc: 97.89%, Val F1: 0.9788
Epoch 2/5 - Train Loss: 0.1401, Val Loss: 0.0623, Val Acc: 97.96%, Val F1: 0.9795
Epoch 3/5 - Train Loss: 0.1916, Val Loss: 0.0413, Val Acc: 98.88%, Val F1: 0.9887
Epoch 4/5 - Train Loss: 0.1003, Val Loss: 0.0357, Val Acc: 98.99%, Val F1: 0.9899
Epoch 5/5 - Train Loss: 0.0213, Val Loss: 0.0354, Val Acc: 99.03%, Val F1: 0.9903


In [10]:
# Cell 10: Test evaluation
print("\nEvaluating on test set...")
model.eval()
test_preds = []
test_labels = []

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        test_preds.extend(predicted.cpu().numpy())
        test_labels.extend(labels.cpu().numpy())

# Calculate test metrics
test_accuracy = accuracy_score(test_labels, test_preds)
test_f1 = f1_score(test_labels, test_preds, average='macro')

print(f"Test Accuracy: {test_accuracy*100:.2f}%")
print(f"Test F1 Score: {test_f1:.4f}")

# Log test results
wandb.log({
    "test_accuracy": test_accuracy * 100,
    "test_f1_score": test_f1
})


Evaluating on test set...
Test Accuracy: 99.15%
Test F1 Score: 0.9914


In [11]:
# Cell 11: Log confusion matrix and sample predictions
# Confusion matrix
cm = confusion_matrix(test_labels, test_preds)
wandb.log({
    "test_confusion_matrix": wandb.plot.confusion_matrix(
        probs=None,
        y_true=test_labels,
        preds=test_preds,
        class_names=[str(i) for i in range(10)]
    )
})

In [12]:
# Cell 12: Visualize sample predictions
def log_predictions(model, test_dataset, num_samples=16):
    model.eval()
    indices = np.random.choice(len(test_dataset), num_samples, replace=False)
    images = []
    
    for idx in indices:
        image, label = test_dataset[idx]
        image_input = image.unsqueeze(0).to(device)
        
        with torch.no_grad():
            output = model(image_input)
            probs = torch.softmax(output, dim=1)
            confidence, predicted = torch.max(probs, 1)
            pred_label = predicted.item()
        
        # Create wandb Image with caption
        caption = f"True: {label}, Pred: {pred_label} ({confidence.item():.2f})"
        images.append(wandb.Image(image.squeeze().numpy(), caption=caption))
    
    wandb.log({"predictions": images})

log_predictions(model, test_dataset)




In [13]:
# Cell 13: Final summary and cleanup
wandb.summary.update({
    "best_val_accuracy": best_val_acc * 100,
    "final_test_accuracy": test_accuracy * 100,
    "final_test_f1": test_f1,
    "total_parameters": sum(p.numel() for p in model.parameters())
})

print("\nTraining complete!")
print(f"Best validation accuracy: {best_val_acc*100:.2f}%")
print(f"Final test accuracy: {test_accuracy*100:.2f}%")
print(f"Final test F1 score: {test_f1:.4f}")

wandb.finish()

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.



Training complete!
Best validation accuracy: 99.03%
Final test accuracy: 99.15%
Final test F1 score: 0.9914


0,1
batch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
batch_train_acc,▁▆▇▇▇███████████████████████████████████
batch_train_loss,█▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▃▅▆█
test_accuracy,▁
test_f1_score,▁
val_accuracy,▁▁▇██
val_f1_score,▁▁▇██
val_loss,█▆▂▁▁

0,1
batch,3749.0
batch_train_acc,98.8125
batch_train_loss,0.04963
best_val_accuracy,99.03333
epoch,4.0
final_test_accuracy,99.15
final_test_f1,0.99143
test_accuracy,99.15
test_f1_score,0.99143
total_parameters,130890.0
