# Small CNN — MNIST Classification with Observer

In [1]:
import logging
import os
import sys
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

# observer.py lives in the parent directory (neural_network/)
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname("__file__"), "..")))
from observer import Observer, ObserverConfig

## Configuration & Hyperparameters

In [2]:
batch_size = 64
num_epochs = 5
lr = 1e-3
device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)

seed = 42
torch.manual_seed(seed)

print(f"Device: {device}")

Device: mps


## Observer Setup

In [3]:
observer_config = ObserverConfig(
    track_profiler=True,
    profile_every_n_steps=100,  # profile every 100th step (0, 100, 200, ...)
    track_memory=True,
    track_throughput=True,
    track_loss=True,
    track_console_logs=True,
    track_error_logs=True,
    track_hyperparameters=True,
    track_system_resources=True,
    track_layer_graph=True,
)

observer = Observer(
    project_id="1",
    config=observer_config,
)

observer.log_hyperparameters({
    "batch_size": batch_size,
    "num_epochs": num_epochs,
    "learning_rate": lr,
    "optimizer": "Adam",
    "dataset": "MNIST",
    "seed": seed,
    "device": device,
})

[Observer] Initialized | project=1 | run=run_20260221_200204 | device=cpu
[Observer] Backend session created | session_id=5
[Observer] Hyperparameters logged: ['batch_size', 'num_epochs', 'learning_rate', 'optimizer', 'dataset', 'seed', 'device']


## Dataset

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])

train_dataset = datasets.MNIST("data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST("data", train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"Training samples: {len(train_dataset):,}")
print(f"Test samples:     {len(test_dataset):,}")
print(f"Batches per epoch: {len(train_loader)}")

Training samples: 60,000
Test samples:     10,000
Batches per epoch: 938


## Model Definition

In [5]:
class SmallCNN(nn.Module):
    """
    Small convolutional network for MNIST.

    Architecture:
      Conv2d(1,16,3,pad=1) -> ReLU -> MaxPool(2)
      Conv2d(16,32,3,pad=1) -> ReLU -> MaxPool(2)
      Flatten -> Linear(32*7*7, 128) -> ReLU -> Linear(128, 10)

    forward(x, targets=None) returns (logits, loss) to match
    Observer.profile_step() which calls model(x, y).
    """

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x, targets=None):
        x = self.pool(F.relu(self.conv1(x)))  # (B, 16, 14, 14)
        x = self.pool(F.relu(self.conv2(x)))  # (B, 32, 7, 7)
        x = x.view(x.size(0), -1)             # (B, 32*7*7)
        x = F.relu(self.fc1(x))               # (B, 128)
        logits = self.fc2(x)                   # (B, 10)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits, targets)

        return logits, loss

In [6]:
model = SmallCNN().to(device)

num_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {num_params:,}")

observer.register_model(model)

Total parameters: 206,922


[Observer] Model registered | 206,922 params (0.21M) | 4 param layers
[Observer] Model registered in backend | model_id=5


## Training

In [7]:
@torch.no_grad()
def evaluate(model, loader):
    """Compute average loss and accuracy on a DataLoader."""
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits, loss = model(x, y)
        total_loss += loss.item() * x.size(0)
        correct += (logits.argmax(dim=1) == y).sum().item()
        total += x.size(0)
    model.train()
    return total_loss / total, correct / total

In [8]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

print("Starting training...")
training_start = time.time()
global_step = 0

for epoch in range(num_epochs):
    for step, (x, y) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)

        if observer.should_profile(global_step):
            logits, loss = observer.profile_step(model, x, y)
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
        else:
            logits, loss = model(x, y)
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()

        observer.step(global_step, loss, batch_size=x.size(0))
        global_step += 1

    # Validation at end of each epoch
    val_loss, val_acc = evaluate(model, test_loader)
    step_report = observer.flush(val_metrics={
        "val_loss": val_loss,
        "val_acc": val_acc,
    })

    elapsed = time.time() - training_start
    print(
        f"Epoch {epoch}: "
        f"train_loss={step_report['loss']['train_mean']:.4f}  "
        f"val_loss={val_loss:.4f}  val_acc={val_acc:.4f}  "
        f"({elapsed:.1f}s)"
    )

training_time = time.time() - training_start
print(f"\nTraining completed in {training_time:.2f}s ({training_time/60:.2f} min)")

STAGE:2026-02-21 20:02:04 29268:1091655 ActivityProfilerController.cpp:314] Completed Stage: Warm Up


Starting training...


STAGE:2026-02-21 20:02:04 29268:1091655 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2026-02-21 20:02:04 29268:1091655 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
STAGE:2026-02-21 20:02:06 29268:1091655 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2026-02-21 20:02:06 29268:1091655 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2026-02-21 20:02:06 29268:1091655 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
STAGE:2026-02-21 20:02:07 29268:1091655 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2026-02-21 20:02:07 29268:1091655 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2026-02-21 20:02:07 29268:1091655 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
STAGE:2026-02-21 20:02:08 29268:1091655 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2026-02-21 20:02:08 29268:1091655 ActivityProfilerCo

Epoch 0: train_loss=0.1698  val_loss=0.0564  val_acc=0.9809  (24.0s)


STAGE:2026-02-21 20:02:29 29268:1091655 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2026-02-21 20:02:29 29268:1091655 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2026-02-21 20:02:29 29268:1091655 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
STAGE:2026-02-21 20:02:30 29268:1091655 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2026-02-21 20:02:30 29268:1091655 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2026-02-21 20:02:30 29268:1091655 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
STAGE:2026-02-21 20:02:31 29268:1091655 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2026-02-21 20:02:31 29268:1091655 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2026-02-21 20:02:31 29268:1091655 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
STAGE:2026-02-21 20:02:32 29268:1091655 ActivityProfilerCo

Epoch 1: train_loss=0.0525  val_loss=0.0363  val_acc=0.9877  (46.9s)


STAGE:2026-02-21 20:02:51 29268:1091655 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2026-02-21 20:02:51 29268:1091655 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2026-02-21 20:02:51 29268:1091655 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
STAGE:2026-02-21 20:02:52 29268:1091655 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2026-02-21 20:02:52 29268:1091655 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2026-02-21 20:02:52 29268:1091655 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
STAGE:2026-02-21 20:02:54 29268:1091655 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2026-02-21 20:02:54 29268:1091655 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2026-02-21 20:02:54 29268:1091655 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
STAGE:2026-02-21 20:02:55 29268:1091655 ActivityProfilerCo

KeyboardInterrupt: 

## Evaluation

In [None]:
test_loss, test_acc = evaluate(model, test_loader)
print(f"Final test loss:     {test_loss:.4f}")
print(f"Final test accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)")

## Observer Report

In [None]:
report = observer.export(os.path.join("observer_reports", f"{observer.run_id}.json"))

# ── Print summary ──
summary = report["summary"]
print("=" * 60)
print("OBSERVER SUMMARY")
print("=" * 60)
print(f"Total steps recorded:   {summary.get('total_steps', 0)}")
print(f"Total training time:    {summary.get('total_duration_s', 0):.2f}s")

if "loss_trend" in summary:
    lt = summary["loss_trend"]
    print(f"\nLoss trend:")
    print(f"  First interval:  {lt['first']:.4f}")
    print(f"  Last interval:   {lt['last']:.4f}")
    print(f"  Best:            {lt['best']:.4f}")
    print(f"  Improved:        {lt['improved']}")

if "avg_tokens_per_sec" in summary:
    print(f"\nAvg throughput:  {summary['avg_tokens_per_sec']:.0f} tokens/sec")

if "profiler_highlight" in summary:
    ph = summary["profiler_highlight"]
    print(f"\nProfiler highlight:")
    print(f"  Top operation:       {ph.get('top_op', 'N/A')}")
    print(f"  Top op % of total:   {ph.get('top_op_pct', 0):.1f}%")
    print(f"  Fwd/Bwd time ratio:  {ph.get('fwd_bwd_ratio', 'N/A')}")

# ── Print per-step profiler categories ──
print("\n" + "=" * 60)
print("PROFILER: OPERATION CATEGORIES (last step)")
print("=" * 60)
for step_rec in reversed(report["steps"]):
    if "profiler" in step_rec:
        cats = step_rec["profiler"].get("operation_categories", {})
        for cat_name, cat_data in sorted(cats.items(), key=lambda x: -x[1]["cpu_time_ms"]):
            print(f"  {cat_name:<20s}  {cat_data['cpu_time_ms']:>8.1f}ms  ({cat_data['pct_cpu']:>5.1f}%)")
        break

print("=" * 60)
print(f"Full report saved to: observer_reports/{observer.run_id}.json")

observer.close()