<a href="https://colab.research.google.com/github/vramonlinebsc/neural_operator_surrogates/blob/main/neural_operator_surrogates_checkpointed_version.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.sparse as sp
from scipy.sparse.linalg import expm_multiply
from scipy.linalg import expm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import time
from itertools import product
from typing import Tuple, List, Dict, Optional
import json
import pickle
import os
from pathlib import Path
from dataclasses import dataclass, asdict
import hashlib

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# ============================================================================
# CHECKPOINT MANAGEMENT
# ============================================================================

@dataclass
class ExperimentConfig:
    """Configuration for experiments"""
    N_values: List[int]
    topologies: List[str]
    n_train_samples: int
    n_val_samples: int
    T: int
    dt: float
    epochs: int
    batch_size: int
    lr: float
    modes: int
    width: int
    n_layers: int

    def get_hash(self) -> str:
        """Get unique hash for this configuration"""
        config_str = json.dumps(asdict(self), sort_keys=True)
        return hashlib.md5(config_str.encode()).hexdigest()[:8]


class CheckpointManager:
    """Manages checkpoints for resumable experiments"""

    def __init__(self, base_dir: str = "checkpoints"):
        self.base_dir = Path(base_dir)
        self.base_dir.mkdir(exist_ok=True)

    def save_dataset(self, dataset, N: int, topology: str, split: str):
        """Save dataset to disk"""
        path = self.base_dir / f"dataset_N{N}_{topology}_{split}.pkl"
        with open(path, 'wb') as f:
            pickle.dump(dataset.data, f)
        print(f"  ✓ Saved dataset: {path}")

    def load_dataset(self, N: int, topology: str, split: str, T: int, dt: float):
        """Load dataset from disk"""
        path = self.base_dir / f"dataset_N{N}_{topology}_{split}.pkl"
        if path.exists():
            print(f"  ✓ Loading existing dataset: {path}")
            dataset = NMRDataset(N, topology, 0, T, dt)  # Empty dataset
            with open(path, 'rb') as f:
                dataset.data = pickle.load(f)
            return dataset
        return None

    def save_model(self, model: nn.Module, N: int, topology: str, epoch: int):
        """Save model checkpoint"""
        path = self.base_dir / f"model_N{N}_{topology}_epoch{epoch}.pt"
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'N': N,
            'topology': topology
        }, path)
        print(f"  ✓ Saved model checkpoint: {path}")

    def load_model(self, model: nn.Module, N: int, topology: str) -> Optional[int]:
        """Load latest model checkpoint"""
        pattern = f"model_N{N}_{topology}_epoch*.pt"
        checkpoints = list(self.base_dir.glob(pattern))

        if not checkpoints:
            return None

        # Get latest checkpoint
        latest = max(checkpoints, key=lambda p: int(p.stem.split('epoch')[1]))
        checkpoint = torch.load(latest, map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])
        epoch = checkpoint['epoch']

        print(f"  ✓ Loaded model from epoch {epoch}: {latest}")
        return epoch

    def save_training_history(self, history: Dict, N: int, topology: str):
        """Save training history"""
        path = self.base_dir / f"history_N{N}_{topology}.json"
        with open(path, 'w') as f:
            json.dump(history, f, indent=2)

    def load_training_history(self, N: int, topology: str) -> Optional[Dict]:
        """Load training history"""
        path = self.base_dir / f"history_N{N}_{topology}.json"
        if path.exists():
            with open(path, 'r') as f:
                return json.load(f)
        return None

    def save_benchmark_result(self, result: Dict, N: int, topology: str):
        """Save benchmark result"""
        path = self.base_dir / f"benchmark_N{N}_{topology}.json"
        with open(path, 'w') as f:
            json.dump(result, f, indent=2)

    def load_benchmark_result(self, N: int, topology: str) -> Optional[Dict]:
        """Load benchmark result"""
        path = self.base_dir / f"benchmark_N{N}_{topology}.json"
        if path.exists():
            with open(path, 'r') as f:
                return json.load(f)
        return None

    def is_complete(self, N: int, topology: str) -> bool:
        """Check if experiment is complete"""
        path = self.base_dir / f"benchmark_N{N}_{topology}.json"
        return path.exists()


# ============================================================================
# OPTIMIZED SPIN SYSTEM (GPU-ready, sparse operations)
# ============================================================================

class SpinSystemOptimized:
    """GPU-optimized exact simulator for coupled spin-1/2 systems"""

    def __init__(self, N: int, topology: str = 'chain', device: str = 'cpu'):
        self.N = N
        self.dim = 2 ** N
        self.topology = topology
        self.device = device

        # For large N, we only support CPU with sparse matrices
        if N > 10:
            self.device = 'cpu'
            self.use_sparse = True
        else:
            self.use_sparse = False

        self._build_operators()

    def _kron_list(self, ops: List[np.ndarray], sparse: bool = False):
        """Kronecker product of list of operators"""
        if sparse:
            result = sp.csr_matrix(ops[0])
            for op in ops[1:]:
                result = sp.kron(result, op)
            return result
        else:
            result = ops[0]
            for op in ops[1:]:
                result = np.kron(result, op)
            return result

    def _build_operators(self):
        """Build spin operators for all sites"""
        # Pauli matrices
        sx = np.array([[0, 1], [1, 0]], dtype=complex)
        sy = np.array([[0, -1j], [1j, 0]], dtype=complex)
        sz = np.array([[1, 0], [0, -1]], dtype=complex)
        identity = np.eye(2, dtype=complex)

        if self.use_sparse:
            sx = sp.csr_matrix(sx)
            sy = sp.csr_matrix(sy)
            sz = sp.csr_matrix(sz)
            identity = sp.eye(2, dtype=complex, format='csr')

        self.Ix = []
        self.Iy = []
        self.Iz = []

        for i in range(self.N):
            ops = [identity] * self.N

            ops[i] = sx
            self.Ix.append(self._kron_list(ops, self.use_sparse))

            ops[i] = sy
            self.Iy.append(self._kron_list(ops, self.use_sparse))

            ops[i] = sz
            self.Iz.append(self._kron_list(ops, self.use_sparse))

    def get_coupling_pairs(self) -> List[Tuple[int, int]]:
        """Get list of coupled spin pairs based on topology"""
        pairs = []

        if self.topology == 'chain':
            pairs = [(i, i+1) for i in range(self.N-1)]
        elif self.topology == 'ring':
            pairs = [(i, (i+1) % self.N) for i in range(self.N)]
        elif self.topology == 'star':
            pairs = [(0, i) for i in range(1, self.N)]

        return pairs

    def build_hamiltonian(self, Omega: np.ndarray, J: float):
        """Build Hamiltonian matrix"""
        if self.use_sparse:
            H = sp.csr_matrix((self.dim, self.dim), dtype=complex)
        else:
            H = np.zeros((self.dim, self.dim), dtype=complex)

        # Chemical shift terms
        for i in range(self.N):
            H = H + Omega[i] * self.Iz[i]

        # J-coupling terms
        pairs = self.get_coupling_pairs()
        for i, j in pairs:
            if self.use_sparse:
                H = H + 2 * np.pi * J * (
                    self.Ix[i].multiply(self.Ix[j]) +
                    self.Iy[i].multiply(self.Iy[j]) +
                    self.Iz[i].multiply(self.Iz[j])
                )
            else:
                H = H + 2 * np.pi * J * (
                    self.Ix[i] @ self.Ix[j] +
                    self.Iy[i] @ self.Iy[j] +
                    self.Iz[i] @ self.Iz[j]
                )

        return H

    def simulate(self, Omega: np.ndarray, J: float, T: int, dt: float = 1e-4) -> Dict:
        """Simulate spin dynamics - optimized version"""
        H = self.build_hamiltonian(Omega, J)

        # Initial state
        psi0 = np.ones(self.dim, dtype=complex) / np.sqrt(self.dim)

        times = np.arange(T) * dt
        Mx = np.zeros(T)
        My = np.zeros(T)
        I1z = np.zeros(T)

        start_time = time.time()

        # Precompute sum operators
        if self.use_sparse:
            Ix_sum = sum(self.Ix)
            Iy_sum = sum(self.Iy)
            Iz_first = self.Iz[0]
        else:
            Ix_sum = sum(self.Ix)
            Iy_sum = sum(self.Iy)
            Iz_first = self.Iz[0]

        # Use more efficient propagation for sparse matrices
        if self.use_sparse:
            for t_idx, t in enumerate(times):
                psi_t = expm_multiply(-1j * H * t, psi0)

                Mx[t_idx] = np.real(np.conj(psi_t) @ (Ix_sum @ psi_t))
                My[t_idx] = np.real(np.conj(psi_t) @ (Iy_sum @ psi_t))
                I1z[t_idx] = np.real(np.conj(psi_t) @ (Iz_first @ psi_t))

                if t_idx % 50 == 0:
                    print(f"    Progress: {t_idx}/{T} steps", end='\r')
        else:
            # For small systems, compute once
            U = expm(-1j * H * dt)
            psi_t = psi0.copy()

            for t_idx in range(T):
                Mx[t_idx] = np.real(np.conj(psi_t) @ Ix_sum @ psi_t)
                My[t_idx] = np.real(np.conj(psi_t) @ Iy_sum @ psi_t)
                I1z[t_idx] = np.real(np.conj(psi_t) @ Iz_first @ psi_t)

                psi_t = U @ psi_t

        elapsed = time.time() - start_time

        return {
            'Mx': Mx,
            'My': My,
            'I1z': I1z,
            'times': times,
            'elapsed_time': elapsed
        }


# ============================================================================
# DATASET WITH CACHING
# ============================================================================

class NMRDataset(Dataset):
    """Dataset of NMR trajectories with caching support"""

    def __init__(self, N: int, topology: str, n_samples: int, T: int, dt: float):
        self.N = N
        self.topology = topology
        self.n_samples = n_samples
        self.T = T
        self.dt = dt
        self.data = []

    def generate_data(self):
        """Generate training data with progress updates"""
        if self.n_samples == 0:
            return

        print(f"Generating {self.n_samples} trajectories for N={self.N}, topology={self.topology}...")

        system = SpinSystemOptimized(self.N, self.topology)

        for i in range(self.n_samples):
            # Random parameters
            Omega = np.random.uniform(-100, 100, self.N) * 2 * np.pi
            J = np.random.uniform(5, 20)

            # Simulate
            result = system.simulate(Omega, J, self.T, self.dt)

            # Store
            params = np.concatenate([Omega, [J]])
            observables = np.stack([result['Mx'], result['My'], result['I1z']], axis=1)

            self.data.append({
                'params': params,
                'observables': observables
            })

            if (i + 1) % 10 == 0:
                print(f"  Generated {i+1}/{self.n_samples}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        params = torch.tensor(item['params'], dtype=torch.float32)
        observables = torch.tensor(item['observables'], dtype=torch.float32)
        return params, observables


# ============================================================================
# NEURAL OPERATOR (same as before)
# ============================================================================

class SpectralConv1d(nn.Module):
    """1D Fourier layer"""

    def __init__(self, in_channels: int, out_channels: int, modes: int):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes = modes

        self.scale = 1 / (in_channels * out_channels)
        self.weights = nn.Parameter(
            self.scale * torch.rand(in_channels, out_channels, self.modes, 2, dtype=torch.float32)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.shape[0]
        x_ft = torch.fft.rfft(x, dim=-1)

        out_ft = torch.zeros(batch_size, self.out_channels, x.size(-1)//2 + 1,
                            dtype=torch.cfloat, device=x.device)

        out_ft[:, :, :self.modes] = torch.einsum(
            "bix,iox->box",
            x_ft[:, :, :self.modes],
            torch.view_as_complex(self.weights)
        )

        x_out = torch.fft.irfft(out_ft, n=x.size(-1), dim=-1)
        return x_out


class PhysicsInformedFNO(nn.Module):
    """FNO with physics-informed losses"""

    def __init__(self, modes: int = 16, width: int = 64, n_layers: int = 4,
                 n_params: int = 13, n_outputs: int = 3):
        super().__init__()

        self.modes = modes
        self.width = width
        self.n_layers = n_layers

        self.param_encoder = nn.Sequential(
            nn.Linear(n_params, width),
            nn.GELU(),
            nn.Linear(width, width)
        )

        self.spectral_layers = nn.ModuleList([
            SpectralConv1d(width, width, modes) for _ in range(n_layers)
        ])

        self.conv_layers = nn.ModuleList([
            nn.Conv1d(width, width, 1) for _ in range(n_layers)
        ])

        self.output_projection = nn.Sequential(
            nn.Linear(width, width),
            nn.GELU(),
            nn.Linear(width, n_outputs)
        )

    def forward(self, params: torch.Tensor, time_steps: int) -> torch.Tensor:
        batch_size = params.shape[0]

        x = self.param_encoder(params)
        x = x.unsqueeze(-1).expand(-1, -1, time_steps)

        for i in range(self.n_layers):
            x1 = self.spectral_layers[i](x)
            x2 = self.conv_layers[i](x)
            x = x1 + x2
            if i < self.n_layers - 1:
                x = F.gelu(x)

        x = x.transpose(1, 2)
        x = self.output_projection(x)

        return x

    def compute_physics_loss(self, predictions: torch.Tensor) -> torch.Tensor:
        Mx = predictions[:, :, 0]
        My = predictions[:, :, 1]
        I1z = predictions[:, :, 2]

        M_mag = torch.sqrt(Mx**2 + My**2)
        magnitude_loss = F.relu(M_mag - 1.0).mean()

        dt_Mx = Mx[:, 1:] - Mx[:, :-1]
        dt_My = My[:, 1:] - My[:, :-1]
        smoothness_loss = (dt_Mx**2 + dt_My**2).mean()

        dt_I1z = I1z[:, 1:] - I1z[:, :-1]
        diffusion_loss = F.relu(dt_I1z).mean()

        return magnitude_loss + 0.1 * smoothness_loss + 0.1 * diffusion_loss


# ============================================================================
# RESUMABLE TRAINING
# ============================================================================

def train_surrogate_resumable(model: nn.Module, train_loader: DataLoader,
                             val_loader: DataLoader, N: int, topology: str,
                             epochs: int = 100, lr: float = 1e-3,
                             device: str = 'cpu',
                             checkpoint_mgr: CheckpointManager = None) -> Dict:
    """Train neural surrogate with checkpoint support"""

    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

    # Try to resume
    start_epoch = 0
    history = {'train_loss': [], 'val_loss': [], 'physics_loss': []}

    if checkpoint_mgr:
        loaded_epoch = checkpoint_mgr.load_model(model, N, topology)
        loaded_history = checkpoint_mgr.load_training_history(N, topology)

        if loaded_epoch is not None:
            start_epoch = loaded_epoch + 1
            if loaded_history:
                history = loaded_history
            print(f"  ✓ Resuming from epoch {start_epoch}")

    if start_epoch >= epochs:
        print(f"  ✓ Training already complete!")
        return history

    for epoch in range(start_epoch, epochs):
        # Training
        model.train()
        train_losses = []
        physics_losses = []

        for params, observables in train_loader:
            params = params.to(device)
            observables = observables.to(device)

            optimizer.zero_grad()

            predictions = model(params, observables.shape[1])

            data_loss = F.mse_loss(predictions, observables)
            physics_loss = model.compute_physics_loss(predictions)

            loss = data_loss + 0.01 * physics_loss

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            train_losses.append(data_loss.item())
            physics_losses.append(physics_loss.item())

        # Validation
        model.eval()
        val_losses = []

        with torch.no_grad():
            for params, observables in val_loader:
                params = params.to(device)
                observables = observables.to(device)

                predictions = model(params, observables.shape[1])
                loss = F.mse_loss(predictions, observables)
                val_losses.append(loss.item())

        scheduler.step()

        avg_train_loss = np.mean(train_losses)
        avg_val_loss = np.mean(val_losses)
        avg_physics_loss = np.mean(physics_losses)

        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['physics_loss'].append(avg_physics_loss)

        # Save checkpoint every 10 epochs
        if checkpoint_mgr and (epoch + 1) % 10 == 0:
            checkpoint_mgr.save_model(model, N, topology, epoch)
            checkpoint_mgr.save_training_history(history, N, topology)

        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{epochs} - Train: {avg_train_loss:.6f}, Val: {avg_val_loss:.6f}, Physics: {avg_physics_loss:.6f}")

    # Final save
    if checkpoint_mgr:
        checkpoint_mgr.save_model(model, N, topology, epochs-1)
        checkpoint_mgr.save_training_history(history, N, topology)

    return history


# ============================================================================
# RESUMABLE BENCHMARKING
# ============================================================================

def benchmark_methods_resumable(N_values: List[int], topology: str = 'chain',
                               config: ExperimentConfig = None,
                               device: str = 'cpu') -> Dict:
    """Benchmark different simulation methods with checkpointing"""

    checkpoint_mgr = CheckpointManager()

    results = {
        'N': [],
        'exact_time': [],
        'surrogate_time': [],
        'surrogate_error': []
    }

    for N in N_values:
        print(f"\n{'='*60}")
        print(f"Benchmarking N={N}, topology={topology}")
        print(f"{'='*60}")

        # Check if already complete
        existing = checkpoint_mgr.load_benchmark_result(N, topology)
        if existing:
            print(f"  ✓ Results already exist, skipping...")
            for key in results:
                if key in existing:
                    results[key].append(existing[key])
            continue

        # Load or generate datasets
        train_dataset = checkpoint_mgr.load_dataset(N, topology, 'train', config.T, config.dt)
        if train_dataset is None:
            train_dataset = NMRDataset(N, topology, config.n_train_samples, config.T, config.dt)
            train_dataset.generate_data()
            checkpoint_mgr.save_dataset(train_dataset, N, topology, 'train')

        val_dataset = checkpoint_mgr.load_dataset(N, topology, 'val', config.T, config.dt)
        if val_dataset is None:
            val_dataset = NMRDataset(N, topology, config.n_val_samples, config.T, config.dt)
            val_dataset.generate_data()
            checkpoint_mgr.save_dataset(val_dataset, N, topology, 'val')

        train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=config.batch_size)

        # Train or load model
        model = PhysicsInformedFNO(modes=config.modes, width=config.width,
                                   n_layers=config.n_layers,
                                   n_params=N+1, n_outputs=3)

        print("Training surrogate...")
        train_surrogate_resumable(model, train_loader, val_loader, N, topology,
                                 epochs=config.epochs, lr=config.lr, device=device,
                                 checkpoint_mgr=checkpoint_mgr)

        # Benchmark
        print("\nRunning benchmark...")
        system = SpinSystemOptimized(N, topology)
        Omega = np.random.uniform(-100, 100, N) * 2 * np.pi
        J = 12.5

        # Exact method
        print("  Running exact method...")
        exact_result = system.simulate(Omega, J, config.T, config.dt)
        exact_time = exact_result['elapsed_time']

        # Surrogate method
        print("  Running surrogate method...")
        model.eval()
        model = model.to(device)
        params_tensor = torch.tensor(np.concatenate([Omega, [J]]),
                                     dtype=torch.float32).unsqueeze(0).to(device)

        start = time.time()
        with torch.no_grad():
            pred = model(params_tensor, config.T)
        surrogate_time = time.time() - start

        pred = pred.squeeze().cpu().numpy()

        # Compute error
        surrogate_error = np.sqrt(
            np.mean((exact_result['Mx'] - pred[:, 0])**2) +
            np.mean((exact_result['My'] - pred[:, 1])**2) +
            np.mean((exact_result['I1z'] - pred[:, 2])**2)
        )

        # Store results
        result = {
            'N': N,
            'exact_time': exact_time,
            'surrogate_time': surrogate_time,
            'surrogate_error': float(surrogate_error),
            'speedup': exact_time / surrogate_time
        }

        checkpoint_mgr.save_benchmark_result(result, N, topology)

        for key in results:
            if key in result:
                results[key].append(result[key])

        print(f"\n✓ Results for N={N}:")
        print(f"  Exact time: {exact_time:.4f}s")
        print(f"  Surrogate time: {surrogate_time:.6f}s")
        print(f"  Speedup: {result['speedup']:.1f}x")
        print(f"  Error: {surrogate_error:.6f}")

    return results


# ============================================================================
# MAIN EXECUTION
# ============================================================================

def main():
    # Auto-detect device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"\nUsing device: {device}")

    # Configuration
    config = ExperimentConfig(
        N_values=[4, 6, 8, 10],  # Start smaller
        topologies=['chain'],
        n_train_samples=50,  # Reduced for faster iteration
        n_val_samples=10,
        T=100,  # Reduced time steps
        dt=1e-4,
        epochs=50,  # Reduced epochs
        batch_size=8,
        lr=1e-3,
        modes=16,
        width=64,
        n_layers=4
    )

    print("\n" + "="*60)
    print("NMR SPIN DYNAMICS - RESUMABLE BENCHMARK")
    print("="*60)
    print(f"Config: N={config.N_values}, T={config.T}, Epochs={config.epochs}")
    print(f"Device: {device}")
    print("="*60)

    # Run benchmark
    results = benchmark_methods_resumable(
        config.N_values,
        topology='chain',
        config=config,
        device=device
    )

    # Save final results
    with open('final_results.json', 'w') as f:
        json.dump(results, f, indent=2)

    print("\n" + "="*60)
    print("BENCHMARK COMPLETE")
    print("="*60)
    print(f"Results saved to: final_results.json")
    print(f"Checkpoints saved to: checkpoints/")

    # Generate scaling plot if we have results
    if len(results['N']) > 0:
        plt.figure(figsize=(10, 6))
        plt.semilogy(results['N'], results['exact_time'], 'o-',
                    label='Exact', linewidth=2, markersize=8)
        plt.semilogy(results['N'], results['surrogate_time'], 's-',
                    label='Surrogate', linewidth=2, markersize=8)
        plt.xlabel('Number of Spins (N)', fontsize=14)
        plt.ylabel('Time (seconds)', fontsize=14)
        plt.title('Computational Scaling (Resumable)', fontsize=16)
        plt.legend(fontsize=12)
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig('scaling_resumable.png', dpi=300, bbox_inches='tight')
        print("Saved: scaling_resumable.png")


if __name__ == "__main__":
    main()