In [None]:
import numpy as np
import math
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
import pandas as pd
import matplotlib.pyplot as plt
import re
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
from matplotlib.collections import LineCollection
import matplotlib.colors as mcolors
import numpy as np
import matplotlib.pyplot as plt

def setup_initial_conditions(N):
    
    """    Set up the initial conditions for the FPU system.    
    Args:        N (int): Number of particles in the FPU system.        
    Returns:        y0 (np.ndarray): Initial state vector of size 2M, where M = N - 1.        
    M (int): Number of interior particles (M = N - 1).    
    """
    
    M = N - 1
    q0_full = np.zeros(M + 2)
    p0_full = np.zeros(M + 2)
    for k in range(1, M + 1):
        q0_full[k] = math.sin(k * math.pi / N)
    y0 = np.concatenate([q0_full[1:-1], p0_full[1:-1]])
    return y0, M

def acceleration(q_full, lambda_val):
    linear = q_full[2:] - 2 * q_full[1:-1] + q_full[:-2]
    nonlinear = lambda_val * ((q_full[2:] - q_full[1:-1])**2
                             - (q_full[1:-1] - q_full[:-2])**2)
    a = np.zeros_like(q_full)
    a[1:-1] = linear + nonlinear
    return a

def rhs(t, y, M, lambda_val):
    q_int = y[:M]
    p_int = y[M:]
    q_full = np.zeros(M + 2)
    q_full[1:-1] = q_int
    a_full = acceleration(q_full, lambda_val)
    dqdt = p_int
    dpdt = a_full[1:-1]
    return np.concatenate([dqdt, dpdt])

def precompute_modes(N, M):
    sin_modes = [np.sin(l * np.arange(1, M+1) * math.pi / N) for l in range(1, M+1)]
    omega_modes = [2 * math.sin(l * math.pi / (2 * N)) for l in range(1, M+1)]
    return sin_modes, omega_modes

def mode_energy(q_int, p_int, sinL, omega_L, N):
    A_L   = math.sqrt(2/N) * np.dot(q_int, sinL)
    A_dot = math.sqrt(2/N) * np.dot(p_int, sinL)
    return 0.5*(A_dot**2 + (omega_L**2)*(A_L**2))

def run_simulation(N=3, lambda_val=0.25, dt=0.1, T_end=3000.0, num_modes_to_plot=3):
    y0, M = setup_initial_conditions(N)
    steps = int(T_end / dt)
    times = np.linspace(0, T_end, steps + 1)

    # Solve ODE
    sol = solve_ivp(rhs,
                    t_span=(0, T_end),
                    y0=y0,
                    args=(M, lambda_val),
                    method='RK45',
                    t_eval=times,
                    atol=1e-9,
                    rtol=1e-9)

    sin_modes, omega_modes = precompute_modes(N, M)

    q_ts = sol.y[:M, :]
    p_ts = sol.y[M:, :]

    # Limit modes for plotting/energy computation
    num_modes_to_plot = min(num_modes_to_plot, M)

    energies = np.zeros((len(sol.t), num_modes_to_plot))
    for idx in range(len(sol.t)):
        qi = q_ts[:, idx]
        pi = p_ts[:, idx]
        for ell in range(num_modes_to_plot):
            energies[idx, ell] = mode_energy(qi, pi, sin_modes[ell], omega_modes[ell], N)

    # Build DataFrame for states
    state_labels = [f'q_{i}' for i in range(1, M+1)] + [f'p_{i}' for i in range(1, M+1)]
    trajectory = sol.y.T
    df_fpu_states = pd.DataFrame(trajectory, columns=state_labels, index=sol.t)

    return sol, df_fpu_states, energies, num_modes_to_plot

def plot_modal_energies(sol, energies, num_modes_to_plot):
    plt.figure(figsize=(12, 6))
    for ell in range(num_modes_to_plot):
        plt.plot(sol.t, energies[:, ell], label=f'Mode {ell+1}')
    plt.xlabel('Time')
    plt.ylabel('Energy')
    plt.title(f'Modal Energies for Modes 1–{num_modes_to_plot} (t up to {sol.t[-1]:.0f})')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

def plot_state_time_series(df_states, num_states_to_plot=3, max_points=1000):
    max_states = min(num_states_to_plot, len(df_states.columns)//2)
    plt.figure(figsize=(12, 6))
    time_vals = df_states.index.values[:max_points]
    for i in range(max_states):
        q_vals = df_states[f'q_{i+1}'].values[:max_points]
        p_vals = df_states[f'p_{i+1}'].values[:max_points]
        plt.plot(time_vals, q_vals, label=f'q_{i+1}')
        plt.plot(time_vals, p_vals, '--', label=f'p_{i+1}')
    plt.xlabel('Time')
    plt.ylabel('State Value')
    plt.title(f'FPU System State Time Series (First {max_states} Particles)')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()
 
def plot_phase_portrait (df_fpu_states, times):
    # Plot phase space trajectory for the first particle
    if times.shape[0] > len(df_fpu_states):
        times = times[:len(df_fpu_states)]
    if 'q_1' not in df_fpu_states or 'p_1' not in df_fpu_states:
        raise ValueError("DataFrame must contain 'q_1' and 'p_1' columns for phase space plot.")
    if len(df_fpu_states) < times.shape[0]:
        raise ValueError("DataFrame has fewer rows than the number of time points provided.")
  
    plt.figure(figsize=(6, 5))
    plt.plot(df_fpu_states['q_1'][:times.shape[0]], df_fpu_states['p_1'][:times.shape[0]], lw=0.1)
    plt.xlabel('q_1')
    plt.ylabel('p_1')
    plt.title('Phase Space Trajectory for Particle 1')
    plt.grid(True)
    plt.tight_layout()
    plt.show()    
    

def plot_phase_matrix(df_fpu_states, times=None, max_points=1000):
    """
    Plot a matrix of phase portraits (q_i vs p_j) for the FPU system.

    Args:
        df_fpu_states: DataFrame with columns q_1,...,q_M, p_1,...,p_M
        times: (optional) time array to determine length; if None, uses DataFrame length
        max_points: maximum number of time steps to plot
    """
    # --- helpers: sort columns by their numeric suffix ---
    def sort_by_suffix(cols, prefix):
        pairs = []
        for c in cols:
            m = re.match(fr"{prefix}_(\d+)$", c)
            if m:
                pairs.append((int(m.group(1)), c))
        return [c for _, c in sorted(pairs)]

    q_cols = sort_by_suffix([c for c in df_fpu_states.columns if c.startswith('q_')], 'q')
    p_cols = sort_by_suffix([c for c in df_fpu_states.columns if c.startswith('p_')], 'p')
    n = min(len(q_cols), len(p_cols))

    # time window
    T = times.shape[0] if times is not None else len(df_fpu_states)
    T = min(T, max_points)

    fig, axes = plt.subplots(n, n, figsize=(3*n, 3*n), sharex=False, sharey=False)
    if n == 1:
        axes = axes.reshape(1, 1)

    for i, q_col in enumerate(q_cols[:n]):
        for j, p_col in enumerate(p_cols[:n]):
            ax = axes[i, j]
            if j < i:
                ax.axis('off')
                continue
            ax.plot(df_fpu_states[q_col][:T], df_fpu_states[p_col][:T], lw=0.3)
            ax.grid(True)
            ax.set_xticks([])
            ax.set_yticks([])
            for spine in ax.spines.values():
                spine.set_visible(True)

    # Outer matrix labels
    for j, p_col in enumerate(p_cols[:n]):
        bbox = axes[0, j].get_position()
        fig.text(bbox.x0 + bbox.width/2, bbox.y1 + 0.01, p_col, ha='center', va='bottom')
    for i, q_col in enumerate(q_cols[:n]):
        bbox = axes[i, 0].get_position()
        fig.text(bbox.x0 - 0.01, bbox.y0 + bbox.height/2, q_col, ha='right', va='center', rotation=90)

    plt.show()
    
def plot_phase_space_3d(df_states):
    """
    Plot the 3D phase space trajectory using the first three states (q_1, q_2, q_3).
    """

    if not all(f'q_{i}' in df_states.columns for i in range(1, 4)):
        raise ValueError("DataFrame must contain 'q_1', 'q_2', and 'q_3' columns.")

    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')
    ax.plot(df_states['q_1'], df_states['q_2'], df_states['q_3'], lw=0.7)
    ax.set_xlabel('q_1')
    ax.set_ylabel('q_2')
    ax.set_zlabel('q_3')
    ax.set_title('3D Phase Space Trajectory (q_1, q_2, q_3)')
    plt.tight_layout()
    plt.show()
    
def plot_hamiltonian_slice_only(df_fpu_states, N, lam, q_idx=0, p_idx=0, n=121):
    """
    Plot true Hamiltonian on a 2D slice (q_idx, p_idx).
    q_idx, p_idx: indices (0-based) of q and p to vary (others fixed at mean).
    """
    Z = df_fpu_states.values.astype(np.float32)
    anchor = Z.mean(axis=0)
    M = Z.shape[1] // 2

    # grid ranges (use percentiles for a nice view)
    q_min, q_max = np.percentile(Z[:, q_idx], [0, 100])
    p_min, p_max = np.percentile(Z[:, M + p_idx], [0, 100])

    q_lin = np.linspace(q_min, q_max, n)
    p_lin = np.linspace(p_min, p_max, n)
    Q, P = np.meshgrid(q_lin, p_lin, indexing="xy")

    grid = np.tile(anchor, (n*n, 1))
    grid[:, q_idx] = Q.ravel()
    grid[:, M + p_idx] = P.ravel()

    # helpers
    def true_energy_np(z, N, lam):
        z = np.atleast_2d(z)
        M = z.shape[1] // 2
        q, p = z[:, :M], z[:, M:]
        KE = 0.5 * np.sum(p**2, axis=1)
        qfull = np.zeros((z.shape[0], M+2))
        qfull[:, 1:-1] = q
        dq = qfull[:, 1:] - qfull[:, :-1]
        PE = np.sum(0.5 * dq**2 + lam * dq**3 / 3.0, axis=1)
        return KE + PE

    H_vals = true_energy_np(grid, N=N, lam=lam).reshape(n, n)

    # plot
    fig, ax = plt.subplots(figsize=(7,5))
    cs = ax.contourf(Q, P, H_vals, levels=60, alpha=0.5, cmap='inferno')
    ax.set_title(f"True Hamiltonian $H(q_{{{q_idx+1}}},p_{{{p_idx+1}}})$")
    ax.set_xlabel(f"$q_{{{q_idx+1}}}$")
    ax.set_ylabel(f"$p_{{{p_idx+1}}}$")
    cbar = fig.colorbar(cs, ax=ax)
    cbar.set_label("H")

    # overlay the trajectory projection
    q_traj = Z[:, q_idx]
    p_traj = Z[:, M + p_idx]
    ax.plot(q_traj, p_traj, c='w', lw=0.1, alpha=0.9, label="trajectory")
    ax.legend()
    plt.tight_layout()
    plt.show()    
    


def animate_instantaneous_contours(df_fpu_states, N, lam,
                                   i=1,              # 1..M (interior index)
                                   stride=1,        # subsample frames
                                   n=181,            # grid resolution
                                   K_frames=None,    # cap frames after striding
                                   interval=60,      # ms/frame
                                   save_path=None,   # "movie.mp4" or ".gif"
                                   # shaded instantaneous field options
                                   levels_bg=30,     # number of filled levels
                                   cmap_bg='inferno',
                                   field_alpha=0.6,
                                   # fading tail (recent) options
                                   tail_len=200,     # recent points to show (in framed steps)
                                   tail_alpha_min=0.05,
                                   tail_lw=1.4,
                                   # NEW: permanent elapsed trail options
                                   show_elapsed_trail=True,
                                   elapsed_color='white',
                                   elapsed_alpha=0.35,
                                   elapsed_lw=0.5):
    """
    Animate the (q_i,p_i) slice with:
      - shaded instantaneous H_slice (contourf per frame)
      - bold instantaneous energy level line
      - red fading tail (recent points)
      - red current point
      - NEW: permanent white trail (history up to current time t), lw=0.5 by default
    """

    Z = df_fpu_states.values.astype(np.float64)
    M = Z.shape[1] // 2
    assert 1 <= i <= M, f"i must be in 1..{M}"
    qi = i - 1
    pi = qi + M

    q = Z[:, :M]
    p = Z[:, M:]
    T = Z.shape[0]

    # neighbors (fixed ends → 0)
    q_prev_t = np.zeros(T) if qi - 1 < 0 else q[:, qi - 1]
    q_next_t = np.zeros(T) if qi + 1 > M - 1 else q[:, qi + 1]

    # total energy H(t)
    qfull = np.zeros((T, M + 2))
    qfull[:, 1:-1] = q
    dq = qfull[:, 1:] - qfull[:, :-1]
    KE = 0.5 * np.sum(p**2, axis=1)
    PE = np.sum(0.5 * dq**2 + lam * dq**3 / 3.0, axis=1)
    H_total = KE + PE

    # grid on (q_i, p_i)
    qmin, qmax = np.percentile(q[:, qi], [1, 99])
    pmin, pmax = np.percentile(p[:, qi], [1, 99])
    pad_q = 0.05 * (qmax - qmin + 1e-12)
    pad_p = 0.05 * (pmax - pmin + 1e-12)
    qmin, qmax = qmin - pad_q, qmax + pad_q
    pmin, pmax = pmin - pad_p, pmax + pad_p

    q_lin = np.linspace(qmin, qmax, n)
    p_lin = np.linspace(pmin, pmax, n)
    Q, P = np.meshgrid(q_lin, p_lin, indexing="xy")

    # helpers to build H_slice
    def H_slice_given_neighbors(q_prev, q_next, const):
        dqL = Q - q_prev
        dqR = q_next - Q
        pot = 0.5 * (dqL**2 + dqR**2) + lam * (dqL**3 + dqR**3) / 3.0
        return const + 0.5 * (P**2) + pot

    def instantaneous_H_slice_at_t(t):
        qi_now = q[t, qi]
        pi_now = p[t, qi]
        dqL_now = qi_now - q_prev_t[t]
        dqR_now = q_next_t[t] - qi_now
        pot_i_now = 0.5 * (dqL_now**2 + dqR_now**2) + lam * (dqL_now**3 + dqR_now**3) / 3.0
        const = H_total[t] - (0.5 * pi_now**2 + pot_i_now)
        return H_slice_given_neighbors(q_prev_t[t], q_next_t[t], const)

    # frame indices
    stride = max(1, int(stride))
    idx = np.arange(0, T, stride)
    if K_frames is not None:
        idx = idx[:int(K_frames)]
    if len(idx) == 0:
        idx = np.array([0])

    # figure & artists
    fig, ax = plt.subplots(figsize=(8.8, 5.8))
    ax.set_xlabel(f"$q_{i}$"); ax.set_ylabel(f"$p_{i}$")
    ax.set_xlim(qmin, qmax);   ax.set_ylim(pmin, pmax)
    ax.set_title(f"(q{i}, p{i}) instantaneous energy level")

    # shaded instantaneous field (contourf) + bold energy level (contour)
    cs_fill = [None]
    cs_line = [None]
    cbar    = [None]

    def clear_fill():
        if cs_fill[0] is not None:
            for coll in cs_fill[0].collections:
                coll.remove()
            cs_fill[0] = None

    def clear_line():
        if cs_line[0] is not None:
            for coll in cs_line[0].collections:
                coll.remove()
            cs_line[0] = None

    # NEW: permanent elapsed trail line (white, low alpha)
    if show_elapsed_trail:
        elapsed_line, = ax.plot([], [], color=elapsed_color, alpha=elapsed_alpha,
                                lw=elapsed_lw, zorder=2)
    else:
        elapsed_line = None

    # red fading tail as LineCollection (recent history)
    tail_coll = LineCollection([], linewidths=tail_lw, zorder=3)
    ax.add_collection(tail_coll)

    # red current point
    point, = ax.plot([], [], "o", ms=5, color="white", zorder=4)

    # legend
    legend_items = [point, tail_coll]
    # legend_labels = ["current", "recent tail"]
    legend_labels = ["current state"]
    if show_elapsed_trail:
        legend_items.insert(0, elapsed_line)
        legend_labels.insert(0, "elapsed trajetory")
    ax.legend(legend_items, legend_labels, framealpha=0.3, loc="upper right")

    def init():
        clear_fill(); clear_line()
        if cbar[0] is not None:
            pass
        tail_coll.set_segments([]); tail_coll.set_color([])
        point.set_data([], [])
        if elapsed_line is not None:
            elapsed_line.set_data([], [])
        return tail_coll, point if elapsed_line is None else (elapsed_line, tail_coll, point)

    def update(k):
        t = idx[k]

        # shaded instantaneous field
        H_slice = instantaneous_H_slice_at_t(t)
        clear_fill()
        cs_fill[0] = ax.contourf(Q, P, H_slice, levels=levels_bg,
                                 cmap=cmap_bg,alpha=field_alpha, antialiased=True, zorder=0)
        if cbar[0] is None:
            cbar[0] = fig.colorbar(cs_fill[0], ax=ax)
            cbar[0].set_label("H")
        else:
            cbar[0].update_normal(cs_fill[0])
        

        # bold instantaneous energy level
        vmin, vmax = float(np.nanmin(H_slice)), float(np.nanmax(H_slice))
        level = float(np.clip(H_total[t], vmin + 1e-12, vmax - 1e-12))
        clear_line()
        
        # map the scalar `level` to RGBA using the contourf's norm & cmap
        norm = cs_fill[0].norm
        cmap = cs_fill[0].cmap
        rgba = list(cmap(norm(level)))
        rgba[-1] = 1.0  # make the line fully opaque

        cs_line[0] = ax.contour(Q, P, H_slice,
                                levels=[level],
                                colors=[tuple(rgba)],
                                linewidths=0.9,
                                zorder=1)  # keep above shading/tails   
        
        # Uncomment to use black contour line instead of colored
        # cs_line[0] = ax.contour(Q, P, H_slice, levels=[level],
        #                         colors="k", linewidths=1.2, zorder=1)

        # NEW: permanent elapsed trail up to current t (white, low alpha)
        if elapsed_line is not None:
            elapsed_line.set_data(q[:t+1, qi], p[:t+1, qi])

        # red fading tail over last `tail_len` frames (post-stride)
        t0 = max(0, t - tail_len)
        x_tail = q[t0:t+1, qi]
        y_tail = p[t0:t+1, qi]
        if len(x_tail) >= 2:
            segs = np.stack(
                [np.column_stack([x_tail[:-1], y_tail[:-1]]),
                 np.column_stack([x_tail[1:],  y_tail[1:]])],
                axis=1
            )
            m = segs.shape[0]
            alphas = np.linspace(tail_alpha_min, 1.0, m)
            red = mcolors.to_rgb('white')
            colors = [(*red, a) for a in alphas]
            tail_coll.set_segments(segs)
            tail_coll.set_color(colors)
        else:
            tail_coll.set_segments([]); tail_coll.set_color([])

        # current point
        point.set_data([q[t, qi]], [p[t, qi]])

        ax.set_title(f"(q{i}, p{i}) instantaneous energy level | t={t}")
        # Return artists (blit=False)
        artists = [*cs_fill[0].collections, *cs_line[0].collections, tail_coll, point]
        if elapsed_line is not None:
            artists.insert(0, elapsed_line)
        return artists

    ani = animation.FuncAnimation(fig, update, init_func=init,
                                  frames=len(idx), interval=interval, blit=False)

    if save_path is not None:
        ani.save(save_path, dpi=120)  # needs pillow (gif) or ffmpeg (mp4)
        print(f"Saved animation to {save_path}")

    plt.show()
    return ani



if __name__ == "__main__":
    N = 10
    lambda_value = 0.25
    sol, df_states, energies, num_modes = run_simulation(N=N, T_end=3000.0, lambda_val=lambda_value, num_modes_to_plot=10)
    plot_modal_energies(sol, energies, num_modes)
    plot_state_time_series(df_states, num_states_to_plot=10, max_points=1000)
    plot_phase_portrait(df_states, sol.t)
    # plot_phase_matrix(df_states, sol.t, max_points=100)
    plot_hamiltonian_slice_only(df_states, N=N, lam=lambda_value, q_idx=0, p_idx=0 , n=121)

    # animate_instantaneous_contours(
    # df_states, N=N, lam=lambda_value,
    # i=1, stride=1, n=181, K_frames=800,
    # levels_bg=30, cmap_bg='inferno', field_alpha=0.65,
    # tail_len=300, tail_alpha_min=0.06, tail_lw=1.4,
    # show_elapsed_trail=True,      # <- permanent trail on
    # elapsed_color='white',        # <- white
    # elapsed_alpha=0.35,           # <- translucent
    # elapsed_lw=0.5,               # <- line weight 0.5
    # save_path=f'hamiltonian_contours_n={N}.gif'
# )



In [None]:
# ==============================
# HYPERPARAMETER SWEEP: SYMPNET vs LSTM on FPU
# ==============================
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

# --------- reproducibility ----------
def set_seed(seed=123):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(123)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# --------- data (from your simulation) ----------
# expects df_fpu_states (T, 2M); N, lambda_value, dt already defined
df_fpu_states = df_states  # from previous simulation
Z_all = df_fpu_states.values.astype(np.float32)
nf    = Z_all.shape[1]                   # = 2*M
T     = len(Z_all)
split = int(0.8 * T)                     # chronological split

# One-step pairs (for SympNet)
Z_tr  = Z_all[:split]
Z_te  = Z_all[split:]
Z_tr_in, Z_tr_out = Z_tr[:-1], Z_tr[1:]
Z_te_in, Z_te_out = Z_te[:-1], Z_te[1:]

train_ds_symp = TensorDataset(torch.from_numpy(Z_tr_in), torch.from_numpy(Z_tr_out))
test_ds_symp  = TensorDataset(torch.from_numpy(Z_te_in), torch.from_numpy(Z_te_out))

# Sequence data (for LSTM)
def make_sequences(Z, seq_len=10):
    X, Y = [], []
    for i in range(len(Z) - seq_len):
        X.append(Z[i:i+seq_len])
        Y.append(Z[i+1:i+seq_len+1])
    return np.asarray(X, np.float32), np.asarray(Y, np.float32)

seq_len = 10
X_tr, Y_tr = make_sequences(Z_tr, seq_len)
X_te, Y_te = make_sequences(Z_te, seq_len)

train_ds_lstm = TensorDataset(torch.from_numpy(X_tr), torch.from_numpy(Y_tr))
test_ds_lstm  = TensorDataset(torch.from_numpy(X_te), torch.from_numpy(Y_te))

# --------- Energy helper (true Hamiltonian) ----------
def energy_np(z_batch, N, lam):
    z = np.atleast_2d(z_batch)
    M = z.shape[1] // 2
    q, p = z[:, :M], z[:, M:]
    KE = 0.5 * np.sum(p**2, axis=1)
    qfull = np.zeros((z.shape[0], M+2), dtype=z.dtype)
    qfull[:, 1:-1] = q
    dq = qfull[:, 1:] - qfull[:, :-1]
    PE = np.sum(0.5 * dq**2 + lam * dq**3 / 3.0, axis=1)
    return KE + PE

# ---------- SympNet (parametric) ----------

class SymplecticNN_Generator(nn.Module):
    def __init__(self, dim, width=64, num_hidden=3):
        super().__init__()
        self.dim = dim
        self.half = dim // 2
        layers = []
        in_dim = dim
        for _ in range(num_hidden):
            layers += [nn.Linear(in_dim, width), nn.Tanh()]
            in_dim = width
        layers += [nn.Linear(in_dim, 1)]
        self.S = nn.Sequential(*layers)

    def forward(self, z, create_graph: bool = True):
        z = z.requires_grad_(True)                    
        S_val = self.S(z)                               
        grads = torch.autograd.grad(
            S_val.sum(), z, create_graph=create_graph)[0]                                           
        dqdt = grads[:, self.half:]                     
        dpdt = -grads[:, :self.half]                    
        return torch.cat([dqdt, dpdt], dim=1)


def symp_midpoint_step(model, z0, h, n_iter=5, create_graph=True):
    z1 = z0.clone()
    for _ in range(n_iter):
        # IMPORTANT: allow grad recording here (even in eval)
        with torch.enable_grad():
            z_mid = 0.5 * (z0 + z1).detach().requires_grad_(True)
            f_mid = model(z_mid, create_graph=create_graph)
            z1 = z0 + h * f_mid
    return z1


mse = nn.MSELoss()
dt = 0.1 

def train_sympnet(dim, width, num_hidden, train_ds, test_ds, epochs=200, lr=1e-3, batch_size=256, h=dt):
    model = SymplecticNN_Generator(dim, width=width, num_hidden=num_hidden).to(device).float()
    opt   = torch.optim.Adam(model.parameters(), lr=lr)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False)

    # --- train ---
    t0 = time.time()
    for ep in range(epochs):
        model.train()
        running = 0.0
        for x0, x1 in train_loader:
            x0 = x0.to(device); x1 = x1.to(device)
            opt.zero_grad()
            x1_hat = symp_midpoint_step(model, x0, h, create_graph=True)
            loss = mse(x1_hat, x1)

            loss.backward(); opt.step()
            running += loss.item() * x0.size(0)
        train_mse = running / len(train_ds)

    # --- eval on test (one-step MSE) ---
    model.eval()
    running_test = 0.0
    for x0, x1 in test_loader:
        x0 = x0.to(device); x1 = x1.to(device)
        x1_hat = symp_midpoint_step(model, x0, h, create_graph=False)
        running_test += mse(x1_hat, x1).item() * x0.size(0)
    test_mse = running_test / len(test_ds)
    

    # --- energy drift on an autoregressive reconstruction (train horizon) ---
    # autoregressive reconstruction over train horizon
    Ztr = Z_tr.astype(np.float32)
    z_t = torch.from_numpy(Ztr[:1]).to(device)
    pred = np.zeros_like(Ztr)
    for t in range(len(Ztr)):
        pred[t] = z_t.squeeze(0).cpu().numpy()
        z_t = symp_midpoint_step(model, z_t, h, create_graph=False)
    H_pred = energy_np(pred, N, lambda_value)
    dH = H_pred - H_pred[0]
    drift_rms = float(np.sqrt(np.mean(dH**2)))
    drift_max = float(np.max(np.abs(dH)))


    n_params = sum(p.numel() for p in model.parameters())
    secs = time.time() - t0
    return model, dict(train_mse=train_mse, test_mse=test_mse,
                       drift_rms=drift_rms, drift_max=drift_max,
                       params=n_params, seconds=secs)

# ---------- LSTM (parametric) ----------
class FPU_LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim=64, num_layers=1):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc   = nn.Linear(hidden_dim, input_dim)
    def forward(self, x, hc=None):
        # x: [B, L, D]
        out, hc = self.lstm(x, hc)         # [B, L, H]
        out = self.fc(out)                  # [B, L, D]
        return out, hc

def train_lstm(input_dim, hidden_dim, train_ds, test_ds,
               epochs=20, lr=1e-3, batch_size=128):
    model = FPU_LSTM(input_dim, hidden_dim=hidden_dim, num_layers=1).to(device).float()
    opt   = torch.optim.Adam(model.parameters(), lr=lr)
    lossf = nn.MSELoss()
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False)

    t0 = time.time()
    for ep in range(epochs):
        model.train()
        running = 0.0
        for xb, yb in train_loader:
            xb = xb.to(device); yb = yb.to(device)
            opt.zero_grad()
            out, _ = model(xb)
            loss = lossf(out, yb)
            loss.backward(); opt.step()
            running += loss.item() * xb.size(0)
        train_mse = running / len(train_ds)

    # test mse (teacher-forced)
    model.eval()
    with torch.no_grad():
        running = 0.0
        for xb, yb in test_loader:
            xb = xb.to(device); yb = yb.to(device)
            out, _ = model(xb)
            running += lossf(out, yb).item() * xb.size(0)
        test_mse = running / len(test_ds)

    # energy drift on autoregressive rollout over the train horizon
    with torch.no_grad():
        # autoregressive 1-step using the LSTM cell with state carry
        z_t = torch.from_numpy(Z_tr[:1]).to(device)      # [1, D]
        hc  = None
        pred = np.zeros_like(Z_tr)
        for t in range(len(Z_tr)):
            # feed as a length-1 sequence
            out, hc = model(z_t.unsqueeze(1), hc)        # [1,1,D]
            z_next = out[:, -1, :]
            pred[t] = z_t.squeeze(0).cpu().numpy()       # store current
            z_t = z_next.detach()                        # next input = last output
        H_pred = energy_np(pred, N, lambda_value)
        dH = H_pred - H_pred[0]
        drift_rms = float(np.sqrt(np.mean(dH**2)))
        drift_max = float(np.max(np.abs(dH)))

    n_params = sum(p.numel() for p in model.parameters())
    secs = time.time() - t0
    return model, dict(train_mse=train_mse, test_mse=test_mse,
                       drift_rms=drift_rms, drift_max=drift_max,
                       params=n_params, seconds=secs)

# ---------- sweep space ----------
width_mults   = [1, 2, 4, 8]          # × nf
hidden_layers = [1, 2, 4, 8]          # SympNet only

# (tune these to your budget)
EPOCHS_SYMP = 10     # e.g., 200–1000; higher = better but slower
EPOCHS_LSTM = 10      # teacher-forced; 50–200 typical
BATCH_SYMP  = 256
BATCH_LSTM  = 128
LR_SYMP     = 1e-3
LR_LSTM     = 1e-3

# ---------- run sweep ----------
results = []

# SympNet grid
for L in hidden_layers:
    for mult in width_mults:
        width = mult * nf
        model, metrics = train_sympnet(
            dim=nf, width=width, num_hidden=L,
            train_ds=train_ds_symp, test_ds=test_ds_symp,
            epochs=EPOCHS_SYMP, lr=LR_SYMP, batch_size=BATCH_SYMP, h=float(dt)
        )
        row = dict(model='SympNet', layers=L, width=width, **metrics)
        results.append(row)
        print(f"[SympNet] layers={L:>2}, width={width:>4} | "
              f"train={metrics['train_mse']:.3e}  test={metrics['test_mse']:.3e}  "
              f"drift_rms={metrics['drift_rms']:.3e}  params={metrics['params']}")

# LSTM grid
for mult in width_mults:
    hid = mult * nf
    model, metrics = train_lstm(
        input_dim=nf, hidden_dim=hid,
        train_ds=train_ds_lstm, test_ds=test_ds_lstm,
        epochs=EPOCHS_LSTM, lr=LR_LSTM, batch_size=BATCH_LSTM
    )
    row = dict(model='LSTM', layers=1, width=hid, **metrics)
    results.append(row)
    print(f"[LSTM]    hidden={hid:>4} | "
          f"train={metrics['train_mse']:.3e}  test={metrics['test_mse']:.3e}  "
          f"drift_rms={metrics['drift_rms']:.3e}  params={metrics['params']}")

# ---------- collate & view ----------
df_sweep = pd.DataFrame(results)
df_sweep = df_sweep[['model','layers','width','params','train_mse','test_mse','drift_rms','drift_max','seconds']]
display(df_sweep.sort_values(['model','test_mse']).reset_index(drop=True))


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# --- 1) Metric vs parameter count (one chart per metric) ---
def plot_metric_vs_params(df, metric, title=None):
    plt.figure(figsize=(7,5))
    for name, g in df.groupby('model'):
        g = g.sort_values('params')
        plt.plot(g['params'].values, g[metric].values, 'o-', label=name)
    plt.xscale('log'); plt.yscale('log')
    plt.xlabel('# parameters'); plt.ylabel(metric)
    if title: plt.title(title)
    plt.grid(True, which='both', linestyle=':')
    plt.legend()
    plt.tight_layout()
    plt.show()

plot_metric_vs_params(df_sweep, 'test_mse',  title='Test MSE vs #params')
plot_metric_vs_params(df_sweep, 'train_mse', title='Reconstruction MSE vs #params')
plot_metric_vs_params(df_sweep, 'drift_rms', title='Energy drift (RMS) vs #params')

# --- 2) Pareto-style tradeoff: Test MSE vs Energy Drift (size ~ params) ---
def plot_pareto(df, x='test_mse', y='drift_rms'):
    plt.figure(figsize=(7,5))
    for name, g in df.groupby('model'):
        sizes = 30 * (np.log10(g['params'].values) - np.log10(df['params'].min()) + 1.0)
        plt.scatter(g[x].values, g[y].values, s=sizes, alpha=0.8, label=name)
    plt.xscale('log'); plt.yscale('log')
    plt.xlabel(x); plt.ylabel(y)
    plt.title('Tradeoff: Test MSE vs Energy Drift (marker size ~ #params)')
    plt.grid(True, which='both', linestyle=':')
    plt.legend()
    plt.tight_layout()
    plt.show()

plot_pareto(df_sweep)

# --- 3) SympNet heatmaps: test MSE / drift across (layers × width_mult) ---
# (width_mult = width / nf)
symp = df_sweep[df_sweep['model']=='SympNet'].copy()
symp['width_mult'] = (symp['width'] / float(nf)).round().astype(int)

def heatmap_metric(symp_df, metric, title):
    pt = symp_df.pivot_table(index='layers', columns='width_mult', values=metric, aggfunc='min')
    layers_sorted = sorted(pt.index)
    cols_sorted   = sorted(pt.columns)
    pt = pt.loc[layers_sorted, cols_sorted]

    plt.figure(figsize=(6,4.8))
    im = plt.imshow(np.log10(pt.values), aspect='auto', origin='lower',
                    extent=[min(cols_sorted)-0.5, max(cols_sorted)+0.5,
                            min(layers_sorted)-0.5, max(layers_sorted)+0.5])
    cbar = plt.colorbar(im)
    cbar.set_label(f'log10({metric})')
    plt.xticks(cols_sorted); plt.yticks(layers_sorted)
    plt.xlabel('width_mult (× nf)'); plt.ylabel('hidden layers')
    plt.title(title)
    plt.tight_layout()
    plt.show()

if not symp.empty:
    heatmap_metric(symp, 'test_mse',  'SympNet: test MSE across (layers × width_mult)')
    heatmap_metric(symp, 'drift_rms', 'SympNet: drift RMS across (layers × width_mult)')

# --- 4) Param-matched head-to-head (nearest #params LSTM for each SympNet row) ---
def nearest_param_pairs(df):
    symp = df[df.model=='SympNet'].copy()
    lstm = df[df.model=='LSTM'].copy()
    rows = []
    for _, r in symp.iterrows():
        j = (lstm['params'] - r['params']).abs().idxmin()
        m = lstm.loc[j]
        rows.append({
            'symp_params': int(r['params']),
            'lstm_params': int(m['params']),
            'symp_layers': int(r['layers']),
            'symp_width' : int(r['width']),
            'lstm_width' : int(m['width']),
            'symp_test_mse': r['test_mse'],
            'lstm_test_mse': m['test_mse'],
            'symp_drift_rms': r['drift_rms'],
            'lstm_drift_rms': m['drift_rms'],
        })
    return pd.DataFrame(rows)

df_match = nearest_param_pairs(df_sweep)
display(df_match.sort_values('symp_params').reset_index(drop=True))

# Optional quick bar plots for matched pairs
def bar_compare(df_match, metric, title):
    x = np.arange(len(df_match))
    width = 0.4
    plt.figure(figsize=(8,4.5))
    plt.bar(x - width/2, df_match[f'symp_{metric}'].values, width, label='SympNet')
    plt.bar(x + width/2, df_match[f'lstm_{metric}'].values, width, label='LSTM')
    plt.yscale('log')
    plt.xticks(x, [f"{p//1000}k" for p in df_match['symp_params']], rotation=0)
    plt.xlabel('~Matched parameter count (SympNet side)')
    plt.ylabel(metric)
    plt.title(title)
    plt.grid(True, which='both', axis='y', linestyle=':')
    plt.legend()
    plt.tight_layout()
    plt.show()

if len(df_match):
    bar_compare(df_match, 'test_mse',   'Param-matched: Test MSE (log scale)')
    bar_compare(df_match, 'drift_rms',  'Param-matched: Energy drift RMS (log scale)')


In [None]:
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mse = nn.MSELoss()

# ----- data from your sim -----
Z_all = df_fpu_states.values.astype(np.float32)   # [T, 2M]
nf    = Z_all.shape[1]
T     = len(Z_all)
split = int(0.8 * T)

# SympNet one-step pairs
Z_tr  = Z_all[:split]
Z_te  = Z_all[split:]
Z_tr_in, Z_tr_out = Z_tr[:-1], Z_tr[1:]
Z_te_in, Z_te_out = Z_te[:-1], Z_te[1:]
train_ds_symp = TensorDataset(torch.from_numpy(Z_tr_in), torch.from_numpy(Z_tr_out))
test_ds_symp  = TensorDataset(torch.from_numpy(Z_te_in), torch.from_numpy(Z_te_out))

# LSTM sequences
def make_sequences(Z, seq_len=10):
    X, Y = [], []
    for i in range(len(Z) - seq_len):
        X.append(Z[i:i+seq_len]); Y.append(Z[i+1:i+seq_len+1])
    return np.asarray(X, np.float32), np.asarray(Y, np.float32)

seq_len = 10
X_tr, Y_tr = make_sequences(Z_tr, seq_len)
X_te, Y_te = make_sequences(Z_te, seq_len)
train_ds_lstm = TensorDataset(torch.from_numpy(X_tr), torch.from_numpy(Y_tr))
test_ds_lstm  = TensorDataset(torch.from_numpy(X_te), torch.from_numpy(Y_te))

# ----- energy helper -----
def energy_np(z_batch, N, lam):
    z = np.atleast_2d(z_batch)
    M = z.shape[1] // 2
    q, p = z[:, :M], z[:, M:]
    KE = 0.5 * np.sum(p**2, axis=1)
    qfull = np.zeros((z.shape[0], M+2), dtype=z.dtype); qfull[:,1:-1] = q
    dq = qfull[:,1:] - qfull[:,:-1]
    PE = np.sum(0.5*dq**2 + lam*dq**3/3.0, axis=1)
    return KE + PE

# ----- SympNet model & step -----
class SymplecticNN_Generator(nn.Module):
    def __init__(self, dim, width=64, num_hidden=3):
        super().__init__()
        self.dim = dim; self.half = dim // 2
        layers = []; d = dim
        for _ in range(num_hidden):
            layers += [nn.Linear(d, width), nn.Tanh()]; d = width
        layers += [nn.Linear(d, 1)]
        self.S = nn.Sequential(*layers)
    def forward(self, z, create_graph=True):
        z = z.requires_grad_(True)
        S_val = self.S(z)
        grads = torch.autograd.grad(S_val.sum(), z, create_graph=create_graph)[0]
        dqdt = grads[:, self.half:]; dpdt = -grads[:, :self.half]
        return torch.cat([dqdt, dpdt], dim=1)

def symp_midpoint_step(model, z0, h, n_iter=5, create_graph=True):
    z1 = z0.clone()
    for _ in range(n_iter):
        with torch.enable_grad():
            z_mid = 0.5*(z0 + z1).detach().requires_grad_(True)
            f_mid = model(z_mid, create_graph=create_graph)
            z1 = z0 + h * f_mid
    return z1

# ----- LSTM model -----
class FPU_LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim=64, num_layers=1):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc   = nn.Linear(hidden_dim, input_dim)
    def forward(self, x, hc=None):
        out, hc = self.lstm(x, hc)
        out = self.fc(out)
        return out, hc

# ----- drift evaluators (short horizon to keep it fast) -----
def drift_rms_symp(model, Z_ref, N, lam, h, H=1000):
    H = min(H, len(Z_ref))
    z_t = torch.from_numpy(Z_ref[:1]).to(device)
    pred = np.zeros((H, Z_ref.shape[1]), dtype=np.float32)
    for t in range(H):
        pred[t] = z_t.squeeze(0).cpu().numpy()
        z_t = symp_midpoint_step(model, z_t, h, create_graph=False)
    dH = energy_np(pred, N, lam) - energy_np(pred[:1], N, lam)[0]
    return float(np.sqrt(np.mean(dH**2)))

def drift_rms_lstm(model, Z_ref, H=1000):
    H = min(H, len(Z_ref))
    z_t = torch.from_numpy(Z_ref[:1]).to(device)
    hc  = None
    pred = np.zeros((H, Z_ref.shape[1]), dtype=np.float32)
    for t in range(H):
        pred[t] = z_t.squeeze(0).cpu().numpy()
        out, hc = model(z_t.unsqueeze(1), hc)     # [1,1,D]
        z_t = out[:, -1, :].detach()
    # Use same FPU Hamiltonian as truth for drift metric
    dH = energy_np(pred, N, lambda_value) - energy_np(pred[:1], N, lambda_value)[0]
    return float(np.sqrt(np.mean(dH**2)))

# ----- training with history -----
def train_sympnet_with_history(dim, width, num_hidden, train_ds, test_ds,
                               epochs=200, lr=1e-3, batch_size=256, h=dt,
                               eval_every=5, drift_H=800):
    model = SymplecticNN_Generator(dim, width=width, num_hidden=num_hidden).to(device).float()
    opt   = torch.optim.Adam(model.parameters(), lr=lr)
    tl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    vl = DataLoader(test_ds,  batch_size=batch_size, shuffle=False)

    hist = dict(train=[], test=[], drift=[])
    for ep in range(epochs):
        model.train()
        run = 0.0
        for x0, x1 in tl:
            x0 = x0.to(device); x1 = x1.to(device)
            opt.zero_grad()
            x1_hat = symp_midpoint_step(model, x0, h, create_graph=True)
            loss = mse(x1_hat, x1)
            loss.backward(); opt.step()
            run += loss.item() * x0.size(0)
        train_mse = run / len(train_ds)
        hist['train'].append(train_mse)

        if (ep % eval_every == 0) or (ep == epochs-1):
            model.eval()
            # test MSE (needs grad enabled for ∇S, but no backprop graph)
            test_run = 0.0
            for x0, x1 in vl:
                x0 = x0.to(device); x1 = x1.to(device)
                x1_hat = symp_midpoint_step(model, x0, h, create_graph=False)
                test_run += mse(x1_hat, x1).item() * x0.size(0)
            test_mse = test_run / len(test_ds)
            hist['test'].append((ep, test_mse))

            # drift RMS on a short reconstruction
            d_rms = drift_rms_symp(model, Z_tr.astype(np.float32), N, lambda_value, h, H=drift_H)
            hist['drift'].append((ep, d_rms))
    return model, hist

def train_lstm_with_history(input_dim, hidden_dim, train_ds, test_ds,
                            epochs=200, lr=1e-3, batch_size=128,
                            eval_every=5, drift_H=800):
    model = FPU_LSTM(input_dim, hidden_dim=hidden_dim, num_layers=1).to(device).float()
    opt   = torch.optim.Adam(model.parameters(), lr=lr)
    tl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    vl = DataLoader(test_ds,  batch_size=batch_size, shuffle=False)

    hist = dict(train=[], test=[], drift=[])
    for ep in range(epochs):
        model.train()
        run = 0.0
        for xb, yb in tl:
            xb = xb.to(device); yb = yb.to(device)
            opt.zero_grad()
            out, _ = model(xb)
            loss = mse(out, yb)
            loss.backward(); opt.step()
            run += loss.item() * xb.size(0)
        train_mse = run / len(train_ds)
        hist['train'].append(train_mse)

        if (ep % eval_every == 0) or (ep == epochs-1):
            model.eval()
            test_run = 0.0
            for xb, yb in vl:
                xb = xb.to(device); yb = yb.to(device)
                out, _ = model(xb)
                test_run += mse(out, yb).item() * xb.size(0)
            test_mse = test_run / len(test_ds)
            hist['test'].append((ep, test_mse))

            d_rms = drift_rms_lstm(model, Z_tr.astype(np.float32), H=drift_H)
            hist['drift'].append((ep, d_rms))
    return model, hist

# ---------- Neural ODE (discrete RK4) ----------
class ODEFunc(nn.Module):
    """ Vector field f_theta(z) with Tanh MLP. """
    def __init__(self, dim, width=64, num_hidden=3):
        super().__init__()
        layers = []
        d = dim
        for _ in range(num_hidden):
            layers += [nn.Linear(d, width), nn.Tanh()]
            d = width
        layers += [nn.Linear(d, dim)]
        self.net = nn.Sequential(*layers)
    def forward(self, z):
        return self.net(z)

def rk4_step(func, z, h):
    k1 = func(z)
    k2 = func(z + 0.5*h*k1)
    k3 = func(z + 0.5*h*k2)
    k4 = func(z + h*k3)
    return z + (h/6.0)*(k1 + 2*k2 + 2*k3 + k4)


def train_neural_ode_with_history(dim, width, num_hidden,
                                  train_ds, test_ds,
                                  epochs=200, lr=1e-3, batch_size=256, h=dt,
                                  eval_every=10, drift_H=800):
    model = ODEFunc(dim, width=width, num_hidden=num_hidden).to(device).float()
    opt   = torch.optim.Adam(model.parameters(), lr=lr)
    tl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    vl = DataLoader(test_ds,  batch_size=batch_size, shuffle=False)

    hist = dict(train=[], test=[], drift=[])
    for ep in range(epochs):
        model.train()
        run = 0.0
        for x0, x1 in tl:
            x0 = x0.to(device); x1 = x1.to(device)
            opt.zero_grad()
            x1_hat = rk4_step(model, x0, h)
            loss = mse(x1_hat, x1)
            loss.backward(); opt.step()
            run += loss.item() * x0.size(0)
        hist['train'].append(run / len(train_ds))

        if (ep % eval_every == 0) or (ep == epochs-1):
            model.eval()
            with torch.no_grad():
                # test one-step MSE
                run = 0.0
                for x0, x1 in vl:
                    x0 = x0.to(device); x1 = x1.to(device)
                    x1_hat = rk4_step(model, x0, h)
                    run += mse(x1_hat, x1).item() * x0.size(0)
                hist['test'].append((ep, run / len(test_ds)))

                # drift RMS on short autoregressive reconstruction
                Ztr = Z_tr.astype(np.float32)
                H = min(drift_H, len(Ztr))
                z_t = torch.from_numpy(Ztr[:1]).to(device)
                pred = np.zeros((H, Ztr.shape[1]), dtype=np.float32)
                for t in range(H):
                    pred[t] = z_t.squeeze(0).cpu().numpy()
                    z_t = rk4_step(model, z_t, h)
                dH = energy_np(pred, N, lambda_value) - energy_np(pred[:1], N, lambda_value)[0]
                hist['drift'].append((ep, float(np.sqrt(np.mean(dH**2)))))
    return model, hist



In [None]:
def plot_convergence_multi(hists: dict, title="Convergence"):
    colors = dict(SympNet='C0', LSTM='C1', NeuralODE='C2')
    fig, axs = plt.subplots(3,1, figsize=(9,9), sharex=True)

    # Train curves (every epoch)
    for name, h in hists.items():
        tr = np.array(h['train'])
        axs[0].plot(np.arange(len(tr)), tr, lw=1.4, label=name, color=colors.get(name))
    axs[0].set_ylabel('Train MSE'); axs[0].set_yscale('log'); axs[0].grid(True); axs[0].legend()

    # Test curves (sampled by eval_every)
    for name, h in hists.items():
        te = np.array(h['test'])
        axs[1].plot(te[:,0], te[:,1], 'o-', lw=1.4, ms=4, label=name, color=colors.get(name))
    axs[1].set_ylabel('Test MSE'); axs[1].set_yscale('log'); axs[1].grid(True); axs[1].legend()

    # Drift curves (sampled by eval_every)
    for name, h in hists.items():
        ds = np.array(h['drift'])
        axs[2].plot(ds[:,0], ds[:,1], 'o-', lw=1.4, ms=4, label=name, color=colors.get(name))
    axs[2].set_ylabel('Drift RMS (ΔH)'); axs[2].set_xlabel('Epoch'); axs[2].set_yscale('log'); axs[2].grid(True); axs[2].legend()

    fig.suptitle(title)
    plt.tight_layout(); plt.show()



def plot_metric_vs_params(df, metric='test_mse', title=None):
    fig, ax = plt.subplots(figsize=(7.5,5.5))
    markers = {'SympNet':'o', 'LSTM':'s', 'NeuralODE':'^'}
    for m in sorted(df['model'].unique()):
        sub = df[df['model']==m]
        ax.scatter(sub['params'], sub[metric], s=50, alpha=0.9,
                   label=m, marker=markers.get(m,'o'))
    ax.set_xscale('log'); ax.set_yscale('log')
    ax.set_xlabel('# parameters (log)'); ax.set_ylabel(metric + ' (log)')
    ax.grid(True, which='both', ls='--', alpha=0.3)
    if title: ax.set_title(title)
    ax.legend()
    plt.tight_layout(); plt.show()




In [None]:
# =========================
# SWEEP + METRICS + PLOTS
# =========================
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import torch

def n_params(m): return sum(p.numel() for p in m.parameters())
def last_metric(hist, key):
    # hist['test'] and hist['drift'] are lists of (epoch, value)
    return hist[key][-1][1] if (key in hist and len(hist[key])>0) else np.nan

# grids
# width_mults   = [1, 2, 4, 8]
# hidden_layers = [1, 2, 4, 8]   # for SympNet & Neural ODE
# width_mults   = [1]
# hidden_layers = [1]   # for SympNet & Neural ODE

# budgets (tune to your patience)
EPOCHS_SYMP = 10
EPOCHS_LSTM = 10
EPOCHS_ODE  = 200
BATCH_SYMP  = 256
BATCH_LSTM  = 128
LR_SYMP     = 1e-3
LR_LSTM     = 1e-3
LR_ODE      = 1e-3
EVAL_EVERY  = 10
DRIFT_H     = 800

results = []

# ---- SympNet sweep ----
for L in hidden_layers:
    for mult in width_mults:
        W = mult * nf
        model, hist = train_sympnet_with_history(
            dim=nf, width=W, num_hidden=L,
            train_ds=train_ds_symp, test_ds=test_ds_symp,
            epochs=EPOCHS_SYMP, lr=LR_SYMP, batch_size=BATCH_SYMP, h=float(dt),
            eval_every=EVAL_EVERY, drift_H=DRIFT_H
        )
        row = dict(
            model='SympNet', layers=L, width=W, params=n_params(model),
            train_mse=hist['train'][-1],
            test_mse=last_metric(hist,'test'),
            drift_rms=last_metric(hist,'drift'),
            drift_max=np.nan, seconds=np.nan
        )
        results.append(row)
        print(f"[SympNet] L={L} W={W} | train={row['train_mse']:.3e}  test={row['test_mse']:.3e}  drift_rms={row['drift_rms']:.3e}  params={row['params']}")

# ---- LSTM sweep (hidden size only; num_layers=1) ----
for mult in width_mults:
    H = mult * nf
    model, hist = train_lstm_with_history(
        input_dim=nf, hidden_dim=H,
        train_ds=train_ds_lstm, test_ds=test_ds_lstm,
        epochs=EPOCHS_LSTM, lr=LR_LSTM, batch_size=BATCH_LSTM,
        eval_every=EVAL_EVERY, drift_H=DRIFT_H
    )
    row = dict(
        model='LSTM', layers=1, width=H, params=n_params(model),
        train_mse=hist['train'][-1],
        test_mse=last_metric(hist,'test'),
        drift_rms=last_metric(hist,'drift'),
        drift_max=np.nan, seconds=np.nan
    )
    results.append(row)
    print(f"[LSTM]    H={H} | train={row['train_mse']:.3e}  test={row['test_mse']:.3e}  drift_rms={row['drift_rms']:.3e}  params={row['params']}")

# ---- Neural ODE sweep ----
for L in hidden_layers:
    for mult in width_mults:
        W = mult * nf
        model, hist = train_neural_ode_with_history(
            dim=nf, width=W, num_hidden=L,
            train_ds=train_ds_symp, test_ds=test_ds_symp,
            epochs=EPOCHS_ODE, lr=LR_ODE, batch_size=BATCH_SYMP, h=float(dt),
            eval_every=EVAL_EVERY, drift_H=DRIFT_H
        )
        # Convert history -> metrics for the sweep table
        metrics = {
            'train_mse': float(hist['train'][-1]),
            'test_mse':  float(last_metric(hist, 'test')),
            'drift_rms': float(last_metric(hist, 'drift')),
            'drift_max': np.nan,            # not tracked in *_with_history (optional to add)
            'params':    n_params(model),
            'seconds':   np.nan             # add timing if you want
        }
        row = dict(model='NeuralODE', layers=L, width=W, **metrics)
        results.append(row)
        print(f"[NeuralODE] L={L} W={W} | "
              f"train={metrics['train_mse']:.3e}  "
              f"test={metrics['test_mse']:.3e}  "
              f"drift_rms={metrics['drift_rms']:.3e}  "
              f"params={metrics['params']}")


plot_metric_vs_params(df_sweep, metric='test_mse',  title='Test MSE vs Parameters')
plot_metric_vs_params(df_sweep, metric='drift_rms', title='Energy Drift (RMS ΔH) vs Parameters')
best_rows = (df_sweep
             .sort_values('test_mse')
             .groupby('model', as_index=False)
             .first())

# Training budgets for the convergence runs (can reuse your sweep budgets)
EPOCHS_CONV_SYMP = EPOCHS_SYMP
EPOCHS_CONV_LSTM = EPOCHS_LSTM
EPOCHS_CONV_ODE  = EPOCHS_ODE

hists = {}

for _, r in best_rows.iterrows():
    if r['model'] == 'SympNet':
        L, W = int(r['layers']), int(r['width'])
        symp_best_model, symp_best_hist = train_sympnet_with_history(
            dim=nf, width=W, num_hidden=L,
            train_ds=train_ds_symp, test_ds=test_ds_symp,
            epochs=EPOCHS_CONV_SYMP, lr=LR_SYMP, batch_size=BATCH_SYMP, h=float(dt),
            eval_every=EVAL_EVERY, drift_H=DRIFT_H
        )
        hists['SympNet'] = symp_best_hist

    elif r['model'] == 'LSTM':
        Hsz = int(r['width'])  # we stored hidden size in 'width' for LSTM
        lstm_best_model, lstm_best_hist = train_lstm_with_history(
            input_dim=nf, hidden_dim=Hsz,
            train_ds=train_ds_lstm, test_ds=test_ds_lstm,
            epochs=EPOCHS_CONV_LSTM, lr=LR_LSTM, batch_size=BATCH_LSTM,
            eval_every=EVAL_EVERY, drift_H=DRIFT_H
        )
        hists['LSTM'] = lstm_best_hist

    elif r['model'] == 'NeuralODE':
        L, W = int(r['layers']), int(r['width'])
        ode_best_model, ode_best_hist = train_neural_ode_with_history(
            dim=nf, width=W, num_hidden=L,
            train_ds=train_ds_symp, test_ds=test_ds_symp,
            epochs=EPOCHS_CONV_ODE, lr=LR_ODE, batch_size=BATCH_SYMP, h=float(dt),
            eval_every=EVAL_EVERY, drift_H=DRIFT_H
        )
        hists['NeuralODE'] = ode_best_hist

# 4) Plot all three on one convergence figure
plot_convergence_multi(
    hists,
    title="Convergence (best configs by sweep)"
)


