In [None]:
# Setup and synthetic data generation.
# Generates 2-ball elastic collision videos for physics validation.

!pip install -q timm einops av opencv-python-headless

import os
import cv2
import numpy as np
from pathlib import Path

SEED = 42
np.random.seed(SEED)

DATA_ROOT = Path("/content/physion_data")
LATENT_ROOT = Path("/content/latents")
SYNTH_DIR = DATA_ROOT / "synthetic_collisions"

for folder in [DATA_ROOT, LATENT_ROOT, SYNTH_DIR]:
    folder.mkdir(parents=True, exist_ok=True)

def generate_collision_video(video_id, n_frames=150, size=256):
    """Generate a 2-ball elastic collision video."""
    frames = []

    x1, y1 = 64.0, 128.0
    x2, y2 = 192.0, 128.0
    vx1, vy1 = 2.0 + np.random.rand(), (np.random.rand() - 0.5) * 2
    vx2, vy2 = -2.0 - np.random.rand(), (np.random.rand() - 0.5) * 2

    radius = 20

    for _ in range(n_frames):
        frame = np.zeros((size, size, 3), dtype=np.uint8)
        cv2.circle(frame, (int(x1), int(y1)), radius, (255, 255, 0), -1)
        cv2.circle(frame, (int(x2), int(y2)), radius, (0, 0, 255), -1)
        frames.append(frame)

        x1 += vx1; y1 += vy1
        x2 += vx2; y2 += vy2

        # Wall collisions
        if x1 < radius or x1 > size - radius: vx1 *= -1
        if y1 < radius or y1 > size - radius: vy1 *= -1
        if x2 < radius or x2 > size - radius: vx2 *= -1
        if y2 < radius or y2 > size - radius: vy2 *= -1

        # Ball-to-ball elastic collision
        dx, dy = x2 - x1, y2 - y1
        dist = np.sqrt(dx**2 + dy**2)
        if dist < 2 * radius and dist > 0:
            nx, ny = dx / dist, dy / dist
            dvx, dvy = vx1 - vx2, vy1 - vy2
            dvn = dvx * nx + dvy * ny
            # Velocity exchange along normal
            vx1 -= dvn * nx; vy1 -= dvn * ny
            vx2 += dvn * nx; vy2 += dvn * ny
            # Separate overlapping balls
            overlap = 2 * radius - dist
            x1 -= overlap * nx / 2; y1 -= overlap * ny / 2
            x2 += overlap * nx / 2; y2 += overlap * ny / 2

    path = SYNTH_DIR / f"collision_{video_id:04d}.mp4"
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(str(path), fourcc, 30, (size, size))
    for frame in frames:
        out.write(frame)
    out.release()
    return str(path)

num_videos = 50
print(f"Generating {num_videos} synthetic physics videos...")
video_files = [generate_collision_video(i) for i in range(num_videos)]
print(f"Generated {num_videos} videos in: {SYNTH_DIR}")

CONFIG = {
    "data_root": str(DATA_ROOT),
    "video_dir": str(SYNTH_DIR),
    "latent_dir": str(LATENT_ROOT),
    "video_files": video_files,
    "window_size": 16,
    "stride": 1,
    "latent_dim": 1024,
    "q_dim": 512,
    "p_dim": 512,
    "batch_size": 32,
    "num_epochs": 50,
}

print(f"Video count: {len(CONFIG['video_files'])}")
print(f"Sample path: {CONFIG['video_files'][0]}")

In [None]:
# V-JEPA 2 latent extraction (Section 3.1).
# Encodes video clips into 1024-d latent vectors, then creates sliding windows.

!pip install -q torchcodec

import torch
from torchcodec.decoders import VideoDecoder
from transformers import AutoVideoProcessor, AutoModel

print("Loading V-JEPA 2 ViT-L model...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

MODEL_NAME = "facebook/vjepa2-vitl-fpc16-256-ssv2"
processor = AutoVideoProcessor.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    attn_implementation="sdpa"
).to(device)
model.eval()

print(f"Model loaded: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M params, "
      f"{torch.cuda.memory_allocated() / 1e9:.2f}GB VRAM")

@torch.no_grad()
def extract_clip_latents(video_path, clip_size=16, stride=4):
    """Extract latents for strided clips using V-JEPA 2.
    Extraction stride (4) subsamples clips for speed; distinct from the
    windowing stride (1) used below to build overlapping training windows.
    """
    vr = VideoDecoder(video_path)
    total_frames = vr.metadata.num_frames

    if total_frames < clip_size:
        return None

    all_latents = []
    for start_idx in range(0, total_frames - clip_size + 1, stride):
        frame_indices = np.arange(start_idx, start_idx + clip_size)
        video = vr.get_frames_at(indices=frame_indices).data

        inputs = processor(video, return_tensors="pt")
        inputs = {k: v.to(device, dtype=torch.float16) if v.dtype == torch.float32 else v.to(device)
                  for k, v in inputs.items()}

        outputs = model(**inputs)
        latent = outputs.last_hidden_state.mean(dim=1)
        all_latents.append(latent.cpu().float())

    return torch.cat(all_latents, dim=0)


def create_windows(latents, window_size=16, stride=1):
    """Convert [T, D] sequence to [num_windows, window_size, D] sliding windows."""
    T, D = latents.shape
    num_windows = (T - window_size) // stride + 1
    if num_windows <= 0:
        return None
    windows = []
    for i in range(num_windows):
        start = i * stride
        windows.append(latents[start:start + window_size])
    return torch.stack(windows, dim=0)


print(f"\nExtracting latents from {len(CONFIG['video_files'])} videos...")

latent_dir = Path(CONFIG['latent_dir'])
latent_dir.mkdir(exist_ok=True)

total_windows = 0
successful = 0

for idx, video_path in enumerate(CONFIG['video_files']):
    video_name = Path(video_path).stem
    output_path = latent_dir / f"{video_name}_latents.pt"

    if output_path.exists():
        print(f"[{idx+1}/{len(CONFIG['video_files'])}] Skipping {video_name} (exists)")
        data = torch.load(output_path)
        total_windows += data.shape[0]
        successful += 1
        continue

    print(f"[{idx+1}/{len(CONFIG['video_files'])}] Processing {video_name}...", end=" ")
    clip_latents = extract_clip_latents(video_path, clip_size=16, stride=4)

    if clip_latents is None or len(clip_latents) < CONFIG['window_size']:
        print("SKIPPED (too short)")
        continue

    windows = create_windows(clip_latents, window_size=CONFIG['window_size'], stride=CONFIG['stride'])
    if windows is None:
        print("SKIPPED (insufficient windows)")
        continue

    torch.save(windows, output_path)
    total_windows += windows.shape[0]
    successful += 1
    print(f"{windows.shape[0]} windows")

print(f"\nExtraction complete: {successful}/{len(CONFIG['video_files'])} videos, {total_windows} windows")

# Free VRAM
del model, processor
torch.cuda.empty_cache()
import gc
gc.collect()

if torch.cuda.is_available():
    print(f"CUDA memory: {torch.cuda.memory_allocated() / 1e9:.2f}GB allocated")

latent_files = list(latent_dir.glob("*_latents.pt"))
CONFIG['latent_files'] = [str(f) for f in latent_files]
print(f"{len(latent_files)} latent files ready.")

In [None]:
# Hamiltonian world model (Sections 3.2-3.3).
# Learns scalar energy H(q,p) and evolves states via leapfrog integration.

import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


class LatentHNN(nn.Module):
    """Hamiltonian Neural Network: MLP mapping (q, p) -> scalar H.
    z in R^1024 splits into q in R^512 (coordinates) and p in R^512 (momenta).
    Dynamics follow dq/dt = dH/dp, dp/dt = -dH/dq.
    """

    def __init__(self, latent_dim=1024, hidden_dims=[512, 512, 256]):
        super().__init__()
        self.latent_dim = latent_dim
        self.q_dim = latent_dim // 2
        self.p_dim = latent_dim // 2

        layers = []
        in_dim = latent_dim
        for h_dim in hidden_dims:
            layers.append(nn.Linear(in_dim, h_dim))
            layers.append(nn.Softplus(beta=1.0))  # C-inf for differentiable Hamilton's eqs
            in_dim = h_dim
        layers.append(nn.Linear(in_dim, 1))
        self.net = nn.Sequential(*layers)
        self._init_weights()

    def _init_weights(self):
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight, gain=0.5)
                nn.init.zeros_(m.bias)

    def forward(self, z):
        return self.net(z)

    def split_state(self, z):
        return z[..., :self.q_dim], z[..., self.q_dim:]

    def merge_state(self, q, p):
        return torch.cat([q, p], dim=-1)

    def time_derivative(self, z):
        """Compute dz/dt via Hamilton's equations: dq/dt = dH/dp, dp/dt = -dH/dq."""
        z = z.requires_grad_(True)
        H = self.forward(z)

        dH_dz = torch.autograd.grad(
            H.sum(), z, create_graph=True, retain_graph=True
        )[0]

        dH_dq = dH_dz[..., :self.q_dim]
        dH_dp = dH_dz[..., self.q_dim:]

        dq_dt = dH_dp       # dq/dt =  dH/dp
        dp_dt = -dH_dq      # dp/dt = -dH/dq

        return self.merge_state(dq_dt, dp_dt)


class SymplecticIntegrator(nn.Module):
    """Stormer-Verlet (leapfrog) integrator.
    Second-order symplectic: preserves the canonical 2-form exactly,
    so no artificial energy drift accumulates over long rollouts.
    """

    def __init__(self, hnn, dt=1.0):
        super().__init__()
        self.hnn = hnn
        self.dt = nn.Parameter(torch.tensor(dt), requires_grad=True)
        self.q_dim = hnn.q_dim

    def _compute_gradients(self, q, p):
        z = self.hnn.merge_state(q, p)
        z = z.requires_grad_(True)
        H = self.hnn(z)
        dH_dz = torch.autograd.grad(
            H.sum(), z, create_graph=True, retain_graph=True
        )[0]
        return dH_dz[..., :self.q_dim], dH_dz[..., self.q_dim:]

    def step(self, z):
        """One leapfrog step: z_t -> z_{t+1}."""
        q, p = self.hnn.split_state(z)
        dt = torch.clamp(self.dt, 0.1, 2.0)

        # Half-step momentum
        dH_dq, _ = self._compute_gradients(q, p)
        p_half = p - (dt / 2) * dH_dq

        # Full-step position
        _, dH_dp = self._compute_gradients(q, p_half)
        q_next = q + dt * dH_dp

        # Half-step momentum
        dH_dq_next, _ = self._compute_gradients(q_next, p_half)
        p_next = p_half - (dt / 2) * dH_dq_next

        return self.hnn.merge_state(q_next, p_next)

    def rollout(self, z0, n_steps):
        """Integrate forward for n_steps, returning full trajectory including z0."""
        trajectory = [z0]
        z = z0
        for _ in range(n_steps):
            z = self.step(z)
            trajectory.append(z)
        return torch.stack(trajectory, dim=1)


class EulerIntegrator(nn.Module):
    """Forward Euler integrator (non-symplectic). Ablation baseline."""
    def __init__(self, hnn, dt=1.0):
        super().__init__()
        self.hnn = hnn
        self.dt = nn.Parameter(torch.tensor(dt), requires_grad=True)

    def step(self, z):
        dt = torch.clamp(self.dt, 0.1, 2.0)
        dz_dt = self.hnn.time_derivative(z)
        return z + dt * dz_dt

    def rollout(self, z0, n_steps):
        trajectory = [z0]
        z = z0
        for _ in range(n_steps):
            z = self.step(z)
            trajectory.append(z)
        return torch.stack(trajectory, dim=1)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

hnn = LatentHNN(
    latent_dim=CONFIG['latent_dim'],
    hidden_dims=[512, 512, 256]
).to(device)

integrator = SymplecticIntegrator(hnn, dt=1.0).to(device)

print(f"LatentHNN: {sum(p.numel() for p in hnn.parameters()) / 1e6:.2f}M parameters")
print(f"Integrator dt (learnable, clamped to [0.1, 2.0]): {integrator.dt.item():.3f}")
print(f"Device: {device}")

# Sanity check
z_test = torch.randn(4, 1024, device=device)
H_test = hnn(z_test)
print(f"\nSanity check:")
print(f"  H(z) shape: {H_test.shape}")
print(f"  dz/dt shape: {hnn.time_derivative(z_test).shape}")

z_next = integrator.step(z_test)
traj = integrator.rollout(z_test, n_steps=5)
H_traj = hnn(traj.view(-1, 1024)).view(4, 6)
energy_drift = (H_traj - H_traj[:, 0:1]).abs().mean().item()
print(f"  Rollout shape: {traj.shape}")
print(f"  Energy drift over 5 steps: {energy_drift:.6f}")

# Sanity check Euler integrator
euler_test = EulerIntegrator(hnn, dt=1.0).to(device)
euler_traj = euler_test.rollout(z_test, n_steps=5)
print(f"  Euler rollout shape: {euler_traj.shape}")

In [None]:
# Training and evaluation (Section 4).
# Trains the HNN with prediction + energy-conservation loss.
# Uses lambda_energy curriculum: ramps from 0.1 to 1.0 over training.

from torch.utils.data import Dataset, DataLoader, random_split
from tqdm.auto import tqdm


class LatentTrajectoryDataset(Dataset):
    def __init__(self, latent_files):
        self.trajectories = []
        for f in latent_files:
            data = torch.load(f, weights_only=True)
            for i in range(data.shape[0]):
                self.trajectories.append(data[i])
        print(f"Loaded {len(self.trajectories)} trajectory windows")

    def __len__(self):
        return len(self.trajectories)

    def __getitem__(self, idx):
        return self.trajectories[idx]


dataset = LatentTrajectoryDataset(CONFIG['latent_files'])

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

torch.manual_seed(SEED)

train_dataset, test_dataset = random_split(
    dataset, [train_size, test_size],
    generator=torch.Generator().manual_seed(SEED)
)
print(f"Train: {len(train_dataset)}, Test: {len(test_dataset)}")

BATCH_SIZE = CONFIG['batch_size']
NUM_EPOCHS = CONFIG['num_epochs']

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

optimizer = torch.optim.AdamW(
    list(hnn.parameters()) + [integrator.dt],
    lr=1e-3, weight_decay=1e-5
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-5)

# Energy loss weight curriculum: ramp from 0.1 -> 1.0 over training
LAMBDA_ENERGY_START = 0.1
LAMBDA_ENERGY_END = 1.0


def get_lambda_energy(epoch, num_epochs):
    """Linearly ramp energy conservation weight over training."""
    progress = (epoch - 1) / max(num_epochs - 1, 1)
    return LAMBDA_ENERGY_START + progress * (LAMBDA_ENERGY_END - LAMBDA_ENERGY_START)


# Symplectic structure is enforced architecturally by the leapfrog integrator,
# not via an explicit Jacobian loss term.
def compute_loss(z_pred, z_target, hnn_model, lambda_energy=0.1):
    """Loss = L_pred + lambda * L_energy."""
    l_pred = F.mse_loss(z_pred, z_target)

    B, T, D = z_pred.shape
    H = hnn_model(z_pred.reshape(B * T, D)).reshape(B, T)
    l_energy = (H - H[:, 0:1]).abs().mean()

    total = l_pred + lambda_energy * l_energy
    return total, {
        'total': total.item(),
        'prediction': l_pred.item(),
        'energy': l_energy.item(),
    }


def train_epoch(epoch, log_interval=10):
    hnn.train()
    integrator.train()
    epoch_losses = {'total': [], 'prediction': [], 'energy': []}
    lambda_e = get_lambda_energy(epoch, NUM_EPOCHS)
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")

    for batch_idx, batch in enumerate(pbar):
        batch = batch.to(device)
        B, T, D = batch.shape

        z_init = batch[:, 0, :]
        z_target = batch[:, 1:, :]
        z_pred = integrator.rollout(z_init, n_steps=T-1)[:, 1:, :].contiguous()

        loss, loss_dict = compute_loss(z_pred, z_target, hnn, lambda_energy=lambda_e)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(hnn.parameters(), max_norm=1.0)
        torch.nn.utils.clip_grad_norm_([integrator.dt], max_norm=1.0)
        optimizer.step()

        for k, v in loss_dict.items():
            epoch_losses[k].append(v)

        if batch_idx % log_interval == 0:
            pbar.set_postfix({
                'loss': f"{loss_dict['total']:.4f}",
                'pred': f"{loss_dict['prediction']:.4f}",
                'energy': f"{loss_dict['energy']:.4f}",
                'dt': f"{integrator.dt.item():.3f}",
                'lam_e': f"{lambda_e:.2f}"
            })

    return {k: np.mean(v) for k, v in epoch_losses.items()}


def evaluate_model(model_or_integrator, model_type, hnn_for_energy):
    """Evaluate any model on the full test set, returning per-sample metrics.

    Args:
        model_or_integrator: SymplecticIntegrator, EulerIntegrator, or BaselineMLP
        model_type: 'hnn' (uses enable_grad for autograd) or 'baseline' (no_grad)
        hnn_for_energy: the HNN to use for computing energy H(z)

    Returns:
        dict with 'mse_mean', 'mse_std', 'drift_mean', 'drift_std',
              'per_sample_mse', 'per_sample_drift', 'per_timestep_mse', 'per_timestep_energy'
    """
    model_or_integrator.eval()
    if hasattr(model_or_integrator, 'hnn'):
        model_or_integrator.hnn.eval()

    all_sample_mse = []
    all_sample_drift = []
    # Collect per-timestep metrics: list of arrays, one per batch
    timestep_mse_accum = []
    timestep_energy_accum = []

    for batch in test_loader:
        batch = batch.to(device)
        B, T, D = batch.shape
        z_init = batch[:, 0, :]
        z_target = batch[:, 1:, :]

        # Rollout
        if model_type == 'hnn':
            with torch.enable_grad():
                z_pred = model_or_integrator.rollout(
                    z_init.requires_grad_(True), n_steps=T-1
                )[:, 1:, :].contiguous()
            z_pred_detached = z_pred.detach()
        else:
            with torch.no_grad():
                z_pred = model_or_integrator.rollout(z_init, n_steps=T-1)[:, 1:, :].contiguous()
            z_pred_detached = z_pred

        # Per-sample MSE: mean over time and dims for each sample
        per_sample = ((z_pred_detached - z_target) ** 2).mean(dim=(1, 2))  # [B]
        all_sample_mse.append(per_sample.cpu())

        # Per-timestep MSE: mean over batch and dims for each timestep
        per_ts_mse = ((z_pred_detached - z_target) ** 2).mean(dim=(0, 2))  # [T-1]
        timestep_mse_accum.append(per_ts_mse.cpu())

        # Energy metrics
        with torch.no_grad():
            H_pred = hnn_for_energy(z_pred_detached.reshape(-1, D)).reshape(B, T-1)
            H_init = hnn_for_energy(z_init).reshape(B, 1)
            drift_per_sample = (H_pred - H_init).abs().max(dim=1).values  # [B]
            all_sample_drift.append(drift_per_sample.cpu())

            # Per-timestep energy: mean H over batch at each timestep
            timestep_energy_accum.append(H_pred.mean(dim=0).cpu())  # [T-1]

    all_sample_mse = torch.cat(all_sample_mse)
    all_sample_drift = torch.cat(all_sample_drift)
    avg_timestep_mse = torch.stack(timestep_mse_accum).mean(dim=0)
    avg_timestep_energy = torch.stack(timestep_energy_accum).mean(dim=0)
    std_timestep_mse = torch.stack(timestep_mse_accum).std(dim=0)
    std_timestep_energy = torch.stack(timestep_energy_accum).std(dim=0)

    return {
        'mse_mean': all_sample_mse.mean().item(),
        'mse_std': all_sample_mse.std().item(),
        'drift_mean': all_sample_drift.mean().item(),
        'drift_std': all_sample_drift.std().item(),
        'per_sample_mse': all_sample_mse.numpy(),
        'per_sample_drift': all_sample_drift.numpy(),
        'per_timestep_mse_mean': avg_timestep_mse.numpy(),
        'per_timestep_mse_std': std_timestep_mse.numpy(),
        'per_timestep_energy_mean': avg_timestep_energy.numpy(),
        'per_timestep_energy_std': std_timestep_energy.numpy(),
        'n_samples': len(all_sample_mse),
    }


def evaluate():
    """Legacy evaluate for training loop compatibility."""
    metrics = evaluate_model(integrator, 'hnn', hnn)
    return {'mse': metrics['mse_mean'], 'energy_drift': metrics['drift_mean']}


def dream(z_init, n_steps=20):
    """Autonomous rollout from a single initial state."""
    hnn.eval()
    integrator.eval()
    with torch.enable_grad():
        z_in = z_init.clone().requires_grad_(True)
        trajectory = integrator.rollout(z_in, n_steps).squeeze(0)
    trajectory = trajectory.detach()
    with torch.no_grad():
        energies = hnn(trajectory).squeeze(-1)
    return trajectory, energies


# Training loop
print(f"Training: {NUM_EPOCHS} epochs, batch size {BATCH_SIZE}, {len(train_loader)} batches/epoch")
print(f"Energy loss curriculum: lambda_energy {LAMBDA_ENERGY_START} -> {LAMBDA_ENERGY_END}")

history = {
    'train_loss': [], 'train_pred': [], 'train_energy': [],
    'test_mse': [], 'test_energy_drift': [], 'dt': [], 'lambda_energy': []
}
best_test_mse = float('inf')

for epoch in range(1, NUM_EPOCHS + 1):
    train_losses = train_epoch(epoch)
    test_metrics = evaluate()
    scheduler.step()
    lambda_e = get_lambda_energy(epoch, NUM_EPOCHS)

    history['train_loss'].append(train_losses['total'])
    history['train_pred'].append(train_losses['prediction'])
    history['train_energy'].append(train_losses['energy'])
    history['test_mse'].append(test_metrics['mse'])
    history['test_energy_drift'].append(test_metrics['energy_drift'])
    history['dt'].append(integrator.dt.item())
    history['lambda_energy'].append(lambda_e)

    print(f"Epoch {epoch}/{NUM_EPOCHS} | "
          f"Train: {train_losses['total']:.4f} (pred {train_losses['prediction']:.4f}, energy {train_losses['energy']:.4f}) | "
          f"Test MSE: {test_metrics['mse']:.4f}, drift: {test_metrics['energy_drift']:.4f} | "
          f"dt: {integrator.dt.item():.4f}, lambda_e: {lambda_e:.2f}, lr: {scheduler.get_last_lr()[0]:.6f}")

    if test_metrics['mse'] < best_test_mse:
        best_test_mse = test_metrics['mse']
        torch.save({
            'hnn_state_dict': hnn.state_dict(),
            'integrator_state_dict': integrator.state_dict(),
            'epoch': epoch,
            'test_mse': test_metrics['mse']
        }, '/content/best_hnn_model.pt')
        print(f"  New best model saved (MSE={best_test_mse:.4f})")

print(f"\nBest test MSE: {best_test_mse:.4f}, final dt: {integrator.dt.item():.4f}")

# 20-step dreaming evaluation
test_sample = next(iter(test_loader))[0:1].to(device)
z_start = test_sample[:, 0, :]
dream_traj, dream_energies = dream(z_start, n_steps=20)

print(f"\nDreaming rollout (20 steps):")
print(f"  Energy t=0: {dream_energies[0].item():.4f}, t=20: {dream_energies[-1].item():.4f}")
print(f"  Energy drift: {(dream_energies[-1] - dream_energies[0]).abs().item():.6f}")
print(f"  Max deviation: {(dream_energies - dream_energies[0]).abs().max().item():.6f}")

if test_sample.shape[1] >= 21:
    gt_traj = test_sample[0, :21, :]
    print(f"  MSE vs ground truth: {F.mse_loss(dream_traj, gt_traj).item():.4f}")

DREAM_RESULTS = {
    'trajectory': dream_traj.cpu(),
    'energies': dream_energies.cpu(),
    'z_start': z_start.cpu(),
    'ground_truth': test_sample[0].cpu() if test_sample.shape[1] >= 21 else None
}
TRAINING_HISTORY = history


# ============================================================
# Ablation training: generic helper + two ablation models
# ============================================================

def train_model(model_integrator, model_hnn, model_type, num_epochs, lambda_energy_fn,
                label="Model"):
    """Generic training loop for HNN-type models (with integrators).

    Args:
        model_integrator: integrator wrapping model_hnn (SymplecticIntegrator or EulerIntegrator)
        model_hnn: the HNN whose parameters are trained
        model_type: 'hnn' (needs enable_grad for rollout)
        num_epochs: training epochs
        lambda_energy_fn: callable(epoch, num_epochs) -> float
        label: display name
    """
    opt = torch.optim.AdamW(
        list(model_hnn.parameters()) + [model_integrator.dt],
        lr=1e-3, weight_decay=1e-5
    )
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=num_epochs, eta_min=1e-5)

    print(f"\nTraining {label}: {num_epochs} epochs")
    for epoch in range(1, num_epochs + 1):
        model_hnn.train()
        model_integrator.train()
        epoch_loss = 0
        n_batches = 0
        lambda_e = lambda_energy_fn(epoch, num_epochs)

        for batch in train_loader:
            batch = batch.to(device)
            B, T, D = batch.shape
            z_init = batch[:, 0, :]
            z_target = batch[:, 1:, :]

            z_pred = model_integrator.rollout(z_init, n_steps=T-1)[:, 1:, :].contiguous()
            loss, _ = compute_loss(z_pred, z_target, model_hnn, lambda_energy=lambda_e)

            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model_hnn.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_([model_integrator.dt], max_norm=1.0)
            opt.step()
            epoch_loss += loss.item()
            n_batches += 1

        sched.step()
        if epoch % 10 == 0 or epoch == 1:
            print(f"  [{label}] Epoch {epoch}/{num_epochs}, Loss: {epoch_loss/n_batches:.4f}, "
                  f"lambda_e: {lambda_e:.2f}, dt: {model_integrator.dt.item():.3f}")


# --- Ablation 1: HNN (no energy loss) ---
print("\n" + "="*60)
print("ABLATION 1: HNN with leapfrog, lambda_energy=0 throughout")
print("="*60)
torch.manual_seed(SEED)
hnn_no_energy = LatentHNN(latent_dim=CONFIG['latent_dim'], hidden_dims=[512, 512, 256]).to(device)
integrator_no_energy = SymplecticIntegrator(hnn_no_energy, dt=1.0).to(device)

train_model(
    integrator_no_energy, hnn_no_energy, model_type='hnn',
    num_epochs=NUM_EPOCHS,
    lambda_energy_fn=lambda epoch, n: 0.0,  # No energy loss
    label="HNN (no energy)"
)

# --- Ablation 2: HNN (Euler integrator) ---
print("\n" + "="*60)
print("ABLATION 2: HNN with Euler integrator + energy loss")
print("="*60)
torch.manual_seed(SEED)
hnn_euler = LatentHNN(latent_dim=CONFIG['latent_dim'], hidden_dims=[512, 512, 256]).to(device)
integrator_euler = EulerIntegrator(hnn_euler, dt=1.0).to(device)

train_model(
    integrator_euler, hnn_euler, model_type='hnn',
    num_epochs=NUM_EPOCHS,
    lambda_energy_fn=get_lambda_energy,  # Same curriculum as main HNN
    label="HNN (Euler)"
)

# Store ablation models for cell 4
ABLATION_MODELS = {
    'hnn_no_energy': (hnn_no_energy, integrator_no_energy),
    'hnn_euler': (hnn_euler, integrator_euler),
}
print("\nAll ablation models trained.")

In [None]:
# Visualization (Section 5).
# Energy conservation, phase portraits, HNN vs baseline MLP comparison,
# ablation bar chart, and summary table.
# Baseline trained with multi-step rollout loss to match HNN's training formulation.

import matplotlib.pyplot as plt
import time

plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.dpi'] = 120
plt.rcParams['font.size'] = 10
plt.rcParams['axes.titlesize'] = 12
plt.rcParams['axes.labelsize'] = 10

EVAL_HORIZON = CONFIG['window_size'] - 1  # 15 steps = training horizon
EXTENDED_HORIZON = 50

# ============================================================
# Baseline MLP: trained with multi-step rollout loss
# ============================================================

class BaselineMLP(nn.Module):
    """Direct z_{t+1} = f(z_t) predictor, no Hamiltonian structure."""
    def __init__(self, latent_dim=1024, hidden_dims=[512, 512]):
        super().__init__()
        layers = []
        in_dim = latent_dim
        for h_dim in hidden_dims:
            layers.append(nn.Linear(in_dim, h_dim))
            layers.append(nn.ReLU())
            in_dim = h_dim
        layers.append(nn.Linear(in_dim, latent_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, z):
        return self.net(z)

    def rollout(self, z0, n_steps):
        trajectory = [z0]
        z = z0
        for _ in range(n_steps):
            z = self.forward(z)
            trajectory.append(z)
        return torch.stack(trajectory, dim=1)


torch.manual_seed(SEED)
baseline = BaselineMLP(latent_dim=CONFIG['latent_dim']).to(device)
baseline_optimizer = torch.optim.AdamW(baseline.parameters(), lr=1e-3, weight_decay=1e-5)
baseline_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    baseline_optimizer, T_max=NUM_EPOCHS, eta_min=1e-5
)

BASELINE_EPOCHS = NUM_EPOCHS  # Match HNN training budget
print(f"Training baseline MLP with MULTI-STEP rollout loss ({BASELINE_EPOCHS} epochs)...")
print(f"  Optimizer: AdamW (lr=1e-3, wd=1e-5), cosine scheduler, grad clip 1.0")
baseline.train()
for epoch in range(1, BASELINE_EPOCHS + 1):
    epoch_loss = 0
    n_batches = 0
    for batch in train_loader:
        batch = batch.to(device)
        B, T, D = batch.shape
        z_init = batch[:, 0, :]
        z_target = batch[:, 1:, :]  # [B, T-1, D]

        # Multi-step rollout: same formulation as HNN
        z_pred = baseline.rollout(z_init, n_steps=T-1)[:, 1:, :]  # [B, T-1, D]
        loss = F.mse_loss(z_pred, z_target)

        baseline_optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(baseline.parameters(), max_norm=1.0)
        baseline_optimizer.step()
        epoch_loss += loss.item()
        n_batches += 1

    baseline_scheduler.step()
    if epoch % 10 == 0 or epoch == 1:
        print(f"  Epoch {epoch}/{BASELINE_EPOCHS}, Loss: {epoch_loss/n_batches:.4f}")

baseline.eval()

# ============================================================
# Collapse detection: compare step sizes for baseline and HNN
# ============================================================
test_sample = next(iter(test_loader))[0:1].to(device)
z_start = test_sample[:, 0, :]

with torch.no_grad():
    baseline_traj_full = baseline.rollout(z_start, n_steps=EXTENDED_HORIZON).squeeze(0)
    baseline_mean_step = torch.norm(
        baseline_traj_full[1:] - baseline_traj_full[:-1], dim=-1
    ).mean().item()

with torch.enable_grad():
    z_in = z_start.clone().requires_grad_(True)
    hnn_traj_full = integrator.rollout(z_in, n_steps=EXTENDED_HORIZON).squeeze(0).detach()

hnn_mean_step = torch.norm(
    hnn_traj_full[1:] - hnn_traj_full[:-1], dim=-1
).mean().item()

print(f"\nCollapse detection (mean step size in latent space):")
print(f"  HNN (leapfrog):  {hnn_mean_step:.4f}")
print(f"  Baseline MLP:    {baseline_mean_step:.4f}")
if baseline_mean_step < 0.01:
    print("  >>> WARNING: Baseline may have collapsed (near-zero step sizes).")
    print("  >>> Its 'low energy variance' would be meaningless -- states barely change.")

# ============================================================
# Full test-set evaluation for all 4 models
# ============================================================
print("\nEvaluating all models on full test set...")

metrics_hnn = evaluate_model(integrator, 'hnn', hnn)
metrics_baseline = evaluate_model(baseline, 'baseline', hnn)

hnn_ne, integ_ne = ABLATION_MODELS['hnn_no_energy']
metrics_no_energy = evaluate_model(integ_ne, 'hnn', hnn_ne)

hnn_eu, integ_eu = ABLATION_MODELS['hnn_euler']
metrics_euler = evaluate_model(integ_eu, 'hnn', hnn_eu)

n_test = metrics_hnn['n_samples']
print(f"  Evaluated over {n_test} test windows")

# ============================================================
# Computational cost timing
# ============================================================
print("\nTiming inference (15-step rollout)...")
z_bench = torch.randn(1, CONFIG['latent_dim'], device=device)
N_WARMUP = 10
N_TIMED = 100

def time_model(model_or_integ, model_type, label):
    """Time a 15-step rollout. Returns mean and std in ms."""
    for _ in range(N_WARMUP):
        if model_type == 'hnn':
            with torch.enable_grad():
                _ = model_or_integ.rollout(z_bench.requires_grad_(True), n_steps=EVAL_HORIZON)
        else:
            with torch.no_grad():
                _ = model_or_integ.rollout(z_bench, n_steps=EVAL_HORIZON)
        if torch.cuda.is_available():
            torch.cuda.synchronize()

    times = []
    for _ in range(N_TIMED):
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        t0 = time.perf_counter()
        if model_type == 'hnn':
            with torch.enable_grad():
                _ = model_or_integ.rollout(z_bench.requires_grad_(True), n_steps=EVAL_HORIZON)
        else:
            with torch.no_grad():
                _ = model_or_integ.rollout(z_bench, n_steps=EVAL_HORIZON)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        times.append((time.perf_counter() - t0) * 1000)  # ms

    return np.mean(times), np.std(times)

time_hnn_mean, time_hnn_std = time_model(integrator, 'hnn', 'HNN (leapfrog)')
time_euler_mean, time_euler_std = time_model(integ_eu, 'hnn', 'HNN (Euler)')
time_baseline_mean, time_baseline_std = time_model(baseline, 'baseline', 'Baseline MLP')

print(f"  HNN (leapfrog): {time_hnn_mean:.2f} +/- {time_hnn_std:.2f} ms "
      f"({time_hnn_mean/EVAL_HORIZON:.2f} ms/step)")
print(f"  HNN (Euler):    {time_euler_mean:.2f} +/- {time_euler_std:.2f} ms "
      f"({time_euler_mean/EVAL_HORIZON:.2f} ms/step)")
print(f"  Baseline MLP:   {time_baseline_mean:.2f} +/- {time_baseline_std:.2f} ms "
      f"({time_baseline_mean/EVAL_HORIZON:.2f} ms/step)")
print(f"  HNN/Baseline ratio: {time_hnn_mean/time_baseline_mean:.1f}x")
print(f"  Note: HNN requires torch.enable_grad() at inference (autograd for Hamilton's eqs)")

# ============================================================
# Plot 1: Energy conservation (single sample, 16-step + 50-step)
# ============================================================
print("\nPlot 1: Energy conservation...")

with torch.no_grad():
    energies_ext = hnn(hnn_traj_full).squeeze(-1).cpu().numpy()

fig1, (ax1a, ax1b) = plt.subplots(1, 2, figsize=(14, 4))

timesteps_train = np.arange(EVAL_HORIZON + 1)
E0 = energies_ext[0]
ax1a.plot(timesteps_train, energies_ext[:EVAL_HORIZON+1], 'b-', linewidth=2, label='H(z_t)')
ax1a.axhline(y=E0, color='r', linestyle='--', linewidth=1.5, alpha=0.7, label=f'H(z_0) = {E0:.4f}')
ax1a.fill_between(timesteps_train, E0, energies_ext[:EVAL_HORIZON+1], alpha=0.2, color='blue')
ax1a.set_xlabel('Timestep')
ax1a.set_ylabel('Hamiltonian H(q, p)')
ax1a.set_title(f'Energy Conservation: {EVAL_HORIZON}-Step (Training Horizon)')
ax1a.legend(loc='upper right')
drift_train = np.abs(energies_ext[:EVAL_HORIZON+1] - E0)
ax1a.text(0.02, 0.98,
         f'Mean drift: {drift_train.mean():.6f}\nMax drift: {drift_train.max():.6f}\nStd: {drift_train.std():.6f}',
         transform=ax1a.transAxes, verticalalignment='top',
         bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5), fontsize=9)

timesteps_ext = np.arange(EXTENDED_HORIZON + 1)
ax1b.plot(timesteps_ext, energies_ext, 'b-', linewidth=2, label='H(z_t)')
ax1b.axhline(y=E0, color='r', linestyle='--', linewidth=1.5, alpha=0.7, label=f'H(z_0) = {E0:.4f}')
ax1b.axvline(x=EVAL_HORIZON, color='green', linestyle=':', alpha=0.5, label=f'Train horizon ({EVAL_HORIZON})')
ax1b.fill_between(timesteps_ext, E0, energies_ext, alpha=0.2, color='blue')
ax1b.set_xlabel('Timestep')
ax1b.set_ylabel('Hamiltonian H(q, p)')
ax1b.set_title(f'Energy Conservation: {EXTENDED_HORIZON}-Step (Extended)')
ax1b.legend(loc='upper right')
drift_ext = np.abs(energies_ext - E0)
ax1b.text(0.02, 0.98,
         f'Mean drift: {drift_ext.mean():.6f}\nMax drift: {drift_ext.max():.6f}\nStd: {drift_ext.std():.6f}',
         transform=ax1b.transAxes, verticalalignment='top',
         bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5), fontsize=9)

plt.tight_layout()
plt.savefig('/content/plot1_energy_conservation.png', dpi=150, bbox_inches='tight')
plt.show()

# ============================================================
# Plot 2: Phase portrait (single sample, qualitative)
# ============================================================
print("Plot 2: Phase portrait...")

traj_np = hnn_traj_full.cpu().numpy()
q_dim = traj_np.shape[1] // 2
q = traj_np[:, :q_dim]
p = traj_np[:, q_dim:]

fig2, axes = plt.subplots(2, 4, figsize=(14, 7))

for i in range(4):
    ax = axes[0, i]
    ax.plot(q[:, i], p[:, i], 'b-', linewidth=1, alpha=0.7)
    ax.scatter(q[0, i], p[0, i], c='green', s=100, marker='o', zorder=5, label='Start')
    ax.scatter(q[-1, i], p[-1, i], c='red', s=100, marker='x', zorder=5, label='End')
    ax.set_xlabel(f'$q_{{{i}}}$')
    ax.set_ylabel(f'$p_{{{i}}}$')
    ax.set_title(f'Phase Space: Dim {i}')
    if i == 0:
        ax.legend(fontsize=8)

for i in range(4):
    ax = axes[1, i]
    ax.plot(timesteps_ext, q[:, i], 'b-', linewidth=1.5, label=f'$q_{{{i}}}$')
    ax.plot(timesteps_ext, p[:, i], 'r--', linewidth=1.5, label=f'$p_{{{i}}}$')
    ax.set_xlabel('Timestep')
    ax.set_ylabel('Value')
    ax.set_title(f'Time Evolution: Dim {i}')
    ax.legend(fontsize=8)

plt.suptitle('Phase Space Dynamics in Learned Hamiltonian Latent Space', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig('/content/plot2_phase_portrait.png', dpi=150, bbox_inches='tight')
plt.show()

# ============================================================
# Plot 3: HNN vs Baseline — error bars over full test set
# ============================================================
print("Plot 3: HNN vs baseline (with error bars over test set)...")

fig3, axes = plt.subplots(2, 2, figsize=(12, 10))

# Top-left: Prediction error per timestep (mean +/- std over test set)
ax = axes[0, 0]
ts = np.arange(1, len(metrics_hnn['per_timestep_mse_mean']) + 1)

ax.plot(ts, metrics_hnn['per_timestep_mse_mean'], 'b-', linewidth=2, label='HNN (Hamiltonian)')
ax.fill_between(ts,
    metrics_hnn['per_timestep_mse_mean'] - metrics_hnn['per_timestep_mse_std'],
    metrics_hnn['per_timestep_mse_mean'] + metrics_hnn['per_timestep_mse_std'],
    alpha=0.2, color='blue')

ax.plot(ts, metrics_baseline['per_timestep_mse_mean'], 'r--', linewidth=2, label='Baseline MLP')
ax.fill_between(ts,
    metrics_baseline['per_timestep_mse_mean'] - metrics_baseline['per_timestep_mse_std'],
    metrics_baseline['per_timestep_mse_mean'] + metrics_baseline['per_timestep_mse_std'],
    alpha=0.2, color='red')

ax.set_xlabel('Timestep')
ax.set_ylabel('MSE vs Ground Truth')
ax.set_title(f'Prediction Error ({EVAL_HORIZON}-Step, mean +/- 1 std)')
ax.legend()
ax.set_yscale('log')

# Top-right: Energy evolution per timestep (mean +/- std over test set)
ax = axes[0, 1]
ax.plot(ts, metrics_hnn['per_timestep_energy_mean'], 'b-', linewidth=2, label='HNN')
ax.fill_between(ts,
    metrics_hnn['per_timestep_energy_mean'] - metrics_hnn['per_timestep_energy_std'],
    metrics_hnn['per_timestep_energy_mean'] + metrics_hnn['per_timestep_energy_std'],
    alpha=0.2, color='blue')

ax.plot(ts, metrics_baseline['per_timestep_energy_mean'], 'r--', linewidth=2, label='Baseline MLP')
ax.fill_between(ts,
    metrics_baseline['per_timestep_energy_mean'] - metrics_baseline['per_timestep_energy_std'],
    metrics_baseline['per_timestep_energy_mean'] + metrics_baseline['per_timestep_energy_std'],
    alpha=0.2, color='red')

ax.set_xlabel('Timestep')
ax.set_ylabel('Energy H(z)')
ax.set_title(f'Energy Evolution ({EVAL_HORIZON}-Step, mean +/- 1 std)')
ax.legend()

# Bottom-left: Latent trajectory (single sample, qualitative)
ax = axes[1, 0]
baseline_traj_np = baseline_traj_full.cpu().numpy()
hnn_traj_np = hnn_traj_full.cpu().numpy()
gt_for_comparison = test_sample[0].cpu().numpy()
max_t = min(EVAL_HORIZON + 1, gt_for_comparison.shape[0])

ax.plot(hnn_traj_np[:max_t, 0], hnn_traj_np[:max_t, 1], 'b-', linewidth=1.5, alpha=0.7, label='HNN')
ax.plot(baseline_traj_np[:max_t, 0], baseline_traj_np[:max_t, 1], 'r--', linewidth=1.5, alpha=0.7, label='Baseline')
ax.plot(gt_for_comparison[:max_t, 0], gt_for_comparison[:max_t, 1], 'g:', linewidth=2, alpha=0.7, label='Ground Truth')
ax.scatter([hnn_traj_np[0, 0]], [hnn_traj_np[0, 1]], c='black', s=100, marker='o', zorder=5, label='Start')
ax.set_xlabel('$z_0$')
ax.set_ylabel('$z_1$')
ax.set_title('Trajectory in Latent Space (dims 0-1)')
ax.legend()

# Bottom-right: Training history
ax = axes[1, 1]
epochs = range(1, len(TRAINING_HISTORY['train_loss']) + 1)
ax.plot(epochs, TRAINING_HISTORY['train_loss'], 'b-', linewidth=2, label='Train Loss')
ax.plot(epochs, TRAINING_HISTORY['test_mse'], 'r--', linewidth=2, label='Test MSE')
ax.plot(epochs, TRAINING_HISTORY['train_energy'], 'g:', linewidth=2, label='Energy Loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training History')
ax.legend()
ax.set_yscale('log')

plt.suptitle('Hamiltonian Neural Network vs Baseline MLP', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig('/content/plot3_hnn_vs_baseline.png', dpi=150, bbox_inches='tight')
plt.show()

# ============================================================
# Plot 4: Ablation bar chart
# ============================================================
print("Plot 4: Ablation study...")

model_names = ['HNN\n(full)', 'HNN\n(no energy)', 'HNN\n(Euler)', 'Baseline\nMLP']
all_metrics = [metrics_hnn, metrics_no_energy, metrics_euler, metrics_baseline]
colors = ['#2196F3', '#4CAF50', '#FF9800', '#F44336']

fig4, (ax4a, ax4b) = plt.subplots(1, 2, figsize=(12, 5))

# MSE bar chart
mse_means = [m['mse_mean'] for m in all_metrics]
mse_stds = [m['mse_std'] for m in all_metrics]
x = np.arange(len(model_names))
bars = ax4a.bar(x, mse_means, yerr=mse_stds, capsize=5, color=colors, alpha=0.8, edgecolor='black')
ax4a.set_xticks(x)
ax4a.set_xticklabels(model_names)
ax4a.set_ylabel('MSE vs Ground Truth')
ax4a.set_title(f'Prediction Error (mean +/- std, n={n_test})')
for bar, val in zip(bars, mse_means):
    ax4a.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.002,
              f'{val:.4f}', ha='center', va='bottom', fontsize=9)

# Energy drift bar chart
drift_means = [m['drift_mean'] for m in all_metrics]
drift_stds = [m['drift_std'] for m in all_metrics]
bars = ax4b.bar(x, drift_means, yerr=drift_stds, capsize=5, color=colors, alpha=0.8, edgecolor='black')
ax4b.set_xticks(x)
ax4b.set_xticklabels(model_names)
ax4b.set_ylabel('Max Energy Drift |H(z_t) - H(z_0)|')
ax4b.set_title(f'Energy Conservation (mean +/- std, n={n_test})')
for bar, val in zip(bars, drift_means):
    ax4b.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.002,
              f'{val:.4f}', ha='center', va='bottom', fontsize=9)

plt.suptitle('Ablation Study: Component Contributions', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig('/content/plot4_ablation.png', dpi=150, bbox_inches='tight')
plt.show()

# ============================================================
# Summary table
# ============================================================
print("\n" + "="*90)
print(f"{'SUMMARY TABLE':^90}")
print(f"{'(all metrics: mean +/- std over ' + str(n_test) + ' test windows)':^90}")
print("="*90)
print(f"{'Model':<22} {'MSE':>18} {'Energy Drift':>18} {'Params':>10} {'ms/step':>12}")
print("-"*90)

rows = [
    ("HNN (full)", metrics_hnn, sum(p.numel() for p in hnn.parameters()), time_hnn_mean/EVAL_HORIZON),
    ("HNN (no energy)", metrics_no_energy, sum(p.numel() for p in hnn_no_energy.parameters()), None),
    ("HNN (Euler)", metrics_euler, sum(p.numel() for p in hnn_euler.parameters()), time_euler_mean/EVAL_HORIZON),
    ("Baseline MLP", metrics_baseline, sum(p.numel() for p in baseline.parameters()), time_baseline_mean/EVAL_HORIZON),
]
for name, m, params, ms_step in rows:
    ms_str = f"{ms_step:.2f}" if ms_step is not None else "N/A"
    print(f"{name:<22} {m['mse_mean']:.4f} +/- {m['mse_std']:.4f}  "
          f"{m['drift_mean']:.4f} +/- {m['drift_std']:.4f}  "
          f"{params:>10,}  {ms_str:>8}")

print("-"*90)
if metrics_baseline['mse_mean'] > 0:
    mse_ratio = metrics_baseline['mse_mean'] / metrics_hnn['mse_mean']
    drift_ratio = metrics_baseline['drift_mean'] / metrics_hnn['drift_mean']
    print(f"HNN vs Baseline: {mse_ratio:.1f}x lower MSE, {drift_ratio:.1f}x lower energy drift")
print(f"Learned dt: {integrator.dt.item():.4f}")
print(f"HNN inference cost: {time_hnn_mean/time_baseline_mean:.1f}x baseline (leapfrog requires 3 gradient passes/step)")

print(f"\nModel statistics:")
print(f"  HNN params: {sum(p.numel() for p in hnn.parameters()):,}")
print(f"  Baseline params: {sum(p.numel() for p in baseline.parameters()):,}")
print(f"\nTraining summary:")
print(f"  Final train loss: {TRAINING_HISTORY['train_loss'][-1]:.4f}")
print(f"  Best test MSE: {min(TRAINING_HISTORY['test_mse']):.4f}")

print(f"\nFigures saved to /content/plot{{1,2,3,4}}_*.png")