<a href="https://colab.research.google.com/github/vramonlinebsc/neural_operator_surrogates/blob/main/surrogate_neural_operators.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.sparse as sp
from scipy.sparse.linalg import expm_multiply
from scipy.linalg import expm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import time
from typing import Tuple, List, Dict, Optional
import json
import pickle
import pandas as pd
from pathlib import Path
from dataclasses import dataclass, asdict
import hashlib
import warnings
warnings.filterwarnings('ignore')

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# ============================================================================
# CONFIGURATION & CHECKPOINTING
# ============================================================================

@dataclass
class ExperimentConfig:
    """Publication-quality experiment configuration"""
    N_values: List[int]
    topologies: List[str]
    n_train_samples: int
    n_val_samples: int
    T: int
    dt: float
    epochs: int
    batch_size: int
    lr: float
    modes: int
    width: int
    n_layers: int

    def get_hash(self) -> str:
        config_str = json.dumps(asdict(self), sort_keys=True)
        return hashlib.md5(config_str.encode()).hexdigest()[:8]


class CheckpointManager:
    """Manages all checkpoints with full resumability"""

    def __init__(self, base_dir: str = "checkpoints"):
        self.base_dir = Path(base_dir)
        self.base_dir.mkdir(exist_ok=True)
        self.results_dir = Path("results")
        self.results_dir.mkdir(exist_ok=True)

    def save_dataset(self, dataset, N: int, topology: str, split: str):
        path = self.base_dir / f"dataset_N{N}_{topology}_{split}.pkl"
        with open(path, 'wb') as f:
            pickle.dump(dataset.data, f)
        print(f"  ✓ Saved dataset: {path.name}")

    def load_dataset(self, N: int, topology: str, split: str, T: int, dt: float):
        path = self.base_dir / f"dataset_N{N}_{topology}_{split}.pkl"
        if path.exists():
            print(f"  ✓ Loading dataset: {path.name}")
            dataset = NMRDataset(N, topology, 0, T, dt)
            with open(path, 'rb') as f:
                dataset.data = pickle.load(f)
            return dataset
        return None

    def save_model(self, model: nn.Module, optimizer, scheduler, N: int,
                   topology: str, epoch: int, history: Dict):
        path = self.base_dir / f"model_N{N}_{topology}_epoch{epoch}.pt"
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'history': history,
            'N': N,
            'topology': topology
        }, path)
        print(f"  ✓ Checkpoint: epoch {epoch}")

    def load_model(self, model: nn.Module, optimizer, scheduler, N: int, topology: str):
        pattern = f"model_N{N}_{topology}_epoch*.pt"
        checkpoints = list(self.base_dir.glob(pattern))

        if not checkpoints:
            return None, None

        latest = max(checkpoints, key=lambda p: int(p.stem.split('epoch')[1]))
        checkpoint = torch.load(latest, map_location='cpu')

        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

        print(f"  ✓ Resumed from epoch {checkpoint['epoch']}: {latest.name}")
        return checkpoint['epoch'], checkpoint.get('history', None)

    def save_benchmark(self, result: Dict, N: int, topology: str):
        path = self.base_dir / f"benchmark_N{N}_{topology}.json"
        with open(path, 'w') as f:
            json.dump(result, f, indent=2)

    def load_benchmark(self, N: int, topology: str) -> Optional[Dict]:
        path = self.base_dir / f"benchmark_N{N}_{topology}.json"
        return json.load(open(path)) if path.exists() else None

    def save_results_csv(self, results: Dict, name: str):
        """Save results as CSV for publication"""
        df = pd.DataFrame(results)
        path = self.results_dir / f"{name}.csv"
        df.to_csv(path, index=False, float_format='%.6f')
        print(f"  ✓ Saved CSV: {path}")
        return path


# ============================================================================
# OPTIMIZED SPIN SIMULATOR
# ============================================================================

class SpinSystemOptimized:
    """Exact quantum spin simulator with sparse/dense modes"""

    def __init__(self, N: int, topology: str = 'chain', use_sparse: bool = None):
        self.N = N
        self.dim = 2 ** N
        self.topology = topology
        self.use_sparse = use_sparse if use_sparse is not None else (N > 10)
        self._build_operators()

    def _kron_list(self, ops: List, sparse: bool = False):
        if sparse:
            result = sp.csr_matrix(ops[0])
            for op in ops[1:]:
                result = sp.kron(result, op)
            return result
        result = ops[0]
        for op in ops[1:]:
            result = np.kron(result, op)
        return result

    def _build_operators(self):
        sx = np.array([[0, 1], [1, 0]], dtype=complex)
        sy = np.array([[0, -1j], [1j, 0]], dtype=complex)
        sz = np.array([[1, 0], [0, -1]], dtype=complex)
        identity = np.eye(2, dtype=complex)

        if self.use_sparse:
            sx, sy, sz = sp.csr_matrix(sx), sp.csr_matrix(sy), sp.csr_matrix(sz)
            identity = sp.eye(2, dtype=complex, format='csr')

        self.Ix, self.Iy, self.Iz = [], [], []

        for i in range(self.N):
            ops = [identity] * self.N
            ops[i] = sx
            self.Ix.append(self._kron_list(ops, self.use_sparse))
            ops[i] = sy
            self.Iy.append(self._kron_list(ops, self.use_sparse))
            ops[i] = sz
            self.Iz.append(self._kron_list(ops, self.use_sparse))

    def get_coupling_pairs(self) -> List[Tuple[int, int]]:
        if self.topology == 'chain':
            return [(i, i+1) for i in range(self.N-1)]
        elif self.topology == 'ring':
            return [(i, (i+1) % self.N) for i in range(self.N)]
        elif self.topology == 'star':
            return [(0, i) for i in range(1, self.N)]
        return []

    def build_hamiltonian(self, Omega: np.ndarray, J: float):
        if self.use_sparse:
            H = sp.csr_matrix((self.dim, self.dim), dtype=complex)
        else:
            H = np.zeros((self.dim, self.dim), dtype=complex)

        for i in range(self.N):
            H = H + Omega[i] * self.Iz[i]

        pairs = self.get_coupling_pairs()
        for i, j in pairs:
            if self.use_sparse:
                H = H + 2*np.pi*J * (
                    self.Ix[i].multiply(self.Ix[j]) +
                    self.Iy[i].multiply(self.Iy[j]) +
                    self.Iz[i].multiply(self.Iz[j])
                )
            else:
                H = H + 2*np.pi*J * (
                    self.Ix[i]@self.Ix[j] + self.Iy[i]@self.Iy[j] + self.Iz[i]@self.Iz[j]
                )
        return H

    def simulate(self, Omega: np.ndarray, J: float, T: int, dt: float = 1e-4) -> Dict:
        H = self.build_hamiltonian(Omega, J)
        psi0 = np.ones(self.dim, dtype=complex) / np.sqrt(self.dim)
        times = np.arange(T) * dt

        Mx, My, I1z = np.zeros(T), np.zeros(T), np.zeros(T)

        Ix_sum = sum(self.Ix)
        Iy_sum = sum(self.Iy)
        Iz_first = self.Iz[0]

        start = time.time()

        if self.use_sparse:
            for t_idx, t in enumerate(times):
                psi_t = expm_multiply(-1j * H * t, psi0)
                Mx[t_idx] = np.real(np.conj(psi_t) @ (Ix_sum @ psi_t))
                My[t_idx] = np.real(np.conj(psi_t) @ (Iy_sum @ psi_t))
                I1z[t_idx] = np.real(np.conj(psi_t) @ (Iz_first @ psi_t))
        else:
            U = expm(-1j * H * dt)
            psi_t = psi0.copy()
            for t_idx in range(T):
                Mx[t_idx] = np.real(np.conj(psi_t) @ Ix_sum @ psi_t)
                My[t_idx] = np.real(np.conj(psi_t) @ Iy_sum @ psi_t)
                I1z[t_idx] = np.real(np.conj(psi_t) @ Iz_first @ psi_t)
                psi_t = U @ psi_t

        return {
            'Mx': Mx, 'My': My, 'I1z': I1z,
            'times': times, 'elapsed_time': time.time() - start
        }


# ============================================================================
# DATASET
# ============================================================================

class NMRDataset(Dataset):
    def __init__(self, N: int, topology: str, n_samples: int, T: int, dt: float):
        self.N, self.topology, self.n_samples = N, topology, n_samples
        self.T, self.dt, self.data = T, dt, []

    def generate_data(self):
        if self.n_samples == 0:
            return
        print(f"Generating {self.n_samples} trajectories (N={self.N}, {self.topology})...")

        system = SpinSystemOptimized(self.N, self.topology)
        for i in range(self.n_samples):
            Omega = np.random.uniform(-100, 100, self.N) * 2 * np.pi
            J = np.random.uniform(5, 20)
            result = system.simulate(Omega, J, self.T, self.dt)

            params = np.concatenate([Omega, [J]])
            observables = np.stack([result['Mx'], result['My'], result['I1z']], axis=1)
            self.data.append({'params': params, 'observables': observables})

            if (i + 1) % 10 == 0:
                print(f"  {i+1}/{self.n_samples} complete")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return (torch.tensor(item['params'], dtype=torch.float32),
                torch.tensor(item['observables'], dtype=torch.float32))


# ============================================================================
# NEURAL OPERATOR
# ============================================================================

class SpectralConv1d(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, modes: int):
        super().__init__()
        self.modes = modes
        scale = 1 / (in_channels * out_channels)
        self.weights = nn.Parameter(
            scale * torch.rand(in_channels, out_channels, modes, 2, dtype=torch.float32)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_ft = torch.fft.rfft(x, dim=-1)
        out_ft = torch.zeros(x.shape[0], self.weights.shape[1], x.size(-1)//2 + 1,
                            dtype=torch.cfloat, device=x.device)
        out_ft[:, :, :self.modes] = torch.einsum(
            "bix,iox->box", x_ft[:, :, :self.modes],
            torch.view_as_complex(self.weights)
        )
        return torch.fft.irfft(out_ft, n=x.size(-1), dim=-1)


class PhysicsInformedFNO(nn.Module):
    def __init__(self, modes: int = 16, width: int = 64, n_layers: int = 4,
                 n_params: int = 13, n_outputs: int = 3):
        super().__init__()
        self.modes, self.width, self.n_layers = modes, width, n_layers

        self.param_encoder = nn.Sequential(
            nn.Linear(n_params, width), nn.GELU(), nn.Linear(width, width)
        )
        self.spectral_layers = nn.ModuleList([
            SpectralConv1d(width, width, modes) for _ in range(n_layers)
        ])
        self.conv_layers = nn.ModuleList([
            nn.Conv1d(width, width, 1) for _ in range(n_layers)
        ])
        self.output_projection = nn.Sequential(
            nn.Linear(width, width), nn.GELU(), nn.Linear(width, n_outputs)
        )

    def forward(self, params: torch.Tensor, time_steps: int) -> torch.Tensor:
        x = self.param_encoder(params)
        x = x.unsqueeze(-1).expand(-1, -1, time_steps)

        for i in range(self.n_layers):
            x = self.spectral_layers[i](x) + self.conv_layers[i](x)
            if i < self.n_layers - 1:
                x = F.gelu(x)

        return self.output_projection(x.transpose(1, 2))

    def compute_physics_loss(self, pred: torch.Tensor) -> torch.Tensor:
        Mx, My, I1z = pred[:, :, 0], pred[:, :, 1], pred[:, :, 2]
        mag_loss = F.relu(torch.sqrt(Mx**2 + My**2) - 1.0).mean()
        smooth_loss = ((Mx[:, 1:] - Mx[:, :-1])**2 + (My[:, 1:] - My[:, :-1])**2).mean()
        diffusion_loss = F.relu(I1z[:, 1:] - I1z[:, :-1]).mean()
        return mag_loss + 0.1 * smooth_loss + 0.1 * diffusion_loss


# ============================================================================
# TRAINING
# ============================================================================

def train_surrogate(model: nn.Module, train_loader, val_loader, N: int,
                   topology: str, epochs: int, lr: float, device: str,
                   ckpt_mgr: CheckpointManager) -> Dict:
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

    start_epoch = 0
    history = {'train_loss': [], 'val_loss': [], 'physics_loss': []}

    loaded_epoch, loaded_history = ckpt_mgr.load_model(model, optimizer, scheduler, N, topology)
    if loaded_epoch is not None:
        start_epoch = loaded_epoch + 1
        if loaded_history:
            history = loaded_history

    if start_epoch >= epochs:
        print("  ✓ Training complete")
        return history

    for epoch in range(start_epoch, epochs):
        model.train()
        train_losses, physics_losses = [], []

        for params, observables in train_loader:
            params, observables = params.to(device), observables.to(device)
            optimizer.zero_grad()

            pred = model(params, observables.shape[1])
            data_loss = F.mse_loss(pred, observables)
            physics_loss = model.compute_physics_loss(pred)
            loss = data_loss + 0.01 * physics_loss

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            train_losses.append(data_loss.item())
            physics_losses.append(physics_loss.item())

        model.eval()
        val_losses = []
        with torch.no_grad():
            for params, observables in val_loader:
                params, observables = params.to(device), observables.to(device)
                val_losses.append(F.mse_loss(model(params, observables.shape[1]), observables).item())

        scheduler.step()

        history['train_loss'].append(np.mean(train_losses))
        history['val_loss'].append(np.mean(val_losses))
        history['physics_loss'].append(np.mean(physics_losses))

        if (epoch + 1) % 10 == 0:
            ckpt_mgr.save_model(model, optimizer, scheduler, N, topology, epoch, history)
            print(f"Epoch {epoch+1}/{epochs} - Train: {history['train_loss'][-1]:.6f}, "
                  f"Val: {history['val_loss'][-1]:.6f}, Physics: {history['physics_loss'][-1]:.6f}")

    ckpt_mgr.save_model(model, optimizer, scheduler, N, topology, epochs-1, history)
    return history


# ============================================================================
# BENCHMARKING
# ============================================================================

def benchmark_all_methods(N_values: List[int], topology: str, config: ExperimentConfig,
                         device: str) -> Dict:
    ckpt_mgr = CheckpointManager()

    results = {
        'N': [], 'exact_time': [], 'krylov_time': [], 'surrogate_time': [],
        'krylov_error': [], 'surrogate_error': []
    }

    for N in N_values:
        print(f"\n{'='*70}")
        print(f"BENCHMARKING N={N}, topology={topology}")
        print(f"{'='*70}")

        existing = ckpt_mgr.load_benchmark(N, topology)
        if existing:
            print("  ✓ Using cached results")
            for k in results:
                if k in existing:
                    results[k].append(existing[k])
            continue

        # Load/generate datasets
        train_ds = ckpt_mgr.load_dataset(N, topology, 'train', config.T, config.dt)
        if not train_ds:
            train_ds = NMRDataset(N, topology, config.n_train_samples, config.T, config.dt)
            train_ds.generate_data()
            ckpt_mgr.save_dataset(train_ds, N, topology, 'train')

        val_ds = ckpt_mgr.load_dataset(N, topology, 'val', config.T, config.dt)
        if not val_ds:
            val_ds = NMRDataset(N, topology, config.n_val_samples, config.T, config.dt)
            val_ds.generate_data()
            ckpt_mgr.save_dataset(val_ds, N, topology, 'val')

        train_loader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True)
        val_loader = DataLoader(val_ds, batch_size=config.batch_size)

        # Train model
        model = PhysicsInformedFNO(config.modes, config.width, config.n_layers, N+1, 3)
        print("\nTraining surrogate...")
        train_surrogate(model, train_loader, val_loader, N, topology,
                       config.epochs, config.lr, device, ckpt_mgr)

        # Benchmark
        print("\nBenchmarking methods...")
        Omega = np.random.uniform(-100, 100, N) * 2 * np.pi
        J = 12.5

        # Exact (dense)
        print("  [1/3] Exact method (dense matrix)...")
        sys_exact = SpinSystemOptimized(N, topology, use_sparse=False)
        exact_res = sys_exact.simulate(Omega, J, config.T, config.dt)
        exact_time = exact_res['elapsed_time']
        print(f"        Time: {exact_time:.4f}s")

        # Krylov (sparse)
        print("  [2/3] Krylov method (sparse)...")
        sys_krylov = SpinSystemOptimized(N, topology, use_sparse=True)
        krylov_res = sys_krylov.simulate(Omega, J, config.T, config.dt)
        krylov_time = krylov_res['elapsed_time']
        krylov_err = np.sqrt(
            np.mean((exact_res['Mx'] - krylov_res['Mx'])**2) +
            np.mean((exact_res['My'] - krylov_res['My'])**2) +
            np.mean((exact_res['I1z'] - krylov_res['I1z'])**2)
        )
        print(f"        Time: {krylov_time:.4f}s, Error: {krylov_err:.2e}")

        # Surrogate
        print("  [3/3] Neural surrogate...")
        model.eval()
        model = model.to(device)
        params_t = torch.tensor(np.concatenate([Omega, [J]]), dtype=torch.float32).unsqueeze(0).to(device)

        # Warmup
        with torch.no_grad():
            _ = model(params_t, config.T)

        start = time.time()
        with torch.no_grad():
            pred = model(params_t, config.T)
        surrogate_time = time.time() - start

        pred = pred.squeeze().cpu().numpy()
        surrogate_err = np.sqrt(
            np.mean((exact_res['Mx'] - pred[:, 0])**2) +
            np.mean((exact_res['My'] - pred[:, 1])**2) +
            np.mean((exact_res['I1z'] - pred[:, 2])**2)
        )
        print(f"        Time: {surrogate_time:.6f}s, Error: {surrogate_err:.6f}")

        result = {
            'N': N, 'exact_time': exact_time, 'krylov_time': krylov_time,
            'surrogate_time': surrogate_time, 'krylov_error': float(krylov_err),
            'surrogate_error': float(surrogate_err),
            'speedup_vs_exact': exact_time / surrogate_time,
            'speedup_vs_krylov': krylov_time / surrogate_time
        }

        ckpt_mgr.save_benchmark(result, N, topology)
        for k in results:
            if k in result:
                results[k].append(result[k])

        print(f"\n{'Method':<15} {'Time(s)':<12} {'Error':<12} {'Speedup'}")
        print("-" * 55)
        print(f"{'Exact':<15} {exact_time:>11.4f}  {'-':<12} {'1.0×'}")
        print(f"{'Krylov':<15} {krylov_time:>11.4f}  {krylov_err:>11.2e}  {exact_time/krylov_time:.1f}×")
        print(f"{'Surrogate':<15} {surrogate_time:>11.6f}  {surrogate_err:>11.6f}  {result['speedup_vs_exact']:.1f}×")

    return results


# ============================================================================
# PLOTTING
# ============================================================================

def generate_publication_figures(results: Dict, ckpt_mgr: CheckpointManager):
    """Generate publication-quality figures"""
    plt.style.use('seaborn-v0_8-paper')

    fig = plt.figure(figsize=(18, 12))
    gs = fig.add_gridspec(3, 3, hspace=0.35, wspace=0.35)

    # Main scaling plot
    ax1 = fig.add_subplot(gs[0, :])
    ax1.semilogy(results['N'], results['exact_time'], 'o-',
                label='Exact (Dense)', linewidth=3, markersize=12, color='#1f77b4')
    ax1.semilogy(results['N'], results['krylov_time'], 's-',
                label='Krylov (Sparse)', linewidth=3, markersize=12, color='#ff7f0e')
    ax1.semilogy(results['N'], results['surrogate_time'], 'd-',
                label='Neural Surrogate', linewidth=3, markersize=12, color='#2ca02c')
    ax1.set_xlabel('Number of Spins (N)', fontsize=16, fontweight='bold')
    ax1.set_ylabel('Wall-Clock Time (s)', fontsize=16, fontweight='bold')
    ax1.set_title('Computational Scaling Comparison', fontsize=18, fontweight='bold', pad=20)
    ax1.legend(fontsize=14, framealpha=0.95, loc='upper left')
    ax1.grid(True, alpha=0.3, which='both')
    ax1.tick_params(labelsize=14)

    # Speedup comparison
    ax2 = fig.add_subplot(gs[1, 0])
    x = np.arange(len(results['N']))
    width = 0.35
    speedup_krylov = [results['exact_time'][i]/results['krylov_time'][i] for i in range(len(x))]
    speedup_surr = [results['exact_time'][i]/results['surrogate_time'][i] for i in range(len(x))]

    ax2.bar(x - width/2, speedup_krylov, width, label='Krylov', color='#ff7f0e', alpha=0.8, edgecolor='black')
    ax2.bar(x + width/2, speedup_surr, width, label='Surrogate', color='#2ca02c', alpha=0.8, edgecolor='black')
    ax2.set_xlabel('System Size (N)', fontsize=14, fontweight='bold')
    ax2.set_ylabel('Speedup Factor', fontsize=14, fontweight='bold')
    ax2.set_title('Speedup vs Exact Method', fontsize=15, fontweight='bold')
    ax2.set_xticks(x)
    ax2.set_xticklabels(results['N'])
    ax2.legend(fontsize=12)
    ax2.grid(True, alpha=0.3, axis='y')
    ax2.set_yscale('log')
    ax2.tick_params(labelsize=12)

    # Error analysis
    ax3 = fig.add_subplot(gs[1, 1])
    ax3.semilogy(results['N'], results['krylov_error'], 's-',
                label='Krylov Error', linewidth=2.5, markersize=10, color='#ff7f0e')
    ax3.semilogy(results['N'], results['surrogate_error'], 'd-',
                label='Surrogate Error', linewidth=2.5, markersize=10, color='#2ca02c')
    ax3.set_xlabel('System Size (N)', fontsize=14, fontweight='bold')
    ax3.set_ylabel('RMSE vs Exact', fontsize=14, fontweight='bold')
    ax3.set_title('Prediction Accuracy', fontsize=15, fontweight='bold')
    ax3.legend(fontsize=12)
    ax3.grid(True, alpha=0.3, which='both')
    ax3.tick_params(labelsize=12)

    # Efficiency plot
    ax4 = fig.add_subplot(gs[1, 2])
    efficiency_surr = [results['surrogate_error'][i] * results['surrogate_time'][i] for i in range(len(x))]
    efficiency_krylov = [results['krylov_error'][i] * results['krylov_time'][i] for i in range(len(x))]

    ax4.semilogy(results['N'], efficiency_krylov, 's-',
                label='Krylov', linewidth=2.5, markersize=10, color='#ff7f0e')
    ax4.semilogy(results['N'], efficiency_surr, 'd-',
                label='Surrogate', linewidth=2.5, markersize=10, color='#2ca02c')
    ax4.set_xlabel('System Size (N)', fontsize=14, fontweight='bold')
    ax4.set_ylabel('Error × Time', fontsize=14, fontweight='bold')
    ax4.set_title('Computational Efficiency', fontsize=15, fontweight='bold')
    ax4.legend(fontsize=12)
    ax4.grid(True, alpha=0.3, which='both')
    ax4.tick_params(labelsize=12)

    # Time breakdown table
    ax5 = fig.add_subplot(gs[2, :])
    ax5.axis('tight')
    ax5.axis('off')

    table_data = [['N', 'Exact (s)', 'Krylov (s)', 'Surrogate (s)', 'KrylovErr', 'Surr Err', 'Speedup']]
    for i in range(len(results['N'])):
        table_data.append([
            f"{results['N'][i]}",
            f"{results['exact_time'][i]:.4f}",
            f"{results['krylov_time'][i]:.4f}",
            f"{results['surrogate_time'][i]:.6f}",
            f"{results['krylov_error'][i]:.2e}",
            f"{results['surrogate_error'][i]:.6f}",
            f"{results['exact_time'][i]/results['surrogate_time'][i]:.1f}×"
        ])

    table = ax5.table(cellText=table_data, cellLoc='center', loc='center',
                     colWidths=[0.1, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15])
    table.auto_set_font_size(False)
    table.set_fontsize(11)
    table.scale(1, 2.5)

    # Style header row
    for i in range(7):
        table[(0, i)].set_facecolor('#4CAF50')
        table[(0, i)].set_text_props(weight='bold', color='white')

    plt.savefig(ckpt_mgr.results_dir / 'comprehensive_benchmark.png',
                dpi=300, bbox_inches='tight', facecolor='white')
    print("\n✓ Saved: comprehensive_benchmark.png")


# ============================================================================
# MAIN
# ============================================================================

def main():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"\nDevice: {device}")

    # Publication config
    config = ExperimentConfig(
        N_values=[4, 6, 8, 10, 12],  # Extended range
        topologies=['chain'],
        n_train_samples=200,  # More data
        n_val_samples=50,
        T=300,  # Longer trajectories
        dt=1e-4,
        epochs=200,  # More training
        batch_size=16,
        lr=1e-3,
        modes=24,  # Larger network
        width=128,
        n_layers=6
    )

    print("\n" + "="*70)
    print("NMR SPIN DYNAMICS - PUBLICATION BENCHMARK")
    print("="*70)
    print(f"Config: N={config.N_values}, T={config.T}, Epochs={config.epochs}")
    print(f"Samples: {config.n_train_samples} train, {config.n_val_samples} val")
    print(f"Network: {config.n_layers} layers, width {config.width}, {config.modes} modes")
    print("="*70)

    ckpt_mgr = CheckpointManager()

    # Run benchmark
    results = benchmark_all_methods(config.N_values, 'chain', config, device)

    # Save results
    ckpt_mgr.save_results_csv(results, 'benchmark_results')

    with open(ckpt_mgr.results_dir / 'benchmark_results.json', 'w') as f:
        json.dump(results, f, indent=2)

    # Generate figures
    generate_publication_figures(results, ckpt_mgr)

    print("\n" + "="*70)
    print("BENCHMARK COMPLETE")
    print("="*70)
    print(f"Results: {ckpt_mgr.results_dir}")
    print(f"Checkpoints: {ckpt_mgr.base_dir}")
    print("="*70)


if __name__ == "__main__":
    main()



Device: cuda

NMR SPIN DYNAMICS - PUBLICATION BENCHMARK
Config: N=[4, 6, 8, 10, 12], T=300, Epochs=200
Samples: 200 train, 50 val
Network: 6 layers, width 128, 24 modes

BENCHMARKING N=4, topology=chain
Generating 200 trajectories (N=4, chain)...
  10/200 complete
  20/200 complete
  30/200 complete
  40/200 complete
  50/200 complete
  60/200 complete
  70/200 complete
  80/200 complete
  90/200 complete
  100/200 complete
  110/200 complete
  120/200 complete
  130/200 complete
  140/200 complete
  150/200 complete
  160/200 complete
  170/200 complete
  180/200 complete
  190/200 complete
  200/200 complete
  ✓ Saved dataset: dataset_N4_chain_train.pkl
Generating 50 trajectories (N=4, chain)...
  10/50 complete
  20/50 complete
  30/50 complete
  40/50 complete
  50/50 complete
  ✓ Saved dataset: dataset_N4_chain_val.pkl

Training surrogate...
  ✓ Checkpoint: epoch 9
Epoch 10/200 - Train: 0.832000, Val: 1.182016, Physics: 0.000387
  ✓ Checkpoint: epoch 19
Epoch 20/200 - Train: 0.80