# GNN Training Diagnostics

Visualise loss curves and validation metrics recorded during training. Re-run the training pipeline with the updated logging to populate the history files.

In [None]:
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import sys

def _locate_project_root(start: Path) -> Path:
    for candidate in [start, *start.parents]:
        candidate = candidate.resolve()
        if (candidate / 'src').exists():
            return candidate
    raise RuntimeError('Could not locate project root containing "src" directory')

PROJECT_ROOT = _locate_project_root(Path.cwd())
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))


try:
    import seaborn as sns
    sns.set_theme(style='whitegrid')
except ImportError:  # pragma: no cover
    sns = None
    plt.style.use('ggplot')

from src.analysis.performance import load_training_history, load_step_losses

RUN_DIR = Path('outputs/gnn_runs/baseline')
history_df = None
step_df = None
try:
    history_df = load_training_history(RUN_DIR)
except FileNotFoundError as exc:
    print(f'Training history not found: {exc}')
try:
    step_df = load_step_losses(RUN_DIR)
except FileNotFoundError as exc:
    print(f'Step-level losses not found: {exc}')

if history_df is not None and not history_df.empty:
    display(history_df.head())
else:
    print('Run training to generate history metrics.')
if step_df is not None and not step_df.empty:
    display(step_df.head())


## Epoch-level Metrics


In [None]:
if history_df is not None and not history_df.empty:
    fig, ax = plt.subplots(figsize=(8, 4))
    ax.plot(history_df['epoch'], history_df['train_loss'], marker='o', label='Train loss')
    for col, label in [('val_dispatch_error', 'Val dispatch error'),
                     ('val_cost_gap', 'Val cost gap'),
                     ('val_violation_rate', 'Val violation rate')]:
        if col in history_df:
            ax.plot(history_df['epoch'], history_df[col], marker='o', label=label)
    if 'is_best' in history_df:
        best = history_df[history_df['is_best']]
        if not best.empty:
            ax.scatter(best['epoch'], best['train_loss'], s=160, marker='*', color='gold', edgecolor='black', label='Best epoch')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Metric value')
    ax.set_title('Training vs validation metrics')
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
else:
    print('No epoch history to plot.')


## Step-level Loss


In [None]:
if step_df is not None and not step_df.empty:
    fig, ax = plt.subplots(figsize=(8, 4))
    ax.plot(step_df['step'], step_df['loss'], linewidth=1, alpha=0.8, label='Instantaneous loss')
    if 'step' in step_df:
        rolling = step_df.set_index('step')['loss'].rolling(window=10, min_periods=1).mean()
        ax.plot(rolling.index, rolling.values, color='red', linewidth=2, label='Rolling mean (10 steps)')
    ax.set_xlabel('Global step')
    ax.set_ylabel('Loss')
    ax.set_title('Mini-batch loss trajectory')
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
else:
    print('No step-level loss data available.')
