# Surrogate Evaluation Notebook

This notebook documents the prototype workflow for training and evaluating
surrogate models against curated waveform datasets. It is designed to run
deterministically by fixing random seeds and reusing manifests generated via
`tools/datasets/curate_waveforms.py`.

In [None]:
import json
from pathlib import Path

import numpy as np
import torch

from igsoa_analysis.surrogates.pytorch_surrogate import FeedForwardSurrogate

# Training configuration is version controlled to guarantee reproducibility.
TRAINING_CONFIG = {
    "seed": 42,
    "input_dim": 128,
    "output_dim": 64,
    "hidden_layers": [256, 256],
    "learning_rate": 1e-3,
    "batch_size": 64,
    "epochs": 25,
}

torch.manual_seed(TRAINING_CONFIG["seed"])
np.random.seed(TRAINING_CONFIG["seed"])

In [None]:
DATASET_ROOT = Path("analysis/datasets")
DATASET_NAME = "example_waveforms"
MANIFEST_PATH = DATASET_ROOT / DATASET_NAME / "manifest.json"

if not MANIFEST_PATH.exists():
    raise FileNotFoundError(
        f"Manifest {MANIFEST_PATH} not found. Run the curation script to generate it."
    )

manifest = json.loads(MANIFEST_PATH.read_text(encoding="utf-8"))
print(f'Loaded {manifest["sample_count"]} samples for training/evaluation.')

In [None]:
model = FeedForwardSurrogate.from_config(TRAINING_CONFIG)
print(model)

# Placeholder training loop: replace with DataLoader built from manifest entries.
dummy_inputs = torch.randn(128, TRAINING_CONFIG["input_dim"])
dummy_targets = torch.randn(128, TRAINING_CONFIG["output_dim"])
loss_history = model.train_on_arrays(dummy_inputs, dummy_targets, epochs=2)
print('Recorded loss trajectory:', loss_history)

In [None]:
# Evaluate surrogate predictions for latency/accuracy benchmarking.
with torch.no_grad():
    predictions = model(dummy_inputs[:16])
    mse = torch.mean((predictions - dummy_targets[:16]) ** 2).item()
    print(f'MSE on hold-out batch: {mse:.6f}')

benchmark_record = {
    "dataset": DATASET_NAME,
    "config": TRAINING_CONFIG,
    "metric": "mse",
    "value": mse,
}
print(json.dumps(benchmark_record, indent=2))