In [1]:
"""
NH₃ QML Comparison: Four Methods
=================================

Compares four approaches for predicting NH₃ energy and forces:
1. Rotationally Equivariant QML - EQNN-style with SO(3) equivariant encoding
2. Non-Equivariant QML - Simple QNN with basic rotations (no symmetry)
3. Graph Permutation Equivariant QML - Graph-based encoding with permutation symmetry
4. Classical Rotationally Equivariant NN - Classical MLP on pairwise distances (E(3) invariant)

Usage (command line):
    python run_comparison_nh3.py --n_runs 3 --n_epochs 400 --output_dir results

Usage (Jupyter):
    from run_comparison_nh3 import main
    results = main(n_runs=3, n_epochs=400, output_dir='nh3_comparison_results')
"""

import pennylane as qml
import numpy as np
import json
import os
import argparse
from datetime import datetime

import jax
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)

from jax import numpy as jnp
from jax.example_libraries import optimizers
from sklearn.preprocessing import MinMaxScaler
from scipy.optimize import curve_fit
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# =============================================================================
# HELPER FUNCTIONS
# =============================================================================

def huber(residual, delta=1.0):
    """Elementwise Huber loss for robust force training."""
    abs_r = jnp.abs(residual)
    quad = 0.5 * residual**2
    lin = delta * (abs_r - 0.5 * delta)
    return jnp.where(abs_r <= delta, quad, lin)


# =============================================================================
# 1. ROTATIONALLY EQUIVARIANT QML MODEL (SO(3))
# =============================================================================

class RotationallyEquivariantQML:
    """
    Rotationally Equivariant QML for NH₃.
    Uses singlet initialization, SO(3) equivariant encoding, and Heisenberg-like interactions.
    """
    
    def __init__(self, depth=6, rep=2, active_atoms=3, seed=42):
        self.depth = depth
        self.rep = rep
        self.active_atoms = active_atoms  # H1, H2, H3 (N is fixed at origin)
        self.n_qubits = active_atoms * rep  # 6 qubits
        self.seed = seed
        
        self.dev = qml.device("default.qubit", wires=self.n_qubits)
        
        # Observable for energy measurement
        self.observable = (
            qml.PauliX(0) @ qml.PauliX(1)
            + qml.PauliY(0) @ qml.PauliY(1)
            + qml.PauliZ(0) @ qml.PauliZ(1)
        )
        
        self._build_circuit()
        self._init_params()
    
    def _singlet(self, wires):
        """Two-qubit singlet state."""
        w0, w1 = wires
        qml.Hadamard(wires=w0)
        qml.PauliZ(wires=w0)
        qml.PauliX(wires=w1)
        qml.CNOT(wires=[w0, w1])
    
    def _equivariant_encoding(self, alpha, vec3, wire):
        """Equivariant encoding of a 3D vector."""
        r = jnp.array(vec3, dtype=jnp.float64)
        norm = jnp.linalg.norm(r) + 1e-12
        n = r / norm
        theta = alpha * norm
        qml.Rot(theta * n[0], theta * n[1], theta * n[2], wires=wire)
    
    def _pair_layer(self, weight, wires):
        """Trainable 2-qubit Heisenberg-like interaction."""
        qml.IsingXX(weight, wires=wires)
        qml.IsingYY(weight, wires=wires)
        qml.IsingZZ(weight, wires=wires)
    
    def _build_circuit(self):
        """Build the quantum circuit as a QNode."""
        
        @qml.qnode(self.dev, interface="jax", diff_method="backprop")
        def circuit(coords, params):
            """
            coords: (3, 3) - H1, H2, H3 positions relative to N
            params: {"weights": (n_qubits, depth), "alphas": (n_qubits, depth+1),
                     "head_scale": scalar, "head_bias": scalar}
            """
            weights = params["weights"]
            alphas = params["alphas"]
            
            # Initialize singlets on pairs (0,1), (2,3), (4,5)
            for i in range(0, self.n_qubits - 1, 2):
                self._singlet([i, i + 1])
            
            # Initial encoding
            for i in range(self.n_qubits):
                self._equivariant_encoding(alphas[i, 0], coords[i % self.active_atoms], i)
            
            # D layers of pair interactions + re-encoding
            for d in range(self.depth):
                qml.Barrier()
                
                # Even pairings
                for i in range(0, self.n_qubits - 1, 2):
                    self._pair_layer(weights[i, d], [i, (i + 1) % self.n_qubits])
                
                # Odd pairings
                for i in range(1, self.n_qubits, 2):
                    self._pair_layer(weights[i, d], [i, (i + 1) % self.n_qubits])
                
                # Re-encode geometry
                for i in range(self.n_qubits):
                    self._equivariant_encoding(alphas[i, d + 1], coords[i % self.active_atoms], i)
            
            return qml.expval(self.observable)
        
        self.circuit = circuit
        self.vec_circuit = jax.vmap(circuit, in_axes=(0, None), out_axes=0)
    
    def _init_params(self):
        """Initialize parameters."""
        np.random.seed(self.seed)
        
        weights0 = np.zeros((self.n_qubits, self.depth), dtype=np.float64)
        weights0[0] = np.random.uniform(0.0, np.pi, size=(self.depth,))
        
        self.params = {
            "weights": jnp.array(weights0),
            "alphas": jnp.ones((self.n_qubits, self.depth + 1), dtype=jnp.float64),
            "head_scale": jnp.array(1.0, dtype=jnp.float64),
            "head_bias": jnp.array(0.0, dtype=jnp.float64),
        }
    
    def get_params(self):
        return self.params
    
    def set_params(self, params):
        self.params = params
    
    def predict_energy(self, coords, params):
        """Predict energy with linear head."""
        raw_E = self.vec_circuit(coords, params)
        return params["head_scale"] * raw_E + params["head_bias"]
    
    def predict_forces(self, coords, params):
        """Predict forces as negative gradient of energy."""
        grad_fn = jax.grad(lambda c, p: self.circuit(c, p), argnums=0)
        vec_grad = jax.vmap(grad_fn, in_axes=(0, None), out_axes=0)
        raw_F = -vec_grad(coords, params)
        return params["head_scale"] * raw_F


# =============================================================================
# 2. NON-EQUIVARIANT QML MODEL
# =============================================================================

class NonEquivariantQML:
    """
    Simple non-equivariant QNN for NH₃.
    Uses basic rotations without symmetry preservation.
    """
    
    def __init__(self, num_qubits=6, depth=4, seed=42):
        self.num_qubits = num_qubits
        self.depth = depth
        self.seed = seed
        
        self.dev = qml.device("default.qubit", wires=num_qubits)
        self._create_circuit()
        self._init_params()
    
    def _create_circuit(self):
        """Create the simple QNN circuit."""
        num_qubits = self.num_qubits
        depth = self.depth
        
        @qml.qnode(self.dev, interface="jax", diff_method="backprop")
        def circuit(positions, params):
            """
            positions: (4, 3) - [N, H1, H2, H3] coordinates
            params: {"weights": (depth, num_qubits, 3)}
            """
            weights = params["weights"]
            
            # Compute simple geometric features (not respecting symmetry)
            # Bond lengths
            d1 = jnp.linalg.norm(positions[1] - positions[0])
            d2 = jnp.linalg.norm(positions[2] - positions[0])
            d3 = jnp.linalg.norm(positions[3] - positions[0])
            
            # Average bond length
            avg_dist = (d1 + d2 + d3) / 3.0
            
            # Initialize
            for i in range(num_qubits):
                qml.RY(0.5, wires=i)
            
            # Simple layers
            for layer in range(depth):
                # Encode distances directly (breaks permutation symmetry)
                qml.RY(weights[layer, 0, 0] * d1, wires=0)
                qml.RY(weights[layer, 1, 0] * d2, wires=1)
                qml.RY(weights[layer, 2, 0] * d3, wires=2)
                qml.RY(weights[layer, 3, 0] * avg_dist, wires=3)
                qml.RY(weights[layer, 4, 0] * d1, wires=4)
                qml.RY(weights[layer, 5, 0] * d2, wires=5)
                
                # Entangle
                for i in range(num_qubits - 1):
                    qml.CNOT(wires=[i, i + 1])
                
                # More rotations
                for i in range(num_qubits):
                    qml.RZ(weights[layer, i, 1], wires=i)
                    qml.RY(weights[layer, i, 2], wires=i)
            
            return qml.expval(sum(qml.PauliZ(i) for i in range(num_qubits)))
        
        self.circuit = circuit
        self.vec_circuit = jax.vmap(circuit, (0, None), 0)
    
    def _init_params(self):
        """Initialize parameters."""
        np.random.seed(self.seed)
        weights = np.random.normal(0, 0.1, (self.depth, self.num_qubits, 3))
        self.params = {"weights": jnp.array(weights)}
    
    def get_params(self):
        return self.params
    
    def set_params(self, params):
        self.params = params


# =============================================================================
# 3. GRAPH PERMUTATION EQUIVARIANT QML MODEL
# =============================================================================

class GraphPermutationEquivariantQML:
    """
    Graph Permutation Equivariant QML for NH₃.
    Encodes N-H bonds as graph edges with geometric features.
    Uses permutation-symmetric encoding.
    """
    
    def __init__(self, n_qubits=6, depth=4, seed=42):
        self.n_qubits = n_qubits  # 3 bonds × 2 qubits per bond
        self.depth = depth
        self.seed = seed
        
        self.dev = qml.device("default.qubit", wires=n_qubits)
        self._build_circuit()
        self._init_params()
    
    def _build_circuit(self):
        """Build the graph-based quantum circuit."""
        
        @qml.qnode(self.dev, interface="jax", diff_method="backprop")
        def circuit(positions, params):
            """
            positions: (4, 3) - [N, H1, H2, H3] coordinates
            params: {"weights": (depth, n_qubits, 3)}
            """
            weights = params["weights"]
            
            # N at index 0, H atoms at indices 1, 2, 3
            N_pos = positions[0]
            H_positions = positions[1:]  # (3, 3)
            
            # Compute bond vectors and distances
            bonds = H_positions - N_pos[None, :]  # (3, 3)
            distances = jnp.linalg.norm(bonds, axis=1)  # (3,)
            
            # Compute angles between bonds
            def compute_angle(v1, v2):
                cos_angle = jnp.dot(v1, v2) / (jnp.linalg.norm(v1) * jnp.linalg.norm(v2) + 1e-12)
                return jnp.arccos(jnp.clip(cos_angle, -1.0, 1.0))
            
            angle_01 = compute_angle(bonds[0], bonds[1])
            angle_02 = compute_angle(bonds[0], bonds[2])
            angle_12 = compute_angle(bonds[1], bonds[2])
            
            # Initialize qubits
            for i in range(self.n_qubits):
                qml.RY(0.5, wires=i)
            
            # Apply layers
            for layer in range(self.depth):
                # Encode bond distances (symmetric across bond pairs)
                qml.RY(weights[layer, 0, 0] * distances[0], wires=0)
                qml.RY(weights[layer, 1, 0] * distances[0], wires=1)
                qml.RY(weights[layer, 2, 0] * distances[1], wires=2)
                qml.RY(weights[layer, 3, 0] * distances[1], wires=3)
                qml.RY(weights[layer, 4, 0] * distances[2], wires=4)
                qml.RY(weights[layer, 5, 0] * distances[2], wires=5)
                
                # Entangle within bonds
                qml.CNOT(wires=[0, 1])
                qml.CNOT(wires=[2, 3])
                qml.CNOT(wires=[4, 5])
                
                # Encode angular information
                qml.RZ(weights[layer, 0, 1] * angle_01, wires=0)
                qml.RZ(weights[layer, 2, 1] * angle_01, wires=2)
                qml.RZ(weights[layer, 0, 2] * angle_02, wires=0)
                qml.RZ(weights[layer, 4, 2] * angle_02, wires=4)
                qml.RZ(weights[layer, 2, 2] * angle_12, wires=2)
                qml.RZ(weights[layer, 4, 2] * angle_12, wires=4)
                
                # Cross-bond entanglement
                qml.CNOT(wires=[1, 2])
                qml.CNOT(wires=[3, 4])
                qml.CNOT(wires=[5, 0])
                
                # Additional rotations
                for i in range(self.n_qubits):
                    qml.RZ(weights[layer, i, 1], wires=i)
                    qml.RY(weights[layer, i, 2], wires=i)
            
            return qml.expval(qml.sum(*(qml.PauliZ(i) for i in range(self.n_qubits))))
        
        self.circuit = circuit
        self.vec_circuit = jax.vmap(circuit, in_axes=(0, None), out_axes=0)
    
    def _init_params(self):
        """Initialize parameters."""
        np.random.seed(self.seed)
        self.params = {
            "weights": jnp.array(np.random.normal(0, 0.1, (self.depth, self.n_qubits, 3)))
        }
    
    def get_params(self):
        return self.params
    
    def set_params(self, params):
        self.params = params


# =============================================================================
# 4. CLASSICAL ROTATIONALLY EQUIVARIANT NN (E(3) INVARIANT)
# =============================================================================

class ClassicalRotationallyEquivariantNN:
    """
    Improved Classical Rotationally Equivariant Neural Network for NH₃.
    
    Key improvements over basic MLP:
    1. Expanded features: distances, inverse distances, angles, Morse-like terms
    2. SiLU activation (smooth, better gradients for force autodiff)
    3. Feature normalization for stable training
    4. Larger network with skip connections
    5. Radial basis function encoding for physics inductive bias
    
    Uses E(3) invariant features (distances, angles).
    Energy predicted from invariants, forces via autodiff.
    """
    
    def __init__(self, hidden_dims=[128, 128, 64], seed=42):
        self.hidden_dims = hidden_dims
        self.seed = seed
        self._init_params()
        self._create_model()
    
    def _init_params(self):
        """Initialize parameters with He initialization."""
        np.random.seed(self.seed)
        
        # Expanded feature set:
        # 6 distances + 6 inverse distances + 6 Morse-like + 3 angles + 3 cos(angles) = 24 features
        n_features = 24
        
        # Main MLP
        layer_sizes = [n_features] + self.hidden_dims + [1]
        
        params = {
            "weights": [],
            "biases": [],
            "skip_weights": [],  # For skip connections
            "output_scale": jnp.array(1.0),
            "output_bias": jnp.array(0.0),
        }
        
        for i in range(len(layer_sizes) - 1):
            fan_in = layer_sizes[i]
            fan_out = layer_sizes[i + 1]
            
            # He initialization for SiLU
            std = np.sqrt(2.0 / fan_in)
            W = np.random.normal(0, std, (fan_in, fan_out))
            b = np.zeros(fan_out)
            
            params["weights"].append(jnp.array(W))
            params["biases"].append(jnp.array(b))
        
        # Skip connection from input to final hidden layer
        skip_std = np.sqrt(2.0 / n_features)
        params["skip_weights"] = jnp.array(np.random.normal(0, skip_std, (n_features, self.hidden_dims[-1])))
        
        # Radial basis function parameters for pairwise terms
        # 6 distance pairs x 8 RBF centers
        params["rbf_coeffs"] = jnp.array(np.random.normal(0, 0.1, (6, 8)))
        params["rbf_output"] = jnp.array(np.random.normal(0, 0.1, (6,)))
        
        self.params = params
    
    def _create_model(self):
        """Create the forward pass function with improved architecture."""
        
        def silu(x):
            """SiLU/Swish activation - smooth, better for gradients."""
            return x * jax.nn.sigmoid(x)
        
        def compute_features(positions):
            """
            Compute E(3) invariant features from positions.
            positions: (4, 3) - [N, H1, H2, H3]
            
            Returns: (features, distances)
            """
            # Compute all 6 pairwise distances with numerical stability
            eps = 1e-8
            d_NH1 = jnp.linalg.norm(positions[1] - positions[0]) + eps
            d_NH2 = jnp.linalg.norm(positions[2] - positions[0]) + eps
            d_NH3 = jnp.linalg.norm(positions[3] - positions[0]) + eps
            d_H1H2 = jnp.linalg.norm(positions[2] - positions[1]) + eps
            d_H1H3 = jnp.linalg.norm(positions[3] - positions[1]) + eps
            d_H2H3 = jnp.linalg.norm(positions[3] - positions[2]) + eps
            
            distances = jnp.array([d_NH1, d_NH2, d_NH3, d_H1H2, d_H1H3, d_H2H3])
            
            # Normalized distances (typical N-H ~ 1.0 Å, H-H ~ 1.6 Å)
            dist_norm = distances / 1.5  # Scale to ~1
            
            # Inverse distances (Coulomb-like)
            inv_dist = 1.0 / distances
            inv_dist_norm = inv_dist / 1.0  # Scale appropriately
            
            # Morse-like features: exp(-alpha * (r - r_eq))
            # This captures the shape of bond potentials
            r_eq = jnp.array([1.01, 1.01, 1.01, 1.63, 1.63, 1.63])  # Approx equilibrium
            alpha = 2.0
            morse = jnp.exp(-alpha * (distances - r_eq))
            
            # Compute H-N-H angles using dot products (more stable than law of cosines)
            def compute_angle(p1, p2, p_center):
                """Angle at p_center between p1-p_center-p2."""
                v1 = p1 - p_center
                v2 = p2 - p_center
                cos_angle = jnp.dot(v1, v2) / (jnp.linalg.norm(v1) * jnp.linalg.norm(v2) + eps)
                return jnp.arccos(jnp.clip(cos_angle, -1.0 + eps, 1.0 - eps))
            
            # H-N-H angles (at N atom)
            angle_H1NH2 = compute_angle(positions[1], positions[2], positions[0])
            angle_H1NH3 = compute_angle(positions[1], positions[3], positions[0])
            angle_H2NH3 = compute_angle(positions[2], positions[3], positions[0])
            
            angles = jnp.array([angle_H1NH2, angle_H1NH3, angle_H2NH3])
            angles_norm = angles / jnp.pi  # Normalize to [0, 1]
            
            # Cosine of angles (often more useful than raw angles)
            cos_angles = jnp.cos(angles)
            
            # Concatenate all features
            features = jnp.concatenate([
                dist_norm,      # 6: normalized distances
                inv_dist_norm,  # 6: inverse distances
                morse,          # 6: Morse-like features
                angles_norm,    # 3: normalized angles
                cos_angles,     # 3: cosine of angles
            ])
            
            return features, distances
        
        def rbf_energy(distances, params):
            """
            Radial basis function energy contribution.
            Provides physics-based inductive bias for pairwise interactions.
            """
            # RBF centers (typical bond lengths in Angstrom)
            centers = jnp.linspace(0.8, 2.5, 8)
            width = 0.2
            
            # Compute RBF features for each distance
            # Shape: (6 distances, 8 centers)
            rbf = jnp.exp(-((distances[:, None] - centers[None, :]) ** 2) / (2 * width ** 2))
            
            # Weighted sum per distance pair
            pair_energies = jnp.sum(rbf * params["rbf_coeffs"], axis=1)  # (6,)
            
            # Total RBF contribution
            return jnp.dot(pair_energies, params["rbf_output"])
        
        def mlp_forward(features, params):
            """MLP with skip connections and SiLU activation."""
            weights = params["weights"]
            biases = params["biases"]
            
            h = features
            
            for i in range(len(weights) - 1):
                h = jnp.dot(h, weights[i]) + biases[i]
                h = silu(h)
                
                # Add skip connection to last hidden layer
                if i == len(weights) - 2:
                    skip = jnp.dot(features, params["skip_weights"])
                    h = h + 0.1 * skip  # Scaled skip connection
            
            # Output layer
            h = jnp.dot(h, weights[-1]) + biases[-1]
            return h.squeeze(-1)
        
        def energy_from_positions(positions, params):
            """
            Compute energy from atomic positions.
            Combines MLP prediction with RBF pairwise terms.
            """
            # Compute features
            features, distances = compute_features(positions)
            
            # MLP energy
            mlp_energy = mlp_forward(features, params)
            
            # RBF pairwise energy (physics bias)
            rbf_contrib = rbf_energy(distances, params)
            
            # Combined with learnable scaling
            total_energy = params["output_scale"] * mlp_energy + rbf_contrib + params["output_bias"]
            
            return total_energy
        
        def force_from_positions(positions, params):
            """Compute forces as negative gradient of energy."""
            grad_fn = jax.grad(energy_from_positions, argnums=0)
            return -grad_fn(positions, params)
        
        self.energy_fn = energy_from_positions
        self.force_fn = force_from_positions
        self.vec_energy = jax.vmap(energy_from_positions, (0, None), 0)
        self.vec_force = jax.vmap(force_from_positions, (0, None), 0)
    
    def get_params(self):
        return self.params
    
    def set_params(self, params):
        self.params = params


# =============================================================================
# TRAINING FUNCTIONS
# =============================================================================

def train_rotationally_equivariant(model, pos_H, E_train, F_train, n_epochs=400, lr=3e-3, 
                                    wE=1.0, wF_max=5.0, warmup_frac=0.4):
    """
    Train rotationally equivariant QML model with force warmup curriculum.
    """
    params = model.get_params()
    
    # Force RMS for normalization
    F_rms = jnp.sqrt(jnp.mean(F_train**2)) + 1e-12
    F_train_norm = F_train / F_rms
    
    # Gradient function for forces
    grad_energy_single = jax.grad(lambda c, p: model.circuit(c, p), argnums=0)
    vec_grad = jax.vmap(grad_energy_single, in_axes=(0, None), out_axes=0)
    
    def loss_fn(params, coords, E_target, F_target_norm, wF):
        raw_E = model.vec_circuit(coords, params)
        raw_F = -vec_grad(coords, params)
        
        scale = params["head_scale"]
        bias = params["head_bias"]
        
        E_pred = scale * raw_E + bias
        F_pred_scaled = scale * raw_F
        F_pred_norm = F_pred_scaled / F_rms
        
        lE = jnp.mean((E_pred - E_target)**2)
        resid_F = F_pred_norm - F_target_norm
        lF = jnp.mean(huber(resid_F, delta=1.0))
        
        return wE * lE + wF * lF, (lE, lF)
    
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(params)
    
    warmup_steps = int(warmup_frac * n_epochs)
    history = {"epoch": [], "train_loss": [], "test_E_loss": [], "test_F_loss": []}
    
    for step in range(n_epochs):
        wF = wF_max * min(step / warmup_steps, 1.0) if warmup_steps > 0 else wF_max
        
        params = get_params(opt_state)
        
        def wrapped(p):
            total, comps = loss_fn(p, pos_H, E_train, F_train_norm, wF)
            return total, comps
        
        (loss_val, (lE, lF)), grads = jax.value_and_grad(wrapped, has_aux=True)(params)
        opt_state = opt_update(step, grads, opt_state)
        
        if (step + 1) % max(1, n_epochs // 20) == 0:
            history["epoch"].append(step + 1)
            history["train_loss"].append(float(loss_val))
            history["test_E_loss"].append(float(lE))
            history["test_F_loss"].append(float(lF))
    
    model.set_params(get_params(opt_state))
    return history


def train_non_equivariant(model, positions, E_train, F_train, n_epochs=400, lr=0.01,
                          lambda_E=2.0, lambda_F=1.0):
    """Train the non-equivariant QML model."""
    
    def energy_single(coords, params):
        return model.circuit(coords, params)
    
    def force_single(coords, params):
        grad_fn = jax.grad(energy_single, argnums=0)
        return -grad_fn(coords, params)
    
    vec_force = jax.vmap(force_single, (0, None), 0)
    
    @jax.jit
    def combined_loss(params, positions, E_target, F_target):
        E_pred = model.vec_circuit(positions, params)
        E_loss = jnp.mean((E_pred - E_target) ** 2)
        
        F_pred_full = vec_force(positions, params)
        F_pred_H = F_pred_full[:, 1:, :]  # H atoms only
        F_loss = jnp.mean((F_pred_H - F_target) ** 2)
        
        E_loss = jnp.where(jnp.isnan(E_loss), 1.0, E_loss)
        F_loss = jnp.where(jnp.isnan(F_loss), 1.0, F_loss)
        
        total_loss = lambda_E * E_loss + lambda_F * F_loss
        return total_loss, (E_loss, F_loss)
    
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(model.params)
    
    history = {"epoch": [], "train_loss": [], "test_E_loss": [], "test_F_loss": []}
    
    for epoch in range(n_epochs):
        (loss, (E_loss, F_loss)), grads = jax.value_and_grad(combined_loss, has_aux=True)(
            get_params(opt_state), positions, E_train, F_train
        )
        
        grad_norm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in jax.tree.leaves(grads) if g is not None))
        if grad_norm > 10.0:
            grads = jax.tree.map(lambda g: g * (10.0 / grad_norm) if g is not None else g, grads)
        
        opt_state = opt_update(epoch, grads, opt_state)
        
        if (epoch + 1) % max(1, n_epochs // 20) == 0:
            history["epoch"].append(epoch + 1)
            history["train_loss"].append(float(loss))
            history["test_E_loss"].append(float(E_loss))
            history["test_F_loss"].append(float(F_loss))
    
    model.set_params(get_params(opt_state))
    return history


def train_graph_permutation_equivariant(model, positions, E_train, F_train, 
                                         n_epochs_energy=200, n_epochs_combined=200, 
                                         lr=0.01):
    """
    Train graph permutation equivariant QML with two-phase training.
    """
    # Phase 1: Energy only
    @jax.jit
    def energy_loss(params, positions, E_target):
        E_pred = model.vec_circuit(positions, params)
        return jnp.mean((E_pred - E_target)**2)
    
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(model.params)
    
    history = {"epoch": [], "train_loss": [], "test_E_loss": [], "test_F_loss": []}
    
    for step in range(n_epochs_energy):
        params = get_params(opt_state)
        loss, grads = jax.value_and_grad(energy_loss)(params, positions, E_train)
        
        grad_norm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in jax.tree.leaves(grads)))
        if grad_norm > 10.0:
            grads = jax.tree.map(lambda g: g * (10.0 / grad_norm), grads)
        
        opt_state = opt_update(step, grads, opt_state)
    
    # Phase 2: Combined
    trained_params = get_params(opt_state)
    
    def force_single(coords, params):
        grad_fn = jax.grad(lambda c, p: model.circuit(c, p), argnums=0)
        return -grad_fn(coords, params)
    
    vec_force = jax.vmap(force_single, in_axes=(0, None), out_axes=0)
    
    @jax.jit
    def combined_loss(params, positions, E_target, F_target):
        E_pred = model.vec_circuit(positions, params)
        E_loss = jnp.mean((E_pred - E_target)**2)
        
        F_pred_full = vec_force(positions, params)
        F_pred_H = F_pred_full[:, 1:, :]
        F_loss = jnp.mean((F_pred_H - F_target)**2)
        
        E_loss = jnp.where(jnp.isnan(E_loss), 1.0, E_loss)
        F_loss = jnp.where(jnp.isnan(F_loss), 1.0, F_loss)
        
        return 2.0 * E_loss + 1.0 * F_loss, (E_loss, F_loss)
    
    opt_state = opt_init(trained_params)
    
    for step in range(n_epochs_combined):
        params = get_params(opt_state)
        (loss, (E_loss, F_loss)), grads = jax.value_and_grad(combined_loss, has_aux=True)(
            params, positions, E_train, F_train
        )
        
        grad_norm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in jax.tree.leaves(grads)))
        if grad_norm > 10.0:
            grads = jax.tree.map(lambda g: g * (10.0 / grad_norm), grads)
        
        opt_state = opt_update(step, grads, opt_state)
        
        if (step + 1) % max(1, n_epochs_combined // 20) == 0:
            history["epoch"].append(n_epochs_energy + step + 1)
            history["train_loss"].append(float(loss))
            history["test_E_loss"].append(float(E_loss))
            history["test_F_loss"].append(float(F_loss))
    
    model.set_params(get_params(opt_state))
    return history


def train_classical_equivariant(model, positions, E_train, F_train, n_epochs=400, 
                                 lr=0.003, lambda_E=1.0, lambda_F=2.0, warmup_frac=0.3):
    """
    Train the classical rotationally equivariant NN with improved strategy.
    
    Uses two-phase training:
    1. Energy-only warmup phase
    2. Combined energy + force training with gradual force weight increase
    """
    warmup_epochs = int(n_epochs * warmup_frac)
    
    # Phase 1: Energy-only loss
    @jax.jit
    def energy_loss(params, positions, E_target):
        E_pred = model.vec_energy(positions, params)
        E_loss = jnp.mean((E_pred - E_target) ** 2)
        return jnp.where(jnp.isnan(E_loss), 1.0, E_loss)
    
    # Phase 2: Combined loss with Huber for forces
    @jax.jit
    def combined_loss(params, positions, E_target, F_target, wF):
        E_pred = model.vec_energy(positions, params)
        E_loss = jnp.mean((E_pred - E_target) ** 2)
        
        F_pred_full = model.vec_force(positions, params)
        F_pred_H = F_pred_full[:, 1:, :]  # H atoms only
        
        # Huber loss for forces (more robust to outliers)
        F_residual = F_pred_H - F_target
        F_loss = jnp.mean(huber(F_residual, delta=0.5))
        
        E_loss = jnp.where(jnp.isnan(E_loss), 1.0, E_loss)
        F_loss = jnp.where(jnp.isnan(F_loss), 1.0, F_loss)
        
        total_loss = lambda_E * E_loss + wF * F_loss
        return total_loss, (E_loss, F_loss)
    
    # Use Adam with slightly higher learning rate
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(model.params)
    
    history = {"epoch": [], "train_loss": [], "test_E_loss": [], "test_F_loss": []}
    
    # Phase 1: Energy warmup
    for epoch in range(warmup_epochs):
        loss, grads = jax.value_and_grad(energy_loss)(
            get_params(opt_state), positions, E_train
        )
        
        # Gradient clipping
        grad_norm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in jax.tree.leaves(grads) if g is not None))
        if grad_norm > 5.0:
            grads = jax.tree.map(lambda g: g * (5.0 / grad_norm) if g is not None else g, grads)
        
        opt_state = opt_update(epoch, grads, opt_state)
        
        if (epoch + 1) % max(1, n_epochs // 20) == 0:
            # Compute force loss for logging
            F_pred_full = model.vec_force(positions, get_params(opt_state))
            F_pred_H = F_pred_full[:, 1:, :]
            F_loss = float(jnp.mean((F_pred_H - F_train) ** 2))
            
            history["epoch"].append(epoch + 1)
            history["train_loss"].append(float(loss))
            history["test_E_loss"].append(float(loss))
            history["test_F_loss"].append(F_loss)
    
    # Phase 2: Combined training with force weight ramp
    for epoch in range(warmup_epochs, n_epochs):
        # Gradually increase force weight
        progress = (epoch - warmup_epochs) / max(1, n_epochs - warmup_epochs)
        wF = lambda_F * min(1.0, progress * 2)  # Ramp up over first half of phase 2
        
        (loss, (E_loss, F_loss)), grads = jax.value_and_grad(combined_loss, has_aux=True)(
            get_params(opt_state), positions, E_train, F_train, wF
        )
        
        # Gradient clipping
        grad_norm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in jax.tree.leaves(grads) if g is not None))
        if grad_norm > 5.0:
            grads = jax.tree.map(lambda g: g * (5.0 / grad_norm) if g is not None else g, grads)
        
        opt_state = opt_update(epoch, grads, opt_state)
        
        if (epoch + 1) % max(1, n_epochs // 20) == 0:
            history["epoch"].append(epoch + 1)
            history["train_loss"].append(float(loss))
            history["test_E_loss"].append(float(E_loss))
            history["test_F_loss"].append(float(F_loss))
    
    model.set_params(get_params(opt_state))
    return history


# =============================================================================
# EVALUATION FUNCTIONS
# =============================================================================

def evaluate_model(E_pred, F_pred, E_true, F_true, indices_test, 
                   energy_scaler, force_scaler, indices_train, 
                   E_pred_train=None, F_pred_train=None, E_train_true=None, F_train_true=None):
    """
    Evaluate model with post-correction and compute metrics.
    """
    # Post-correction for energy (quadratic)
    def corr_E(E, a, b, c):
        return a * E**2 + b * E + c
    
    try:
        if E_pred_train is not None and E_train_true is not None:
            popt_E, _ = curve_fit(corr_E, E_pred_train, E_train_true)
            E_pred_corr = corr_E(E_pred, *popt_E)
        else:
            E_pred_corr = E_pred
    except:
        E_pred_corr = E_pred
    
    # Post-correction for forces (linear)
    try:
        if F_pred_train is not None and F_train_true is not None:
            F_pred_flat = F_pred_train.flatten().reshape(-1, 1)
            F_true_flat = F_train_true.flatten().reshape(-1, 1)
            lr_model = LinearRegression()
            lr_model.fit(F_pred_flat, F_true_flat)
            F_pred_corr = lr_model.predict(F_pred.flatten().reshape(-1, 1)).reshape(F_pred.shape)
        else:
            F_pred_corr = F_pred
    except:
        F_pred_corr = F_pred
    
    # Inverse transform to original units
    E_pred_final = energy_scaler.inverse_transform(E_pred_corr.reshape(-1, 1)).flatten()
    E_true_orig = energy_scaler.inverse_transform(E_true.reshape(-1, 1)).flatten()
    
    F_pred_final = force_scaler.inverse_transform(F_pred_corr.flatten().reshape(-1, 1)).reshape(F_pred_corr.shape)
    F_true_orig = force_scaler.inverse_transform(F_true.flatten().reshape(-1, 1)).reshape(F_true.shape)
    
    # Compute metrics on test set
    E_test_pred = E_pred_final[indices_test]
    E_test_true = E_true_orig[indices_test]
    
    E_mae = np.mean(np.abs(E_test_pred - E_test_true))
    E_rmse = np.sqrt(np.mean((E_test_pred - E_test_true)**2))
    E_r2 = 1 - np.sum((E_test_pred - E_test_true)**2) / np.sum((E_test_true - E_test_true.mean())**2)
    
    F_test_pred = F_pred_final[indices_test].flatten()
    F_test_true = F_true_orig[indices_test].flatten()
    
    F_mae = np.mean(np.abs(F_test_pred - F_test_true))
    F_rmse = np.sqrt(np.mean((F_test_pred - F_test_true)**2))
    F_r2 = 1 - np.sum((F_test_pred - F_test_true)**2) / np.sum((F_test_true - F_test_true.mean())**2)
    
    return {
        "E_mae_Ha": float(E_mae),
        "E_mae_eV": float(E_mae * 27.2114),
        "E_rmse_Ha": float(E_rmse),
        "E_rmse_eV": float(E_rmse * 27.2114),
        "E_r2": float(E_r2),
        "F_mae": float(F_mae),
        "F_rmse": float(F_rmse),
        "F_r2": float(F_r2),
    }, {
        "E_pred": E_pred_final[indices_test].tolist(),
        "E_true": E_test_true.tolist(),
        "F_pred": F_pred_final[indices_test].tolist(),
        "F_true": F_true_orig[indices_test].tolist(),
        "indices_test": indices_test.tolist(),
    }


# =============================================================================
# PLOTTING
# =============================================================================

def create_comparison_plots(all_results, output_dir):
    """Create comparison plots for all four methods."""
    
    methods = ["rotationally_equivariant", "non_equivariant", "graph_permutation_equivariant", "classical_equivariant"]
    method_names = ["Rot. Equiv. QML", "Non-Equiv. QML", "Graph Perm. QML", "Classical Equiv. NN"]
    colors = ["#2ecc71", "#e74c3c", "#3498db", "#9b59b6"]
    
    # Figure 1: Training curves (2x2)
    fig1, axes1 = plt.subplots(2, 2, figsize=(12, 10))
    axes1 = axes1.flatten()
    
    for idx, (method, name, color) in enumerate(zip(methods, method_names, colors)):
        ax = axes1[idx]
        
        if len(all_results[method]["runs"]) > 0:
            history = all_results[method]["runs"][0]["history"]
            if history.get("epoch"):
                epochs = history["epoch"]
                ax.plot(epochs, history["train_loss"], '-', color=color, lw=2, label='Train Loss')
                if history.get("test_E_loss"):
                    ax.plot(epochs, history["test_E_loss"], '--', color=color, lw=2, alpha=0.7, label='E Loss')
                if history.get("test_F_loss"):
                    ax.plot(epochs, history["test_F_loss"], ':', color=color, lw=2, alpha=0.7, label='F Loss')
                
                ax.set_xlabel('Epoch')
                ax.set_ylabel('Loss')
                ax.set_title(f'{name}\nTraining Curves')
                ax.legend(fontsize=8)
                ax.grid(True, alpha=0.3)
                ax.set_yscale('log')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'training_curves.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    # Figure 2: Bar chart comparisons
    fig2, axes2 = plt.subplots(1, 3, figsize=(15, 5))
    
    x_pos = np.arange(len(methods))
    
    # Energy R² comparison
    ax = axes2[0]
    E_r2_means = [all_results[m]["metrics_summary"]["E_r2"]["mean"] for m in methods]
    E_r2_stds = [all_results[m]["metrics_summary"]["E_r2"]["std"] for m in methods]
    
    bars = ax.bar(x_pos, E_r2_means, yerr=E_r2_stds, color=colors, alpha=0.8, capsize=5)
    ax.set_xticks(x_pos)
    ax.set_xticklabels(method_names, fontsize=9, rotation=15, ha='right')
    ax.set_ylabel('Energy R²')
    ax.set_title('Energy Prediction R²')
    ax.set_ylim([0, 1.15])
    ax.grid(True, alpha=0.3, axis='y')
    
    for bar, mean in zip(bars, E_r2_means):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
                f'{mean:.3f}', ha='center', va='bottom', fontsize=9)
    
    # Force R² comparison
    ax = axes2[1]
    F_r2_means = [all_results[m]["metrics_summary"]["F_r2"]["mean"] for m in methods]
    F_r2_stds = [all_results[m]["metrics_summary"]["F_r2"]["std"] for m in methods]
    
    bars = ax.bar(x_pos, F_r2_means, yerr=F_r2_stds, color=colors, alpha=0.8, capsize=5)
    ax.set_xticks(x_pos)
    ax.set_xticklabels(method_names, fontsize=9, rotation=15, ha='right')
    ax.set_ylabel('Force R²')
    ax.set_title('Force Prediction R²')
    ax.set_ylim([0, 1.15])
    ax.grid(True, alpha=0.3, axis='y')
    
    for bar, mean in zip(bars, F_r2_means):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
                f'{mean:.3f}', ha='center', va='bottom', fontsize=9)
    
    # MAE comparison
    ax = axes2[2]
    width = 0.35
    E_mae_means = [all_results[m]["metrics_summary"]["E_mae_Ha"]["mean"] * 1000 for m in methods]
    F_mae_means = [all_results[m]["metrics_summary"]["F_mae"]["mean"] for m in methods]
    
    bars1 = ax.bar(x_pos - width/2, E_mae_means, width, color=colors, alpha=0.6, label='Energy MAE (mHa)')
    bars2 = ax.bar(x_pos + width/2, F_mae_means, width, color=colors, alpha=1.0, hatch='//', label='Force MAE (eV/Å)')
    
    ax.set_xticks(x_pos)
    ax.set_xticklabels(method_names, fontsize=9, rotation=15, ha='right')
    ax.set_ylabel('MAE')
    ax.set_title('Mean Absolute Errors')
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'metrics_comparison.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    # Figure 3: Scatter plots (2x4)
    fig3, axes3 = plt.subplots(2, 4, figsize=(16, 8))
    
    for idx, (method, name, color) in enumerate(zip(methods, method_names, colors)):
        if len(all_results[method]["runs"]) > 0:
            predictions = all_results[method]["runs"][0]["predictions"]
            
            E_pred = np.array(predictions["E_pred"])
            E_true = np.array(predictions["E_true"])
            F_pred = np.array(predictions["F_pred"]).flatten()
            F_true = np.array(predictions["F_true"]).flatten()
            
            metrics = all_results[method]["runs"][0]["metrics"]
            
            # Energy scatter
            ax = axes3[0, idx]
            ax.scatter(E_true, E_pred, c=color, alpha=0.6, s=20)
            lims = [min(E_true.min(), E_pred.min()), max(E_true.max(), E_pred.max())]
            ax.plot(lims, lims, 'k--', lw=2)
            ax.set_xlabel('True Energy (Ha)')
            ax.set_ylabel('Predicted Energy (Ha)')
            ax.set_title(f'{name}\nE R²={metrics["E_r2"]:.3f}')
            ax.grid(True, alpha=0.3)
            
            # Force scatter
            ax = axes3[1, idx]
            ax.scatter(F_true, F_pred, c=color, alpha=0.6, s=20)
            lims = [min(F_true.min(), F_pred.min()), max(F_true.max(), F_pred.max())]
            ax.plot(lims, lims, 'k--', lw=2)
            ax.set_xlabel('True Force (eV/Å)')
            ax.set_ylabel('Predicted Force (eV/Å)')
            ax.set_title(f'{name}\nF R²={metrics["F_r2"]:.3f}')
            ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'predictions_scatter.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    # Figure 4: Summary comparison
    fig4, ax = plt.subplots(1, 1, figsize=(10, 6))
    
    width = 0.35
    x = np.arange(len(methods))
    
    E_r2 = [all_results[m]["metrics_summary"]["E_r2"]["mean"] for m in methods]
    F_r2 = [all_results[m]["metrics_summary"]["F_r2"]["mean"] for m in methods]
    E_r2_err = [all_results[m]["metrics_summary"]["E_r2"]["std"] for m in methods]
    F_r2_err = [all_results[m]["metrics_summary"]["F_r2"]["std"] for m in methods]
    
    bars1 = ax.bar(x - width/2, E_r2, width, yerr=E_r2_err, label='Energy R²', 
                   color='steelblue', alpha=0.8, capsize=4)
    bars2 = ax.bar(x + width/2, F_r2, width, yerr=F_r2_err, label='Force R²', 
                   color='coral', alpha=0.8, capsize=4)
    
    ax.set_ylabel('R² Score')
    ax.set_title('NH₃ Method Comparison: Energy vs Force Prediction Performance')
    ax.set_xticks(x)
    ax.set_xticklabels(method_names, fontsize=10)
    ax.legend()
    ax.set_ylim([0, 1.1])
    ax.grid(True, alpha=0.3, axis='y')
    
    for bar in bars1:
        height = bar.get_height()
        ax.annotate(f'{height:.3f}',
                    xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3), textcoords="offset points",
                    ha='center', va='bottom', fontsize=8)
    for bar in bars2:
        height = bar.get_height()
        ax.annotate(f'{height:.3f}',
                    xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3), textcoords="offset points",
                    ha='center', va='bottom', fontsize=8)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'summary_comparison.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"  Plots saved to: {output_dir}/")


# =============================================================================
# MAIN COMPARISON
# =============================================================================

def run_comparison(n_runs=2, n_epochs=400, output_dir="nh3_comparison_results", 
                   data_dir="eqnn_force_field_data_nh3_new"):
    """
    Run comparison between all four methods on NH₃ data.
    """
    print("="*90)
    print("NH₃ Energy/Force Prediction: Four Methods Comparison")
    print("="*90)
    print("Methods:")
    print("  1. Rotationally Equivariant QML (SO(3) symmetry)")
    print("  2. Non-Equivariant QML (baseline)")
    print("  3. Graph Permutation Equivariant QML (permutation symmetry)")
    print("  4. Classical Rotationally Equivariant NN (E(3) invariant MLP)")
    print("="*90)
    
    os.makedirs(output_dir, exist_ok=True)
    
    # Load data
    print("\nLoading data...")
    try:
        energy = np.load(os.path.join(data_dir, "Energy.npy"))
        forces = np.load(os.path.join(data_dir, "Forces.npy"))
        positions = np.load(os.path.join(data_dir, "Positions.npy"))
        print(f"  Loaded {len(energy)} samples")
        print(f"  Positions shape: {positions.shape}")
        print(f"  Forces shape: {forces.shape}")
    except FileNotFoundError:
        print(f"  ERROR: Data not found in {data_dir}")
        return None
    
    N_samples = len(energy)
    
    # Normalize energies
    energy_scaler = MinMaxScaler((-1, 1))
    if energy.ndim == 1:
        energy = energy.reshape(-1, 1)
    energy_scaled = energy_scaler.fit_transform(energy).flatten()
    
    # Normalize forces (H atoms only)
    forces_H = forces[:, 1:, :]  # (N, 3, 3)
    force_scaler = MinMaxScaler((-1, 1))
    forces_flat = forces_H.reshape(-1, 1)
    forces_scaled = force_scaler.fit_transform(forces_flat).reshape(forces_H.shape)
    
    # Extract H positions for equivariant model (relative to N)
    positions_H = positions[:, 1:, :]  # (N, 3, 3)
    
    # Results storage
    all_results = {
        "config": {
            "n_runs": n_runs,
            "n_epochs": n_epochs,
            "n_samples": N_samples,
            "timestamp": datetime.now().isoformat(),
        },
        "rotationally_equivariant": {"runs": [], "metrics_summary": {}},
        "non_equivariant": {"runs": [], "metrics_summary": {}},
        "graph_permutation_equivariant": {"runs": [], "metrics_summary": {}},
        "classical_equivariant": {"runs": [], "metrics_summary": {}},
    }
    
    # Run experiments
    for run in range(n_runs):
        print(f"\n{'='*90}")
        print(f"RUN {run+1}/{n_runs}")
        print(f"{'='*90}")
        
        run_seed = 42 + run * 100
        
        # Train/test split
        rng = np.random.default_rng(run_seed)
        indices = np.arange(N_samples)
        rng.shuffle(indices)
        
        n_train = int(0.8 * N_samples)
        indices_train = indices[:n_train]
        indices_test = indices[n_train:]
        
        print(f"  Train: {len(indices_train)}, Test: {len(indices_test)}")
        
        # Prepare data
        pos_H_train = jnp.array(positions_H[indices_train])
        pos_H_all = jnp.array(positions_H)
        pos_full_train = jnp.array(positions[indices_train])
        pos_full_all = jnp.array(positions)
        
        E_train = jnp.array(energy_scaled[indices_train])
        F_train = jnp.array(forces_scaled[indices_train])
        
        # --- 1. Rotationally Equivariant QML ---
        print(f"\n[1/4] Rotationally Equivariant QML")
        rot_eq_model = RotationallyEquivariantQML(depth=6, rep=2, active_atoms=3, seed=run_seed)
        
        print(f"  Training for {n_epochs} epochs...")
        rot_eq_history = train_rotationally_equivariant(
            rot_eq_model, pos_H_train, E_train, F_train,
            n_epochs=n_epochs, lr=3e-3, wE=1.0, wF_max=5.0, warmup_frac=0.4
        )
        
        print(f"  Evaluating...")
        E_pred_rot = np.array(rot_eq_model.predict_energy(pos_H_all, rot_eq_model.params))
        F_pred_rot = np.array(rot_eq_model.predict_forces(pos_H_all, rot_eq_model.params))
        
        rot_metrics, rot_predictions = evaluate_model(
            E_pred_rot, F_pred_rot, energy_scaled, forces_scaled,
            indices_test, energy_scaler, force_scaler, indices_train,
            E_pred_rot[indices_train], F_pred_rot[indices_train],
            energy_scaled[indices_train], forces_scaled[indices_train]
        )
        
        print(f"  Energy: MAE={rot_metrics['E_mae_Ha']:.6f} Ha, R²={rot_metrics['E_r2']:.4f}")
        print(f"  Force:  MAE={rot_metrics['F_mae']:.4f} eV/Å, R²={rot_metrics['F_r2']:.4f}")
        
        all_results["rotationally_equivariant"]["runs"].append({
            "run_id": run,
            "seed": run_seed,
            "history": rot_eq_history,
            "metrics": rot_metrics,
            "predictions": rot_predictions,
        })
        
        # --- 2. Non-Equivariant QML ---
        print(f"\n[2/4] Non-Equivariant QML")
        neq_model = NonEquivariantQML(num_qubits=6, depth=4, seed=run_seed)
        
        print(f"  Training for {n_epochs} epochs...")
        neq_history = train_non_equivariant(
            neq_model, pos_full_train, E_train, F_train,
            n_epochs=n_epochs, lr=0.01, lambda_E=2.0, lambda_F=1.0
        )
        
        print(f"  Evaluating...")
        E_pred_neq = np.array(neq_model.vec_circuit(pos_full_all, neq_model.params))
        
        def neq_force_single(coords, params):
            return -jax.grad(lambda c, p: neq_model.circuit(c, p), argnums=0)(coords, params)
        neq_vec_force = jax.vmap(neq_force_single, (0, None), 0)
        F_pred_neq_full = np.array(neq_vec_force(pos_full_all, neq_model.params))
        F_pred_neq = F_pred_neq_full[:, 1:, :]
        
        neq_metrics, neq_predictions = evaluate_model(
            E_pred_neq, F_pred_neq, energy_scaled, forces_scaled,
            indices_test, energy_scaler, force_scaler, indices_train,
            E_pred_neq[indices_train], F_pred_neq[indices_train],
            energy_scaled[indices_train], forces_scaled[indices_train]
        )
        
        print(f"  Energy: MAE={neq_metrics['E_mae_Ha']:.6f} Ha, R²={neq_metrics['E_r2']:.4f}")
        print(f"  Force:  MAE={neq_metrics['F_mae']:.4f} eV/Å, R²={neq_metrics['F_r2']:.4f}")
        
        all_results["non_equivariant"]["runs"].append({
            "run_id": run,
            "seed": run_seed,
            "history": neq_history,
            "metrics": neq_metrics,
            "predictions": neq_predictions,
        })
        
        # --- 3. Graph Permutation Equivariant QML ---
        print(f"\n[3/4] Graph Permutation Equivariant QML")
        gpe_model = GraphPermutationEquivariantQML(n_qubits=6, depth=4, seed=run_seed)
        
        n_epochs_energy = int(n_epochs * 0.5)
        n_epochs_combined = int(n_epochs * 0.5)
        
        print(f"  Training for {n_epochs} epochs (2-phase)...")
        gpe_history = train_graph_permutation_equivariant(
            gpe_model, pos_full_train, E_train, F_train,
            n_epochs_energy=n_epochs_energy, n_epochs_combined=n_epochs_combined, lr=0.01
        )
        
        print(f"  Evaluating...")
        E_pred_gpe = np.array(gpe_model.vec_circuit(pos_full_all, gpe_model.params))
        
        def gpe_force_single(coords, params):
            return -jax.grad(lambda c, p: gpe_model.circuit(c, p), argnums=0)(coords, params)
        gpe_vec_force = jax.vmap(gpe_force_single, (0, None), 0)
        F_pred_gpe_full = np.array(gpe_vec_force(pos_full_all, gpe_model.params))
        F_pred_gpe = F_pred_gpe_full[:, 1:, :]
        
        gpe_metrics, gpe_predictions = evaluate_model(
            E_pred_gpe, F_pred_gpe, energy_scaled, forces_scaled,
            indices_test, energy_scaler, force_scaler, indices_train,
            E_pred_gpe[indices_train], F_pred_gpe[indices_train],
            energy_scaled[indices_train], forces_scaled[indices_train]
        )
        
        print(f"  Energy: MAE={gpe_metrics['E_mae_Ha']:.6f} Ha, R²={gpe_metrics['E_r2']:.4f}")
        print(f"  Force:  MAE={gpe_metrics['F_mae']:.4f} eV/Å, R²={gpe_metrics['F_r2']:.4f}")
        
        all_results["graph_permutation_equivariant"]["runs"].append({
            "run_id": run,
            "seed": run_seed,
            "history": gpe_history,
            "metrics": gpe_metrics,
            "predictions": gpe_predictions,
        })
        
        # --- 4. Classical Rotationally Equivariant NN ---
        print(f"\n[4/4] Classical Rotationally Equivariant NN")
        classical_model = ClassicalRotationallyEquivariantNN(hidden_dims=[128, 128, 64], seed=run_seed)
        
        print(f"  Training for {n_epochs} epochs...")
        classical_history = train_classical_equivariant(
            classical_model, pos_full_train, E_train, F_train,
            n_epochs=n_epochs, lr=0.003, lambda_E=1.0, lambda_F=2.0, warmup_frac=0.3
        )
        
        print(f"  Evaluating...")
        E_pred_classical = np.array(classical_model.vec_energy(pos_full_all, classical_model.params))
        F_pred_classical_full = np.array(classical_model.vec_force(pos_full_all, classical_model.params))
        F_pred_classical = F_pred_classical_full[:, 1:, :]
        
        classical_metrics, classical_predictions = evaluate_model(
            E_pred_classical, F_pred_classical, energy_scaled, forces_scaled,
            indices_test, energy_scaler, force_scaler, indices_train,
            E_pred_classical[indices_train], F_pred_classical[indices_train],
            energy_scaled[indices_train], forces_scaled[indices_train]
        )
        
        print(f"  Energy: MAE={classical_metrics['E_mae_Ha']:.6f} Ha, R²={classical_metrics['E_r2']:.4f}")
        print(f"  Force:  MAE={classical_metrics['F_mae']:.4f} eV/Å, R²={classical_metrics['F_r2']:.4f}")
        
        all_results["classical_equivariant"]["runs"].append({
            "run_id": run,
            "seed": run_seed,
            "history": classical_history,
            "metrics": classical_metrics,
            "predictions": classical_predictions,
        })
    
    # Compute summary statistics
    for model_type in ["rotationally_equivariant", "non_equivariant", "graph_permutation_equivariant", "classical_equivariant"]:
        metrics_list = [r["metrics"] for r in all_results[model_type]["runs"]]
        
        summary = {}
        for key in metrics_list[0].keys():
            values = [m[key] for m in metrics_list]
            summary[key] = {
                "mean": float(np.mean(values)),
                "std": float(np.std(values)),
                "min": float(np.min(values)),
                "max": float(np.max(values)),
                "values": values,
            }
        all_results[model_type]["metrics_summary"] = summary
    
    # Save results
    results_path = os.path.join(output_dir, "results.json")
    with open(results_path, 'w') as f:
        json.dump(all_results, f, indent=2)
    print(f"\nResults saved to: {results_path}")
    
    # Save numpy arrays
    np.savez(os.path.join(output_dir, "metrics.npz"),
             rot_eq_E_r2=[r["metrics"]["E_r2"] for r in all_results["rotationally_equivariant"]["runs"]],
             rot_eq_F_r2=[r["metrics"]["F_r2"] for r in all_results["rotationally_equivariant"]["runs"]],
             rot_eq_E_mae=[r["metrics"]["E_mae_Ha"] for r in all_results["rotationally_equivariant"]["runs"]],
             rot_eq_F_mae=[r["metrics"]["F_mae"] for r in all_results["rotationally_equivariant"]["runs"]],
             neq_E_r2=[r["metrics"]["E_r2"] for r in all_results["non_equivariant"]["runs"]],
             neq_F_r2=[r["metrics"]["F_r2"] for r in all_results["non_equivariant"]["runs"]],
             neq_E_mae=[r["metrics"]["E_mae_Ha"] for r in all_results["non_equivariant"]["runs"]],
             neq_F_mae=[r["metrics"]["F_mae"] for r in all_results["non_equivariant"]["runs"]],
             gpe_E_r2=[r["metrics"]["E_r2"] for r in all_results["graph_permutation_equivariant"]["runs"]],
             gpe_F_r2=[r["metrics"]["F_r2"] for r in all_results["graph_permutation_equivariant"]["runs"]],
             gpe_E_mae=[r["metrics"]["E_mae_Ha"] for r in all_results["graph_permutation_equivariant"]["runs"]],
             gpe_F_mae=[r["metrics"]["F_mae"] for r in all_results["graph_permutation_equivariant"]["runs"]],
             classical_E_r2=[r["metrics"]["E_r2"] for r in all_results["classical_equivariant"]["runs"]],
             classical_F_r2=[r["metrics"]["F_r2"] for r in all_results["classical_equivariant"]["runs"]],
             classical_E_mae=[r["metrics"]["E_mae_Ha"] for r in all_results["classical_equivariant"]["runs"]],
             classical_F_mae=[r["metrics"]["F_mae"] for r in all_results["classical_equivariant"]["runs"]])
    
    # Create plots
    print("\nGenerating comparison plots...")
    create_comparison_plots(all_results, output_dir)
    
    # Print summary
    print("\n" + "="*100)
    print("SUMMARY")
    print("="*100)
    print(f"\n{'Metric':<15} {'Rot. Equiv. QML':<20} {'Non-Equiv. QML':<20} {'Graph Perm. QML':<20} {'Classical Equiv.':<20}")
    print("-"*95)
    
    for metric in ["E_r2", "E_mae_Ha", "F_r2", "F_mae"]:
        rot_mean = all_results["rotationally_equivariant"]["metrics_summary"][metric]["mean"]
        rot_std = all_results["rotationally_equivariant"]["metrics_summary"][metric]["std"]
        neq_mean = all_results["non_equivariant"]["metrics_summary"][metric]["mean"]
        neq_std = all_results["non_equivariant"]["metrics_summary"][metric]["std"]
        gpe_mean = all_results["graph_permutation_equivariant"]["metrics_summary"][metric]["mean"]
        gpe_std = all_results["graph_permutation_equivariant"]["metrics_summary"][metric]["std"]
        cls_mean = all_results["classical_equivariant"]["metrics_summary"][metric]["mean"]
        cls_std = all_results["classical_equivariant"]["metrics_summary"][metric]["std"]
        
        print(f"{metric:<15} {rot_mean:.4f}±{rot_std:.4f}       {neq_mean:.4f}±{neq_std:.4f}       "
              f"{gpe_mean:.4f}±{gpe_std:.4f}       {cls_mean:.4f}±{cls_std:.4f}")
    
    print("="*100)
    
    return all_results


# =============================================================================
# MAIN
# =============================================================================

def main(n_runs=2, n_epochs=400, output_dir="nh3_comparison_results", 
         data_dir="eqnn_force_field_data_nh3_new"):
    """
    Main function - can be called directly from Jupyter or command line.
    
    Args:
        n_runs: Number of runs for each model
        n_epochs: Training epochs per run
        output_dir: Directory to save results
        data_dir: Directory containing NH₃ data (.npy files)
    
    Returns:
        results dictionary
    """
    return run_comparison(
        n_runs=n_runs,
        n_epochs=n_epochs,
        output_dir=output_dir,
        data_dir=data_dir
    )


if __name__ == "__main__":
    import sys
    
    if 'ipykernel' in sys.modules:
        print("Running in Jupyter notebook. Call main() directly with parameters:")
        print("  results = main(n_runs=2, n_epochs=400, output_dir='nh3_results', data_dir='eqnn_force_field_data_nh3_new')")
    else:
        parser = argparse.ArgumentParser(description="Compare Four Methods on NH₃")
        parser.add_argument("--n_runs", type=int, default=2, help="Number of runs")
        parser.add_argument("--n_epochs", type=int, default=400, help="Training epochs per run")
        parser.add_argument("--output_dir", type=str, default="nh3_comparison_results", help="Output directory")
        parser.add_argument("--data_dir", type=str, default="eqnn_force_field_data_nh3_new", help="Data directory")
        
        args = parser.parse_args()
        
        results = main(
            n_runs=args.n_runs,
            n_epochs=args.n_epochs,
            output_dir=args.output_dir,
            data_dir=args.data_dir
        )

Running in Jupyter notebook. Call main() directly with parameters:
  results = main(n_runs=2, n_epochs=400, output_dir='nh3_results', data_dir='eqnn_force_field_data_nh3_new')




In [None]:
results = main(n_runs=1, n_epochs=400, output_dir='nh3_results', data_dir='eqnn_force_field_data_nh3_new')



NH₃ Energy/Force Prediction: Four Methods Comparison
Methods:
  1. Rotationally Equivariant QML (SO(3) symmetry)
  2. Non-Equivariant QML (baseline)
  3. Graph Permutation Equivariant QML (permutation symmetry)
  4. Classical Rotationally Equivariant NN (E(3) invariant MLP)

Loading data...
  Loaded 2400 samples
  Positions shape: (2400, 4, 3)
  Forces shape: (2400, 4, 3)

RUN 1/1
  Train: 1920, Test: 480

[1/4] Rotationally Equivariant QML
  Training for 400 epochs...
