In [3]:
"""
NH₃ QML Comparison: Graph Equivariant vs Non-Equivariant
=========================================================

Compares two QML approaches for predicting NH₃ energy and forces:
1. Equivariant QML: EQNN-style with SO(3) equivariant encoding, singlets, Heisenberg interactions
2. Non-Equivariant (Graph-based): Simple graph QNN encoding N-H bonds with geometric features

Usage (command line):
    python run_comparison_nh3.py --n_runs 3 --n_epochs 10000 --output_dir results

Usage (Jupyter):
    from run_comparison_nh3 import main
    results = main(n_runs=3, n_epochs=400, output_dir='nh3_comparison_results')
"""

import pennylane as qml
import numpy as np
import json
import os
import argparse
from datetime import datetime

import jax
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)

from jax import numpy as jnp
from jax.example_libraries import optimizers
from sklearn.preprocessing import MinMaxScaler
from scipy.optimize import curve_fit
from sklearn.linear_model import LinearRegression

# =============================================================================
# HELPER FUNCTIONS
# =============================================================================

def huber(residual, delta=1.0):
    """Elementwise Huber loss for robust force training."""
    abs_r = jnp.abs(residual)
    quad = 0.5 * residual**2
    lin = delta * (abs_r - 0.5 * delta)
    return jnp.where(abs_r <= delta, quad, lin)

# =============================================================================
# EQUIVARIANT QML MODEL (EQNN-style)
# =============================================================================

class EquivariantQML:
    """
    EQNN-style equivariant QML for NH₃.
    Uses singlet initialization, equivariant encoding, and Heisenberg-like interactions.
    """
    
    def __init__(self, depth=6, rep=2, active_atoms=3):
        self.depth = depth
        self.rep = rep
        self.active_atoms = active_atoms  # H1, H2, H3 (N is fixed at origin)
        self.n_qubits = active_atoms * rep  # 6 qubits
        
        self.dev = qml.device("default.qubit", wires=self.n_qubits)
        
        # Observable for energy measurement
        self.observable = (
            qml.PauliX(0) @ qml.PauliX(1)
            + qml.PauliY(0) @ qml.PauliY(1)
            + qml.PauliZ(0) @ qml.PauliZ(1)
        )
        
        self._build_circuit()
    
    def _singlet(self, wires):
        """Two-qubit singlet state."""
        w0, w1 = wires
        qml.Hadamard(wires=w0)
        qml.PauliZ(wires=w0)
        qml.PauliX(wires=w1)
        qml.CNOT(wires=[w0, w1])
    
    def _equivariant_encoding(self, alpha, vec3, wire):
        """Equivariant encoding of a 3D vector."""
        r = jnp.array(vec3, dtype=jnp.float64)
        norm = jnp.linalg.norm(r) + 1e-12
        n = r / norm
        theta = alpha * norm
        qml.Rot(theta * n[0], theta * n[1], theta * n[2], wires=wire)
    
    def _pair_layer(self, weight, wires):
        """Trainable 2-qubit Heisenberg-like interaction."""
        qml.IsingXX(weight, wires=wires)
        qml.IsingYY(weight, wires=wires)
        qml.IsingZZ(weight, wires=wires)
    
    def _build_circuit(self):
        """Build the quantum circuit as a QNode."""
        
        @qml.qnode(self.dev, interface="jax", diff_method="backprop")
        def circuit(coords, params):
            """
            coords: (3, 3) - H1, H2, H3 positions relative to N
            params: {"weights": (n_qubits, depth), "alphas": (n_qubits, depth+1),
                     "head_scale": scalar, "head_bias": scalar}
            """
            weights = params["weights"]
            alphas = params["alphas"]
            
            # Initialize singlets on pairs (0,1), (2,3), (4,5)
            for i in range(0, self.n_qubits - 1, 2):
                self._singlet([i, i + 1])
            
            # Initial encoding
            for i in range(self.n_qubits):
                self._equivariant_encoding(alphas[i, 0], coords[i % self.active_atoms], i)
            
            # D layers of pair interactions + re-encoding
            for d in range(self.depth):
                qml.Barrier()
                
                # Even pairings
                for i in range(0, self.n_qubits - 1, 2):
                    self._pair_layer(weights[i, d], [i, (i + 1) % self.n_qubits])
                
                # Odd pairings
                for i in range(1, self.n_qubits, 2):
                    self._pair_layer(weights[i, d], [i, (i + 1) % self.n_qubits])
                
                # Re-encode geometry
                for i in range(self.n_qubits):
                    self._equivariant_encoding(alphas[i, d + 1], coords[i % self.active_atoms], i)
            
            return qml.expval(self.observable)
        
        self.circuit = circuit
        self.vec_circuit = jax.vmap(circuit, in_axes=(0, None), out_axes=0)
    
    def init_params(self, seed=42):
        """Initialize parameters."""
        np.random.seed(seed)
        
        weights0 = np.zeros((self.n_qubits, self.depth), dtype=np.float64)
        weights0[0] = np.random.uniform(0.0, np.pi, size=(self.depth,))
        
        return {
            "weights": jnp.array(weights0),
            "alphas": jnp.ones((self.n_qubits, self.depth + 1), dtype=jnp.float64),
            "head_scale": jnp.array(1.0, dtype=jnp.float64),
            "head_bias": jnp.array(0.0, dtype=jnp.float64),
        }
    
    def predict_energy(self, coords, params):
        """Predict energy with linear head."""
        raw_E = self.vec_circuit(coords, params)
        return params["head_scale"] * raw_E + params["head_bias"]
    
    def predict_forces(self, coords, params):
        """Predict forces as negative gradient of energy."""
        grad_fn = jax.grad(lambda c, p: self.circuit(c, p), argnums=0)
        vec_grad = jax.vmap(grad_fn, in_axes=(0, None), out_axes=0)
        raw_F = -vec_grad(coords, params)
        return params["head_scale"] * raw_F


# =============================================================================
# NON-EQUIVARIANT (GRAPH-BASED) QML MODEL
# =============================================================================

class GraphQML:
    """
    Graph-based QML for NH₃ (non-equivariant).
    Encodes N-H bonds as graph edges with geometric features.
    """
    
    def __init__(self, n_qubits=6, depth=4):
        self.n_qubits = n_qubits  # 3 bonds × 2 qubits per bond
        self.depth = depth
        
        self.dev = qml.device("default.qubit", wires=n_qubits)
        self._build_circuit()
    
    def _build_circuit(self):
        """Build the graph-based quantum circuit."""
        
        @qml.qnode(self.dev, interface="jax", diff_method="backprop")
        def circuit(positions, params):
            """
            positions: (4, 3) - [N, H1, H2, H3] coordinates
            params: {"weights": (depth, n_qubits, 3)}
            """
            weights = params["weights"]
            
            # N at index 0, H atoms at indices 1, 2, 3
            N_pos = positions[0]
            H_positions = positions[1:]  # (3, 3)
            
            # Compute bond vectors and distances
            bonds = H_positions - N_pos[None, :]  # (3, 3)
            distances = jnp.linalg.norm(bonds, axis=1)  # (3,)
            
            # Compute angles between bonds
            def compute_angle(v1, v2):
                cos_angle = jnp.dot(v1, v2) / (jnp.linalg.norm(v1) * jnp.linalg.norm(v2) + 1e-12)
                return jnp.arccos(jnp.clip(cos_angle, -1.0, 1.0))
            
            angle_01 = compute_angle(bonds[0], bonds[1])
            angle_02 = compute_angle(bonds[0], bonds[2])
            angle_12 = compute_angle(bonds[1], bonds[2])
            
            # Initialize qubits
            for i in range(self.n_qubits):
                qml.RY(0.5, wires=i)
            
            # Apply layers
            for layer in range(self.depth):
                # Encode bond distances
                qml.RY(weights[layer, 0, 0] * distances[0], wires=0)
                qml.RY(weights[layer, 1, 0] * distances[0], wires=1)
                qml.RY(weights[layer, 2, 0] * distances[1], wires=2)
                qml.RY(weights[layer, 3, 0] * distances[1], wires=3)
                qml.RY(weights[layer, 4, 0] * distances[2], wires=4)
                qml.RY(weights[layer, 5, 0] * distances[2], wires=5)
                
                # Entangle within bonds
                qml.CNOT(wires=[0, 1])
                qml.CNOT(wires=[2, 3])
                qml.CNOT(wires=[4, 5])
                
                # Encode angular information
                qml.RZ(weights[layer, 0, 1] * angle_01, wires=0)
                qml.RZ(weights[layer, 2, 1] * angle_01, wires=2)
                qml.RZ(weights[layer, 0, 2] * angle_02, wires=0)
                qml.RZ(weights[layer, 4, 2] * angle_02, wires=4)
                qml.RZ(weights[layer, 2, 2] * angle_12, wires=2)
                qml.RZ(weights[layer, 4, 2] * angle_12, wires=4)
                
                # Cross-bond entanglement
                qml.CNOT(wires=[1, 2])
                qml.CNOT(wires=[3, 4])
                qml.CNOT(wires=[5, 0])
                
                # Additional rotations
                for i in range(self.n_qubits):
                    qml.RZ(weights[layer, i, 1], wires=i)
                    qml.RY(weights[layer, i, 2], wires=i)
            
            return qml.expval(qml.sum(*(qml.PauliZ(i) for i in range(self.n_qubits))))
        
        self.circuit = circuit
        self.vec_circuit = jax.vmap(circuit, in_axes=(0, None), out_axes=0)
    
    def init_params(self, seed=42):
        """Initialize parameters."""
        np.random.seed(seed)
        return {
            "weights": jnp.array(np.random.normal(0, 0.1, (self.depth, self.n_qubits, 3)))
        }
    
    def predict_energy(self, positions, params):
        """Predict energy."""
        return self.vec_circuit(positions, params)
    
    def predict_forces(self, positions, params):
        """Predict forces as negative gradient of energy."""
        grad_fn = jax.grad(lambda p, params: self.circuit(p, params), argnums=0)
        vec_grad = jax.vmap(grad_fn, in_axes=(0, None), out_axes=0)
        return -vec_grad(positions, params)


# =============================================================================
# TRAINING FUNCTIONS
# =============================================================================

def train_equivariant(model, pos_H, E_train, F_train, n_epochs=400, lr=3e-3, 
                      wE=1.0, wF_max=5.0, warmup_frac=0.4, seed=42):
    """
    Train equivariant model with force warmup curriculum.
    
    Args:
        model: EquivariantQML instance
        pos_H: (N, 3, 3) H atom positions relative to N
        E_train: (N,) scaled energies
        F_train: (N, 3, 3) scaled forces on H atoms
        n_epochs: number of training steps
        lr: learning rate
        wE: energy loss weight
        wF_max: maximum force loss weight
        warmup_frac: fraction of training for force warmup
        seed: random seed
    
    Returns:
        trained_params, history
    """
    params = model.init_params(seed)
    
    # Force RMS for normalization
    F_rms = jnp.sqrt(jnp.mean(F_train**2)) + 1e-12
    F_train_norm = F_train / F_rms
    
    # Gradient function for forces
    grad_energy_single = jax.grad(lambda c, p: model.circuit(c, p), argnums=0)
    vec_grad = jax.vmap(grad_energy_single, in_axes=(0, None), out_axes=0)
    
    def loss_fn(params, coords, E_target, F_target_norm, wF):
        raw_E = model.vec_circuit(coords, params)
        raw_F = -vec_grad(coords, params)
        
        scale = params["head_scale"]
        bias = params["head_bias"]
        
        E_pred = scale * raw_E + bias
        F_pred_scaled = scale * raw_F
        F_pred_norm = F_pred_scaled / F_rms
        
        lE = jnp.mean((E_pred - E_target)**2)
        resid_F = F_pred_norm - F_target_norm
        lF = jnp.mean(huber(resid_F, delta=1.0))
        
        return wE * lE + wF * lF, (lE, lF)
    
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(params)
    
    warmup_steps = int(warmup_frac * n_epochs)
    history = {"total": [], "energy": [], "force": []}
    
    for step in range(n_epochs):
        wF = wF_max * min(step / warmup_steps, 1.0) if warmup_steps > 0 else wF_max
        
        params = get_params(opt_state)
        
        def wrapped(p):
            total, comps = loss_fn(p, pos_H, E_train, F_train_norm, wF)
            return total, comps
        
        (loss_val, (lE, lF)), grads = jax.value_and_grad(wrapped, has_aux=True)(params)
        opt_state = opt_update(step, grads, opt_state)
        
        history["total"].append(float(loss_val))
        history["energy"].append(float(lE))
        history["force"].append(float(lF))
        
        if step % 50 == 0:
            print(f"  [Equivariant] Step {step:4d} | Total: {loss_val:.6f} | E: {lE:.6f} | F: {lF:.6f}")
    
    return get_params(opt_state), history


def train_graph_qml(model, positions, E_train, F_train, n_epochs_energy=400, 
                    n_epochs_combined=2500, lr=0.01, seed=42):
    """
    Train graph-based QML with two-phase training: energy first, then combined.
    
    Args:
        model: GraphQML instance
        positions: (N, 4, 3) full positions [N, H1, H2, H3]
        E_train: (N,) scaled energies
        F_train: (N, 3, 3) scaled forces on H atoms
        n_epochs_energy: epochs for energy-only phase
        n_epochs_combined: epochs for combined phase
        lr: learning rate
        seed: random seed
    
    Returns:
        trained_params, history
    """
    params = model.init_params(seed)
    
    # Phase 1: Energy only
    @jax.jit
    def energy_loss(params, positions, E_target):
        E_pred = model.vec_circuit(positions, params)
        return jnp.mean((E_pred - E_target)**2)
    
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(params)
    
    history = {"energy_phase": [], "combined_E": [], "combined_F": []}
    
    print("  [Graph QML] Phase 1: Energy only")
    for step in range(n_epochs_energy):
        params = get_params(opt_state)
        loss, grads = jax.value_and_grad(energy_loss)(params, positions, E_train)
        
        # Gradient clipping
        grad_norm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in jax.tree.leaves(grads)))
        if grad_norm > 10.0:
            grads = jax.tree.map(lambda g: g * (10.0 / grad_norm), grads)
        
        opt_state = opt_update(step, grads, opt_state)
        history["energy_phase"].append(float(loss))
        
        if step % 50 == 0:
            print(f"    Step {step:4d} | Energy Loss: {loss:.6f}")
    
    # Phase 2: Combined energy + forces
    trained_params = get_params(opt_state)
    
    # Force gradient
    def force_single(coords, params):
        grad_fn = jax.grad(lambda c, p: model.circuit(c, p), argnums=0)
        return -grad_fn(coords, params)
    
    vec_force = jax.vmap(force_single, in_axes=(0, None), out_axes=0)
    
    @jax.jit
    def combined_loss(params, positions, E_target, F_target):
        E_pred = model.vec_circuit(positions, params)
        E_loss = jnp.mean((E_pred - E_target)**2)
        
        F_pred_full = vec_force(positions, params)
        F_pred_H = F_pred_full[:, 1:, :]  # Only H atoms
        F_loss = jnp.mean((F_pred_H - F_target)**2)
        
        E_loss = jnp.where(jnp.isnan(E_loss), 1.0, E_loss)
        F_loss = jnp.where(jnp.isnan(F_loss), 1.0, F_loss)
        
        return 2.0 * E_loss + 1.0 * F_loss, (E_loss, F_loss)
    
    opt_state = opt_init(trained_params)
    
    print("  [Graph QML] Phase 2: Energy + Forces")
    for step in range(n_epochs_combined):
        params = get_params(opt_state)
        (loss, (E_loss, F_loss)), grads = jax.value_and_grad(combined_loss, has_aux=True)(
            params, positions, E_train, F_train
        )
        
        # Gradient clipping
        grad_norm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in jax.tree.leaves(grads)))
        if grad_norm > 10.0:
            grads = jax.tree.map(lambda g: g * (10.0 / grad_norm), grads)
        
        opt_state = opt_update(step, grads, opt_state)
        
        history["combined_E"].append(float(E_loss))
        history["combined_F"].append(float(F_loss))
        
        if step % 50 == 0:
            print(f"    Step {step:4d} | E: {E_loss:.6f} | F: {F_loss:.6f}")
    
    return get_params(opt_state), history


# =============================================================================
# EVALUATION FUNCTIONS
# =============================================================================

def evaluate_model(E_pred, F_pred, E_true, F_true, indices_test, 
                   energy_scaler, force_scaler, indices_train, 
                   E_pred_train=None, F_pred_train=None, E_train_true=None, F_train_true=None):
    """
    Evaluate model with post-correction and compute metrics.
    
    Returns:
        dict with metrics and predictions
    """
    # Post-correction for energy (quadratic)
    def corr_E(E, a, b, c):
        return a * E**2 + b * E + c
    
    try:
        if E_pred_train is not None and E_train_true is not None:
            popt_E, _ = curve_fit(corr_E, E_pred_train, E_train_true)
            E_pred_corr = corr_E(E_pred, *popt_E)
        else:
            E_pred_corr = E_pred
    except:
        E_pred_corr = E_pred
    
    # Post-correction for forces (linear)
    try:
        if F_pred_train is not None and F_train_true is not None:
            F_pred_flat = F_pred_train.flatten().reshape(-1, 1)
            F_true_flat = F_train_true.flatten().reshape(-1, 1)
            lr_model = LinearRegression()
            lr_model.fit(F_pred_flat, F_true_flat)
            F_pred_corr = lr_model.predict(F_pred.flatten().reshape(-1, 1)).reshape(F_pred.shape)
        else:
            F_pred_corr = F_pred
    except:
        F_pred_corr = F_pred
    
    # Inverse transform to original units
    E_pred_final = energy_scaler.inverse_transform(E_pred_corr.reshape(-1, 1)).flatten()
    E_true_orig = energy_scaler.inverse_transform(E_true.reshape(-1, 1)).flatten()
    
    F_pred_final = force_scaler.inverse_transform(F_pred_corr.flatten().reshape(-1, 1)).reshape(F_pred_corr.shape)
    F_true_orig = force_scaler.inverse_transform(F_true.flatten().reshape(-1, 1)).reshape(F_true.shape)
    
    # Compute metrics on test set
    E_test_pred = E_pred_final[indices_test]
    E_test_true = E_true_orig[indices_test]
    
    E_mae = np.mean(np.abs(E_test_pred - E_test_true))
    E_rmse = np.sqrt(np.mean((E_test_pred - E_test_true)**2))
    E_r2 = 1 - np.sum((E_test_pred - E_test_true)**2) / np.sum((E_test_true - E_test_true.mean())**2)
    
    F_test_pred = F_pred_final[indices_test].flatten()
    F_test_true = F_true_orig[indices_test].flatten()
    
    F_mae = np.mean(np.abs(F_test_pred - F_test_true))
    F_rmse = np.sqrt(np.mean((F_test_pred - F_test_true)**2))
    F_r2 = 1 - np.sum((F_test_pred - F_test_true)**2) / np.sum((F_test_true - F_test_true.mean())**2)
    
    return {
        "energy": {
            "mae_ha": float(E_mae),
            "mae_ev": float(E_mae * 27.2114),
            "rmse_ha": float(E_rmse),
            "rmse_ev": float(E_rmse * 27.2114),
            "r2": float(E_r2),
        },
        "force": {
            "mae": float(F_mae),
            "rmse": float(F_rmse),
            "r2": float(F_r2),
        },
        "predictions": {
            "E_pred": E_pred_final[indices_test].tolist(),
            "E_true": E_test_true.tolist(),
            "F_pred": F_pred_final[indices_test].tolist(),
            "F_true": F_true_orig[indices_test].tolist(),
        }
    }


# =============================================================================
# MAIN COMPARISON
# =============================================================================

def run_comparison(n_runs=2, n_epochs=5000, output_dir="nh3_comparison_results", 
                   data_dir="eqnn_force_field_data_nh3_new"):
    """
    Run comparison between equivariant and graph-based QML models.
    
    Args:
        n_runs: Number of runs for each model
        n_epochs: Training epochs (equivariant uses this directly, 
                  graph QML splits between energy and combined phases)
        output_dir: Directory to save results
        data_dir: Directory containing NH₃ data files
    
    Returns:
        results dictionary
    """
    print(f"\n{'='*70}")
    print(f"NH₃ QML Comparison: Equivariant vs Graph-based")
    print(f"{'='*70}")
    print(f"Runs: {n_runs}, Epochs: {n_epochs}")
    print(f"Data directory: {data_dir}")
    print(f"Output directory: {output_dir}")
    print(f"{'='*70}\n")
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Load data
    print("Loading data...")
    energy = np.load(os.path.join(data_dir, "Energy.npy"))
    forces = np.load(os.path.join(data_dir, "Forces.npy"))
    positions = np.load(os.path.join(data_dir, "Positions.npy"))
    
    N_samples = len(energy)
    print(f"  Samples: {N_samples}")
    print(f"  Positions shape: {positions.shape}")
    print(f"  Forces shape: {forces.shape}\n")
    
    # Normalize energies
    energy_scaler = MinMaxScaler((-1, 1))
    if energy.ndim == 1:
        energy = energy.reshape(-1, 1)
    energy_scaled = energy_scaler.fit_transform(energy).flatten()
    
    # Normalize forces (H atoms only)
    forces_H = forces[:, 1:, :]  # (N, 3, 3)
    force_scaler = MinMaxScaler((-1, 1))
    forces_flat = forces_H.reshape(-1, 1)
    forces_scaled = force_scaler.fit_transform(forces_flat).reshape(forces_H.shape)
    
    # Extract H positions for equivariant model (relative to N)
    positions_H = positions[:, 1:, :]  # (N, 3, 3)
    
    # Results storage
    results = {
        "metadata": {
            "n_runs": n_runs,
            "n_epochs": n_epochs,
            "n_samples": N_samples,
            "timestamp": datetime.now().isoformat(),
        },
        "equivariant": {"runs": []},
        "graph_qml": {"runs": []},
    }
    
    # Run experiments
    for run in range(n_runs):
        print(f"\n{'='*70}")
        print(f"RUN {run + 1}/{n_runs}")
        print(f"{'='*70}")
        
        # Train/test split with different seed per run
        rng = np.random.default_rng(run)
        indices = np.arange(N_samples)
        rng.shuffle(indices)
        
        n_train = int(0.8 * N_samples)
        indices_train = indices[:n_train]
        indices_test = indices[n_train:]
        
        # Prepare data
        pos_H_train = jnp.array(positions_H[indices_train])
        pos_H_all = jnp.array(positions_H)
        pos_full_train = jnp.array(positions[indices_train])
        pos_full_all = jnp.array(positions)
        
        E_train = jnp.array(energy_scaled[indices_train])
        F_train = jnp.array(forces_scaled[indices_train])
        
        # ===== EQUIVARIANT MODEL =====
        print(f"\n--- Equivariant QML ---")
        eq_model = EquivariantQML(depth=6, rep=2, active_atoms=3)
        
        eq_params, eq_history = train_equivariant(
            eq_model, pos_H_train, E_train, F_train,
            n_epochs=n_epochs, lr=3e-3, wE=1.0, wF_max=5.0,
            warmup_frac=0.4, seed=run
        )
        
        # Predictions
        E_pred_eq = np.array(eq_model.predict_energy(pos_H_all, eq_params))
        F_pred_eq = np.array(eq_model.predict_forces(pos_H_all, eq_params))
        
        E_pred_eq_train = E_pred_eq[indices_train]
        F_pred_eq_train = F_pred_eq[indices_train]
        
        eq_metrics = evaluate_model(
            E_pred_eq, F_pred_eq, energy_scaled, forces_scaled,
            indices_test, energy_scaler, force_scaler, indices_train,
            E_pred_eq_train, F_pred_eq_train, 
            energy_scaled[indices_train], forces_scaled[indices_train]
        )
        
        print(f"\n  Equivariant Results:")
        print(f"    Energy R²: {eq_metrics['energy']['r2']:.4f}")
        print(f"    Force R²:  {eq_metrics['force']['r2']:.4f}")
        
        results["equivariant"]["runs"].append({
            "run": run + 1,
            "metrics": eq_metrics,
            "history": eq_history,
        })
        
        # ===== GRAPH-BASED QML =====
        print(f"\n--- Graph-based QML ---")
        graph_model = GraphQML(n_qubits=6, depth=4)
        
        # Split epochs between phases
        n_epochs_energy = int(n_epochs * 0.55)
        n_epochs_combined = int(n_epochs * 0.45)
        
        graph_params, graph_history = train_graph_qml(
            graph_model, pos_full_train, E_train, F_train,
            n_epochs_energy=n_epochs_energy, n_epochs_combined=n_epochs_combined,
            lr=0.01, seed=run
        )
        
        # Predictions
        E_pred_graph = np.array(graph_model.predict_energy(pos_full_all, graph_params))
        F_pred_graph_full = np.array(graph_model.predict_forces(pos_full_all, graph_params))
        F_pred_graph = F_pred_graph_full[:, 1:, :]  # H atoms only
        
        E_pred_graph_train = E_pred_graph[indices_train]
        F_pred_graph_train = F_pred_graph[indices_train]
        
        graph_metrics = evaluate_model(
            E_pred_graph, F_pred_graph, energy_scaled, forces_scaled,
            indices_test, energy_scaler, force_scaler, indices_train,
            E_pred_graph_train, F_pred_graph_train,
            energy_scaled[indices_train], forces_scaled[indices_train]
        )
        
        print(f"\n  Graph QML Results:")
        print(f"    Energy R²: {graph_metrics['energy']['r2']:.4f}")
        print(f"    Force R²:  {graph_metrics['force']['r2']:.4f}")
        
        results["graph_qml"]["runs"].append({
            "run": run + 1,
            "metrics": graph_metrics,
            "history": graph_history,
        })
    
    # Compute summary statistics
    for model_name in ["equivariant", "graph_qml"]:
        runs = results[model_name]["runs"]
        
        E_r2 = [r["metrics"]["energy"]["r2"] for r in runs]
        E_mae = [r["metrics"]["energy"]["mae_ha"] for r in runs]
        F_r2 = [r["metrics"]["force"]["r2"] for r in runs]
        F_mae = [r["metrics"]["force"]["mae"] for r in runs]
        
        results[model_name]["summary"] = {
            "energy_r2": {"mean": float(np.mean(E_r2)), "std": float(np.std(E_r2))},
            "energy_mae_ha": {"mean": float(np.mean(E_mae)), "std": float(np.std(E_mae))},
            "force_r2": {"mean": float(np.mean(F_r2)), "std": float(np.std(F_r2))},
            "force_mae": {"mean": float(np.mean(F_mae)), "std": float(np.std(F_mae))},
        }
    
    # Print summary
    print(f"\n{'='*70}")
    print("SUMMARY")
    print(f"{'='*70}")
    
    for model_name, display_name in [("equivariant", "Equivariant"), ("graph_qml", "Graph QML")]:
        s = results[model_name]["summary"]
        print(f"\n{display_name}:")
        print(f"  Energy R²:  {s['energy_r2']['mean']:.4f} ± {s['energy_r2']['std']:.4f}")
        print(f"  Energy MAE: {s['energy_mae_ha']['mean']:.6f} ± {s['energy_mae_ha']['std']:.6f} Ha")
        print(f"  Force R²:   {s['force_r2']['mean']:.4f} ± {s['force_r2']['std']:.4f}")
        print(f"  Force MAE:  {s['force_mae']['mean']:.4f} ± {s['force_mae']['std']:.4f} eV/Å")
    
    # Save results
    results_path = os.path.join(output_dir, "results.json")
    with open(results_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"\nResults saved to {results_path}")
    
    # Save numpy arrays for quick access
    np.savez(
        os.path.join(output_dir, "metrics.npz"),
        eq_E_r2=np.array([r["metrics"]["energy"]["r2"] for r in results["equivariant"]["runs"]]),
        eq_F_r2=np.array([r["metrics"]["force"]["r2"] for r in results["equivariant"]["runs"]]),
        eq_E_mae=np.array([r["metrics"]["energy"]["mae_ha"] for r in results["equivariant"]["runs"]]),
        eq_F_mae=np.array([r["metrics"]["force"]["mae"] for r in results["equivariant"]["runs"]]),
        graph_E_r2=np.array([r["metrics"]["energy"]["r2"] for r in results["graph_qml"]["runs"]]),
        graph_F_r2=np.array([r["metrics"]["force"]["r2"] for r in results["graph_qml"]["runs"]]),
        graph_E_mae=np.array([r["metrics"]["energy"]["mae_ha"] for r in results["graph_qml"]["runs"]]),
        graph_F_mae=np.array([r["metrics"]["force"]["mae"] for r in results["graph_qml"]["runs"]]),
    )
    print(f"Metrics saved to {os.path.join(output_dir, 'metrics.npz')}")
    
    return results


# =============================================================================
# MAIN
# =============================================================================

def main(n_runs=2, n_epochs=400, output_dir="nh3_comparison_results", 
         data_dir="eqnn_force_field_data_nh3_new"):
    """
    Main function - can be called directly from Jupyter or command line.
    
    Args:
        n_runs: Number of runs for each model
        n_epochs: Training epochs per run
        output_dir: Directory to save results
        data_dir: Directory containing NH₃ data (.npy files)
    
    Returns:
        results dictionary
    """
    return run_comparison(
        n_runs=n_runs,
        n_epochs=n_epochs,
        output_dir=output_dir,
        data_dir=data_dir
    )


if __name__ == "__main__":
    import sys
    
    # Check if running in Jupyter
    if 'ipykernel' in sys.modules:
        print("Running in Jupyter notebook. Call main() directly with parameters:")
        print("  results = main(n_runs=2, n_epochs=400, output_dir='results', data_dir='eqnn_force_field_data_nh3_new')")
    else:
        parser = argparse.ArgumentParser(description="Compare Equivariant vs Graph QML on NH₃")
        parser.add_argument("--n_runs", type=int, default=2, help="Number of runs")
        parser.add_argument("--n_epochs", type=int, default=400, help="Training epochs per run")
        parser.add_argument("--output_dir", type=str, default="nh3_comparison_results", help="Output directory")
        parser.add_argument("--data_dir", type=str, default="eqnn_force_field_data_nh3_new", help="Data directory")
        
        args = parser.parse_args()
        
        results = main(
            n_runs=args.n_runs,
            n_epochs=args.n_epochs,
            output_dir=args.output_dir,
            data_dir=args.data_dir
        )

Running in Jupyter notebook. Call main() directly with parameters:
  reader = main('nh3_comparison_results', plot='all', save_dir='./nh3_results/', show=True)

Or use the ResultsReader class directly:
  reader = ResultsReader('nh3_comparison_results')
  reader.print_summary()
  reader.plot_comparison_bars()


In [2]:
results = main(n_runs=1, n_epochs=400, output_dir='nh3_results', data_dir='eqnn_force_field_data_nh3_new')




NH₃ QML Comparison: Equivariant vs Graph-based
Runs: 1, Epochs: 400
Data directory: eqnn_force_field_data_nh3_new
Output directory: results

Loading data...
  Samples: 2400
  Positions shape: (2400, 4, 3)
  Forces shape: (2400, 4, 3)


RUN 1/1

--- Equivariant QML ---
  [Equivariant] Step    0 | Total: 0.359228 | E: 0.359228 | F: 4.019605
  [Equivariant] Step   50 | Total: 0.986429 | E: 0.169406 | F: 0.522894
  [Equivariant] Step  100 | Total: 0.641949 | E: 0.116195 | F: 0.168241
  [Equivariant] Step  150 | Total: 0.559816 | E: 0.113019 | F: 0.095317
  [Equivariant] Step  200 | Total: 0.360961 | E: 0.108589 | F: 0.050474
  [Equivariant] Step  250 | Total: 0.254419 | E: 0.101725 | F: 0.030539
  [Equivariant] Step  300 | Total: 0.237241 | E: 0.101009 | F: 0.027246
  [Equivariant] Step  350 | Total: 0.224346 | E: 0.100891 | F: 0.024691

  Equivariant Results:
    Energy R²: 0.9578
    Force R²:  0.9566

--- Graph-based QML ---
  [Graph QML] Phase 1: Energy only
    Step    0 | Energy Los