In [3]:
"""
Comparison of Four Methods for LiH Energy/Force Prediction

This script runs four methods on LiH molecular data and compares their performance.
All results are saved to an output directory for later analysis.

Methods compared:
1. Rotationally Equivariant QML - Uses SO(3) equivariant encoding with Heisenberg observable
2. Non-Equivariant QML - Simple QNN with basic rotations
3. Graph Permutation Equivariant QML - Uses graph-based permutation-symmetric encoding
4. Classical Rotationally Equivariant NN - Classical MLP on pairwise distances (E(3) invariant)

Usage:
    python run_comparison_four_methods.py --n_runs 3 --n_epochs 100 --output_dir results
"""

import pennylane as qml
import numpy as np
import json
import os
from datetime import datetime
import argparse

import jax
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)
from jax import numpy as jnp

from jax.example_libraries import optimizers
from sklearn.preprocessing import MinMaxScaler
from scipy.optimize import curve_fit
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# =============================================================================
# ROTATIONALLY EQUIVARIANT QML MODEL (SO(3))
# =============================================================================

class RotationallyEquivariantQML:
    """
    Rotationally Equivariant Quantum Machine Learning model for LiH.
    Uses SO(3) equivariant encoding with native PennyLane gates.
    
    Architecture features:
    - Multiple qubits with singlet initialization
    - Learnable head_scale and head_bias for output
    - Proper alpha initialization in [0.5, 1.5] range
    """
    
    def __init__(self, n_qubits=6, depth=6, seed=42):
        self.n_qubits = n_qubits
        self.depth = depth
        self.seed = seed
        
        self.dev = qml.device("default.qubit", wires=n_qubits)
        
        # Heisenberg observable
        self.observable = (
            qml.PauliX(0) @ qml.PauliX(1)
            + qml.PauliY(0) @ qml.PauliY(1)
            + qml.PauliZ(0) @ qml.PauliZ(1)
        )
        
        self._build_circuit()
        self._init_params()
    
    def _singlet(self, wires):
        """Create singlet state on two qubits."""
        w0, w1 = wires
        qml.Hadamard(wires=w0)
        qml.PauliZ(wires=w0)
        qml.PauliX(wires=w1)
        qml.CNOT(wires=[w0, w1])
    
    def _equivariant_encoding(self, alpha, vec3, wire):
        """SO(3) equivariant encoding using qml.Rot."""
        r = jnp.array(vec3, dtype=jnp.float64)
        norm = jnp.linalg.norm(r) + 1e-12
        n = r / norm
        theta = alpha * norm
        qml.Rot(theta * n[0], theta * n[1], theta * n[2], wires=wire)
    
    def _pair_layer(self, weight, wires):
        """Trainable Heisenberg-like interaction."""
        qml.IsingXX(weight, wires=wires)
        qml.IsingYY(weight, wires=wires)
        qml.IsingZZ(weight, wires=wires)
    
    def _build_circuit(self):
        """Build the quantum circuit."""
        @qml.qnode(self.dev, interface="jax", diff_method="backprop")
        def circuit(coords, params):
            """
            coords: (1, 3) - H position relative to Li
            params: {"weights", "alphas", "head_scale", "head_bias"}
            """
            weights = params["weights"]
            alphas = params["alphas"]
            
            # Initialize singlets on pairs of qubits
            for i in range(0, self.n_qubits - 1, 2):
                self._singlet([i, i + 1])
            
            # Initial encoding - all qubits encode the same H position
            for i in range(self.n_qubits):
                self._equivariant_encoding(alphas[i, 0], coords[0], i)
            
            # Variational layers
            for d in range(self.depth):
                qml.Barrier()
                # Even pairs
                for i in range(0, self.n_qubits - 1, 2):
                    self._pair_layer(weights[i, d], [i, (i + 1) % self.n_qubits])
                # Odd pairs
                for i in range(1, self.n_qubits, 2):
                    self._pair_layer(weights[i, d], [i, (i + 1) % self.n_qubits])
                # Re-encoding
                for i in range(self.n_qubits):
                    self._equivariant_encoding(alphas[i, d + 1], coords[0], i)
            
            return qml.expval(self.observable)
        
        self.circuit = circuit
        self.vec_circuit = jax.vmap(circuit, in_axes=(0, None), out_axes=0)
    
    def _init_params(self):
        """Initialize parameters with proper ranges."""
        np.random.seed(self.seed)
        
        # Weights: small initial values, only first row non-zero initially
        weights = np.zeros((self.n_qubits, self.depth), dtype=np.float64)
        weights[0] = np.random.uniform(0.0, np.pi, size=(self.depth,))
        
        # Alphas: in [0.5, 1.5] range for stable encoding
        alphas = np.random.uniform(0.5, 1.5, size=(self.n_qubits, self.depth + 1))
        
        self.params = {
            "weights": jnp.array(weights),
            "alphas": jnp.array(alphas),
            "head_scale": jnp.array(1.0),
            "head_bias": jnp.array(0.0),
        }
    
    def energy(self, coords, params):
        """Compute energy with head transformation."""
        raw = self.circuit(coords, params)
        return params["head_scale"] * raw + params["head_bias"]
    
    def get_params(self):
        return self.params
    
    def set_params(self, params):
        self.params = params


# =============================================================================
# NON-EQUIVARIANT QML MODEL
# =============================================================================

class NonEquivariantQML:
    """
    Simple non-equivariant QNN for molecular properties.
    Uses basic rotations without symmetry preservation.
    """
    
    def __init__(self, num_qubits=4, depth=3, seed=42):
        self.num_qubits = num_qubits
        self.depth = depth
        self.seed = seed
        
        self.dev = qml.device("default.qubit", wires=num_qubits)
        self._create_circuit()
        self._init_params()
    
    def _create_circuit(self):
        """Create the simple QNN circuit."""
        num_qubits = self.num_qubits
        depth = self.depth
        
        @qml.qnode(self.dev, interface="jax", diff_method="backprop")
        def circuit(positions, params):
            weights = params["weights"]
            
            # Single feature: bond length
            dist = jnp.linalg.norm(positions[1] - positions[0])
            
            # Initialize
            for i in range(num_qubits):
                qml.RY(0.5, wires=i)
            
            # Simple layers
            for layer in range(depth):
                for i in range(num_qubits):
                    qml.RY(weights[layer, i, 0] * dist, wires=i)
                
                for i in range(num_qubits - 1):
                    qml.CNOT(wires=[i, i + 1])
                
                for i in range(num_qubits):
                    qml.RZ(weights[layer, i, 1], wires=i)
                    qml.RY(weights[layer, i, 2], wires=i)
            
            return qml.expval(qml.PauliZ(0) + qml.PauliZ(1) + qml.PauliZ(2) + qml.PauliZ(3))
        
        self.circuit = circuit
        self.vec_circuit = jax.vmap(circuit, (0, None), 0)
    
    def _init_params(self):
        """Initialize parameters."""
        np.random.seed(self.seed)
        weights = np.random.normal(0, 0.1, (self.depth, self.num_qubits, 3))
        self.params = {"weights": jnp.array(weights)}
    
    def get_params(self):
        return self.params
    
    def set_params(self, params):
        self.params = params


# =============================================================================
# GRAPH PERMUTATION EQUIVARIANT QML MODEL
# =============================================================================

class GraphPermutationEquivariantQML:
    """
    Graph Permutation Equivariant Quantum Machine Learning model.
    
    Uses permutation-symmetric encoding based on graph structure:
    - Encodes interatomic distances (permutation invariant features)
    - Uses symmetric pooling operations
    - Circuit structure respects graph connectivity
    """
    
    def __init__(self, num_qubits=4, depth=4, seed=42):
        self.num_qubits = num_qubits
        self.depth = depth
        self.seed = seed
        
        self.dev = qml.device("default.qubit", wires=num_qubits)
        self._create_circuit()
        self._init_params()
    
    def _create_circuit(self):
        """Create the graph permutation equivariant circuit."""
        num_qubits = self.num_qubits
        depth = self.depth
        
        @qml.qnode(self.dev, interface="jax", diff_method="backprop")
        def circuit(positions, params):
            """
            Graph-based permutation equivariant circuit.
            """
            weights = params["weights"]       # (depth, num_qubits, 4)
            edge_weights = params["edge_weights"]  # (depth, num_edges, 2)
            global_weights = params["global_weights"]  # (depth, 3)
            
            # Extract graph features (permutation invariant)
            n_atoms = positions.shape[0]
            
            # Compute pairwise distances (permutation invariant features)
            dist_LiH = jnp.linalg.norm(positions[1] - positions[0])
            
            # Compute direction cosines for directional information
            direction = (positions[1] - positions[0]) / (dist_LiH + 1e-8)
            
            # Symmetric aggregated features
            center_of_mass = jnp.mean(positions, axis=0)
            spread = jnp.std(positions)
            
            # Feature vector (all permutation invariant/equivariant)
            features = jnp.array([
                dist_LiH,
                spread,
                jnp.linalg.norm(center_of_mass),
                direction[2]  # z-component for force direction
            ])
            
            # === Symmetric Initial State ===
            for i in range(num_qubits):
                qml.Hadamard(wires=i)
            
            # === Graph-based Encoding Layers ===
            for layer in range(depth):
                # Node update
                for i in range(num_qubits):
                    angle_y = weights[layer, i, 0] * features[0] + weights[layer, i, 1] * features[1]
                    angle_z = weights[layer, i, 2] * features[2] + weights[layer, i, 3] * features[3]
                    
                    qml.RY(angle_y, wires=i)
                    qml.RZ(angle_z, wires=i)
                
                # Edge operations
                edges = [(i, (i+1) % num_qubits) for i in range(num_qubits)]
                
                for e_idx, (i, j) in enumerate(edges):
                    edge_angle = edge_weights[layer, e_idx % edge_weights.shape[1], 0] * dist_LiH
                    
                    qml.CNOT(wires=[i, j])
                    qml.RZ(edge_angle, wires=j)
                    qml.CNOT(wires=[i, j])
                
                # Global pooling layer
                global_angle = global_weights[layer, 0] * dist_LiH + global_weights[layer, 1]
                for i in range(num_qubits):
                    qml.RY(global_angle * global_weights[layer, 2], wires=i)
            
            # === Permutation Symmetric Measurement ===
            obs = sum(qml.PauliZ(i) for i in range(num_qubits))
            return qml.expval(obs)
        
        self.circuit = circuit
        self.vec_circuit = jax.vmap(circuit, (0, None), 0)
    
    def _init_params(self):
        """Initialize parameters with Xavier-like initialization."""
        np.random.seed(self.seed)
        
        num_edges = self.num_qubits
        
        limit = np.sqrt(2.0 / (self.num_qubits + 4))
        weights = np.random.uniform(-limit, limit, (self.depth, self.num_qubits, 4))
        edge_weights = np.random.uniform(-0.5, 0.5, (self.depth, num_edges, 2))
        global_weights = np.random.uniform(-0.3, 0.3, (self.depth, 3))
        
        self.params = {
            "weights": jnp.array(weights),
            "edge_weights": jnp.array(edge_weights),
            "global_weights": jnp.array(global_weights)
        }
    
    def get_params(self):
        return self.params
    
    def set_params(self, params):
        self.params = params


# =============================================================================
# CLASSICAL ROTATIONALLY EQUIVARIANT NN (E(3) INVARIANT)
# =============================================================================

class ClassicalRotationallyEquivariantNN:
    """
    Classical Rotationally Equivariant Neural Network for LiH.
    Uses physics-inspired E(3) invariant features with smooth activations.
    
    Key improvements:
    - SiLU activation (smooth for autodiff force computation)
    - Multiple physics-inspired features (not just raw distance)
    - Larger network with skip connections
    """
    
    def __init__(self, hidden_dims=[128, 128, 64], seed=42):
        self.hidden_dims = hidden_dims
        self.seed = seed
        
        # Physics parameters for LiH
        self.r_eq = 1.6  # Equilibrium Li-H distance in Å
        self.morse_alpha = 2.0
        
        # RBF parameters
        self.rbf_centers = jnp.linspace(0.8, 3.0, 8)  # 8 Gaussians
        self.rbf_width = 0.3
        
        # Number of features: distance + 1/r + Morse + 8 RBF = 11
        self.n_features = 11
        
        self._init_params()
        self._create_model()
    
    def _init_params(self):
        """Initialize MLP parameters with Xavier initialization."""
        np.random.seed(self.seed)
        
        # Feature dimension -> hidden -> output
        layer_sizes = [self.n_features] + self.hidden_dims + [1]
        
        params = {"weights": [], "biases": []}
        
        for i in range(len(layer_sizes) - 1):
            fan_in = layer_sizes[i]
            fan_out = layer_sizes[i + 1]
            limit = np.sqrt(6.0 / (fan_in + fan_out))
            
            W = np.random.uniform(-limit, limit, (fan_in, fan_out))
            b = np.zeros(fan_out)
            
            params["weights"].append(jnp.array(W))
            params["biases"].append(jnp.array(b))
        
        # Skip connection weights (from input features to last hidden layer)
        skip_dim = layer_sizes[-2]  # Last hidden layer dimension
        params["skip_weight"] = jnp.array(
            np.random.uniform(-0.1, 0.1, (self.n_features, skip_dim))
        )
        
        self.params = params
    
    def _create_model(self):
        """Create the forward pass function with physics-inspired features."""
        
        def compute_features(positions):
            """Compute physics-inspired invariant features from positions."""
            # Li at index 0, H at index 1
            r_vec = positions[1] - positions[0]
            r = jnp.linalg.norm(r_vec) + 1e-12
            
            # Feature 1: Normalized distance
            f_dist = r / 2.0  # Normalize by typical scale
            
            # Feature 2: Inverse distance (Coulomb-like)
            f_inv = 1.0 / r
            
            # Feature 3: Morse-like term
            f_morse = jnp.exp(-self.morse_alpha * (r - self.r_eq))
            
            # Features 4-11: RBF encoding (8 Gaussians)
            f_rbf = jnp.exp(-((r - self.rbf_centers) ** 2) / (2 * self.rbf_width ** 2))
            
            # Concatenate all features
            features = jnp.concatenate([
                jnp.array([f_dist, f_inv, f_morse]),
                f_rbf
            ])
            
            return features
        
        def mlp_forward(x, params):
            """MLP forward pass with SiLU activation and skip connection."""
            weights = params["weights"]
            biases = params["biases"]
            skip_weight = params["skip_weight"]
            
            h = x
            for i in range(len(weights) - 1):
                h = jnp.dot(h, weights[i]) + biases[i]
                # SiLU activation: x * sigmoid(x) - smooth for autodiff!
                h = h * jax.nn.sigmoid(h)
                
                # Add skip connection to last hidden layer
                if i == len(weights) - 2:
                    h = h + 0.1 * jnp.dot(x, skip_weight)
            
            # Output layer (no activation)
            h = jnp.dot(h, weights[-1]) + biases[-1]
            return h.squeeze(-1)
        
        def energy_from_positions(positions, params):
            """Compute energy from atomic positions."""
            features = compute_features(positions)
            energy = mlp_forward(features, params)
            return energy
        
        def force_from_positions(positions, params):
            """Compute forces as negative gradient of energy."""
            grad_fn = jax.grad(energy_from_positions, argnums=0)
            return -grad_fn(positions, params)
        
        self.compute_features = compute_features
        self.energy_fn = energy_from_positions
        self.force_fn = force_from_positions
        self.vec_energy = jax.vmap(energy_from_positions, (0, None), 0)
        self.vec_force = jax.vmap(force_from_positions, (0, None), 0)
    
    def get_params(self):
        return self.params
    
    def set_params(self, params):
        self.params = params


# =============================================================================
# TRAINING FUNCTIONS
# =============================================================================

def train_rotationally_equivariant(model, data_train, E_train, F_train, data_test, E_test, F_test,
                                    n_epochs=200, lr=3e-3, wE=1.0, wF_max=5.0, warmup_frac=0.4):
    """Train the rotationally equivariant QML model with force warmup curriculum."""
    warmup_epochs = int(n_epochs * warmup_frac)
    
    def raw_energy(coords, params):
        """Raw circuit output."""
        return model.circuit(coords, params)
    
    vec_raw_energy = jax.vmap(raw_energy, (0, None), 0)
    
    def vec_force_fn(coords_batch, params):
        """Compute forces as -grad(energy)."""
        def single_force(coords):
            grad_fn = jax.grad(raw_energy, argnums=0)
            return -grad_fn(coords, params)
        return jax.vmap(single_force)(coords_batch)
    
    @jax.jit
    def loss_fn(params, coords, E_target, F_target, wF):
        # Energy with head transformation
        E_raw = vec_raw_energy(coords, params)
        E_pred = params["head_scale"] * E_raw + params["head_bias"]
        L_E = jnp.mean((E_pred - E_target) ** 2)
        
        # Forces (scaled by head_scale)
        F_raw = vec_force_fn(coords, params)
        F_pred = params["head_scale"] * F_raw
        F_pred_z = F_pred[:, 0, 2]  # H atom z-component
        L_F = jnp.mean((F_pred_z - F_target) ** 2)
        
        # Handle NaNs
        L_E = jnp.where(jnp.isnan(L_E), 1.0, L_E)
        L_F = jnp.where(jnp.isnan(L_F), 1.0, L_F)
        
        return wE * L_E + wF * L_F, (L_E, L_F)
    
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(model.params)
    
    history = {"epoch": [], "train_loss": [], "test_E_loss": [], "test_F_loss": []}
    
    for epoch in range(n_epochs):
        # Warmup curriculum: gradually increase force weight
        if epoch < warmup_epochs:
            wF = wF_max * (epoch / warmup_epochs)
        else:
            wF = wF_max
        
        (loss, (L_E, L_F)), grads = jax.value_and_grad(loss_fn, has_aux=True)(
            get_params(opt_state), data_train, E_train, F_train, wF
        )
        
        # Gradient clipping
        grad_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in jax.tree.leaves(grads)))
        if grad_norm > 10.0:
            grads = jax.tree.map(lambda g: g * (10.0 / grad_norm), grads)
        
        opt_state = opt_update(epoch, grads, opt_state)
        
        if (epoch + 1) % max(1, n_epochs // 20) == 0:
            test_params = get_params(opt_state)
            E_raw_test = np.array(vec_raw_energy(data_test, test_params))
            E_pred_test = float(test_params["head_scale"]) * E_raw_test + float(test_params["head_bias"])
            F_raw_test = np.array(vec_force_fn(data_test, test_params))[:, 0, 2]
            F_pred_test = float(test_params["head_scale"]) * F_raw_test
            
            E_test_loss = np.mean((E_pred_test - np.array(E_test)) ** 2)
            F_test_loss = np.mean((F_pred_test - np.array(F_test)) ** 2)
            
            history["epoch"].append(epoch + 1)
            history["train_loss"].append(float(loss))
            history["test_E_loss"].append(float(E_test_loss))
            history["test_F_loss"].append(float(F_test_loss))
    
    model.set_params(get_params(opt_state))
    return history


def train_non_equivariant(model, pos_train, E_train, F_train, pos_test, E_test, F_test,
                          n_epochs=200, lr=0.01, lambda_E=2.0, lambda_F=1.0):
    """Train the non-equivariant QML model."""
    
    def energy_single(coords, params):
        return model.circuit(coords, params)
    
    def force_single(coords, params):
        grad_fn = jax.grad(energy_single, argnums=0)
        return -grad_fn(coords, params)
    
    vec_force = jax.vmap(force_single, (0, None), 0)
    
    @jax.jit
    def combined_loss(params, positions, E_target, F_target):
        E_pred = model.vec_circuit(positions, params)
        E_loss = jnp.mean((E_pred - E_target) ** 2)
        
        F_pred_full = vec_force(positions, params)
        F_pred_z = F_pred_full[:, 1, 2]
        F_loss = jnp.mean((F_pred_z - F_target) ** 2)
        
        E_loss = jnp.where(jnp.isnan(E_loss), 1.0, E_loss)
        F_loss = jnp.where(jnp.isnan(F_loss), 1.0, F_loss)
        
        total_loss = lambda_E * E_loss + lambda_F * F_loss
        return total_loss, (E_loss, F_loss)
    
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(model.params)
    
    history = {"epoch": [], "train_loss": [], "test_E_loss": [], "test_F_loss": []}
    
    for epoch in range(n_epochs):
        (loss, (E_loss, F_loss)), grads = jax.value_and_grad(combined_loss, has_aux=True)(
            get_params(opt_state), pos_train, E_train, F_train
        )
        
        grad_norm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in jax.tree.leaves(grads) if g is not None))
        if grad_norm > 10.0:
            grads = jax.tree.map(lambda g: g * (10.0 / grad_norm) if g is not None else g, grads)
        
        opt_state = opt_update(epoch, grads, opt_state)
        
        if (epoch + 1) % max(1, n_epochs // 20) == 0:
            test_params = get_params(opt_state)
            E_pred_test = np.array(model.vec_circuit(pos_test, test_params))
            F_pred_test = np.array(vec_force(pos_test, test_params))[:, 1, 2]
            
            E_test_loss = np.mean((E_pred_test - np.array(E_test)) ** 2)
            F_test_loss = np.mean((F_pred_test - np.array(F_test)) ** 2)
            
            history["epoch"].append(epoch + 1)
            history["train_loss"].append(float(loss))
            history["test_E_loss"].append(float(E_test_loss))
            history["test_F_loss"].append(float(F_test_loss))
    
    model.set_params(get_params(opt_state))
    return history


def train_graph_permutation_equivariant(model, pos_train, E_train, F_train, pos_test, E_test, F_test,
                                         n_epochs=200, lr=0.01, lambda_E=1.5, lambda_F=1.5):
    """Train the graph permutation equivariant QML model."""
    
    def energy_single(coords, params):
        return model.circuit(coords, params)
    
    def force_single(coords, params):
        grad_fn = jax.grad(energy_single, argnums=0)
        return -grad_fn(coords, params)
    
    vec_force = jax.vmap(force_single, (0, None), 0)
    
    @jax.jit
    def combined_loss(params, positions, E_target, F_target):
        E_pred = model.vec_circuit(positions, params)
        E_loss = jnp.mean((E_pred - E_target) ** 2)
        
        F_pred_full = vec_force(positions, params)
        F_pred_z = F_pred_full[:, 1, 2]
        F_loss = jnp.mean((F_pred_z - F_target) ** 2)
        
        E_loss = jnp.where(jnp.isnan(E_loss), 1.0, E_loss)
        F_loss = jnp.where(jnp.isnan(F_loss), 1.0, F_loss)
        
        total_loss = lambda_E * E_loss + lambda_F * F_loss
        return total_loss, (E_loss, F_loss)
    
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(model.params)
    
    history = {"epoch": [], "train_loss": [], "test_E_loss": [], "test_F_loss": []}
    
    for epoch in range(n_epochs):
        (loss, (E_loss, F_loss)), grads = jax.value_and_grad(combined_loss, has_aux=True)(
            get_params(opt_state), pos_train, E_train, F_train
        )
        
        grad_norm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in jax.tree.leaves(grads) if g is not None))
        if grad_norm > 10.0:
            grads = jax.tree.map(lambda g: g * (10.0 / grad_norm) if g is not None else g, grads)
        
        opt_state = opt_update(epoch, grads, opt_state)
        
        if (epoch + 1) % max(1, n_epochs // 20) == 0:
            test_params = get_params(opt_state)
            E_pred_test = np.array(model.vec_circuit(pos_test, test_params))
            F_pred_test = np.array(vec_force(pos_test, test_params))[:, 1, 2]
            
            E_test_loss = np.mean((E_pred_test - np.array(E_test)) ** 2)
            F_test_loss = np.mean((F_pred_test - np.array(F_test)) ** 2)
            
            history["epoch"].append(epoch + 1)
            history["train_loss"].append(float(loss))
            history["test_E_loss"].append(float(E_test_loss))
            history["test_F_loss"].append(float(F_test_loss))
    
    model.set_params(get_params(opt_state))
    return history


def train_classical_equivariant(model, pos_train, E_train, F_train, pos_test, E_test, F_test,
                                 n_epochs=200, lr=3e-3, wE=1.0, wF_max=2.0, warmup_frac=0.3):
    """
    Train the classical rotationally equivariant NN with two-phase training.
    
    Phase 1 (warmup): Energy-only training to establish good features
    Phase 2: Combined energy + forces with gradual force weight ramp
    """
    warmup_epochs = int(n_epochs * warmup_frac)
    
    def huber_loss(pred, target, delta=0.5):
        """Huber loss - robust to outliers."""
        diff = pred - target
        abs_diff = jnp.abs(diff)
        return jnp.mean(jnp.where(abs_diff <= delta, 
                                   0.5 * diff**2, 
                                   delta * (abs_diff - 0.5 * delta)))
    
    @jax.jit
    def loss_fn(params, positions, E_target, F_target, wF):
        E_pred = model.vec_energy(positions, params)
        E_loss = jnp.mean((E_pred - E_target) ** 2)
        
        F_pred_full = model.vec_force(positions, params)
        F_pred_z = F_pred_full[:, 1, 2]
        F_loss = huber_loss(F_pred_z, F_target, delta=0.5)
        
        E_loss = jnp.where(jnp.isnan(E_loss), 1.0, E_loss)
        F_loss = jnp.where(jnp.isnan(F_loss), 1.0, F_loss)
        
        return wE * E_loss + wF * F_loss, (E_loss, F_loss)
    
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(model.params)
    
    history = {"epoch": [], "train_loss": [], "test_E_loss": [], "test_F_loss": []}
    
    for epoch in range(n_epochs):
        # Two-phase training
        if epoch < warmup_epochs:
            wF = 0.0  # Energy-only warmup
        else:
            phase2_progress = (epoch - warmup_epochs) / max(1, (n_epochs - warmup_epochs) / 2)
            wF = min(wF_max, wF_max * phase2_progress)
        
        (loss, (E_loss, F_loss)), grads = jax.value_and_grad(loss_fn, has_aux=True)(
            get_params(opt_state), pos_train, E_train, F_train, wF
        )
        
        # Gradient clipping
        grad_norm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in jax.tree.leaves(grads) if g is not None))
        if grad_norm > 5.0:
            grads = jax.tree.map(lambda g: g * (5.0 / grad_norm) if g is not None else g, grads)
        
        opt_state = opt_update(epoch, grads, opt_state)
        
        if (epoch + 1) % max(1, n_epochs // 20) == 0:
            test_params = get_params(opt_state)
            E_pred_test = np.array(model.vec_energy(pos_test, test_params))
            F_pred_test = np.array(model.vec_force(pos_test, test_params))[:, 1, 2]
            
            E_test_loss = np.mean((E_pred_test - np.array(E_test)) ** 2)
            F_test_loss = np.mean((F_pred_test - np.array(F_test)) ** 2)
            
            history["epoch"].append(epoch + 1)
            history["train_loss"].append(float(loss))
            history["test_E_loss"].append(float(E_test_loss))
            history["test_F_loss"].append(float(F_test_loss))
    
    model.set_params(get_params(opt_state))
    return history


# =============================================================================
# DATA LOADING
# =============================================================================

def load_lih_data(data_dir="eqnn_force_field_data_LiH"):
    """Load LiH molecular data."""
    energy = np.load(os.path.join(data_dir, "Energy.npy"))
    forces = np.load(os.path.join(data_dir, "Forces.npy"))
    positions = np.load(os.path.join(data_dir, "Positions.npy"))
    
    return energy, forces, positions


def prepare_data(energy, forces, positions, test_split=0.2, seed=42):
    """Prepare and scale data for training."""
    shape = positions.shape
    
    # Scale energy
    energy_scaler = MinMaxScaler((-1, 1))
    if energy.ndim == 1:
        energy = energy.reshape(-1, 1)
    energy_scaled = energy_scaler.fit_transform(energy).flatten()
    
    # Center molecule positions
    n_atoms_total = positions.shape[1]
    positions_centered = np.zeros((shape[0], n_atoms_total - 1, 3))
    positions_centered[:, 0, :] = positions[:, 1, :] - positions[:, 0, :]
    
    # Scale forces (z-component of H atom)
    forces_H = forces[:, 1:, :]
    force_scaler = MinMaxScaler((-1, 1))
    forces_z_only = forces_H[:, 0, 2].reshape(-1, 1)
    forces_z_scaled = force_scaler.fit_transform(forces_z_only).flatten()
    
    forces_scaled = np.zeros_like(forces_H)
    forces_scaled[:, 0, 2] = forces_z_scaled
    
    # Train/test split
    np.random.seed(seed)
    n_samples = shape[0]
    indices_train = np.random.choice(np.arange(n_samples), size=int((1-test_split) * n_samples), replace=False)
    indices_test = np.setdiff1d(np.arange(n_samples), indices_train)
    
    data = {
        "energy_scaler": energy_scaler,
        "force_scaler": force_scaler,
        "energy_scaled": energy_scaled,
        "forces_scaled": forces_scaled,
        "positions_centered": positions_centered,
        "positions_raw": positions,
        "forces_H": forces_H,
        "indices_train": indices_train,
        "indices_test": indices_test,
    }
    
    return data


# =============================================================================
# EVALUATION
# =============================================================================

def evaluate_model(model, data, model_type="rotationally_equivariant"):
    """Evaluate model and compute metrics."""
    
    positions_centered = data["positions_centered"]
    positions_raw = data["positions_raw"]
    energy_scaled = data["energy_scaled"]
    forces_z_scaled = data["forces_scaled"][:, 0, 2]
    energy_scaler = data["energy_scaler"]
    force_scaler = data["force_scaler"]
    indices_train = data["indices_train"]
    indices_test = data["indices_test"]
    forces_H = data["forces_H"]
    
    if model_type == "rotationally_equivariant":
        # Apply head transformation for rotationally equivariant model
        params = model.params
        E_raw = np.array(model.vec_circuit(jnp.array(positions_centered), params))
        E_pred_scaled = float(params["head_scale"]) * E_raw + float(params["head_bias"])
        
        def raw_energy(coords, params):
            return model.circuit(coords, params)
        def force_single(coords, params):
            return -jax.grad(raw_energy, argnums=0)(coords, params)
        vec_force = jax.vmap(force_single, (0, None), 0)
        
        F_raw = np.array(vec_force(jnp.array(positions_centered), params))
        F_pred_z_scaled = float(params["head_scale"]) * F_raw[:, 0, 2]
        
    elif model_type == "classical_equivariant":
        E_pred_scaled = np.array(model.vec_energy(jnp.array(positions_raw), model.params))
        F_pred_all = np.array(model.vec_force(jnp.array(positions_raw), model.params))
        F_pred_z_scaled = F_pred_all[:, 1, 2]
        
    else:
        # For non-equivariant and graph permutation equivariant QML
        E_pred_scaled = np.array(model.vec_circuit(jnp.array(positions_raw), model.params))
        
        def energy_single(coords, params):
            return model.circuit(coords, params)
        def force_single(coords, params):
            return -jax.grad(energy_single, argnums=0)(coords, params)
        vec_force = jax.vmap(force_single, (0, None), 0)
        
        F_pred_all = np.array(vec_force(jnp.array(positions_raw), model.params))
        F_pred_z_scaled = F_pred_all[:, 1, 2]
    
    # Post-correction for energy
    def corr_E(E, a, b, c):
        return a * E**2 + b * E + c
    
    try:
        popt_E, _ = curve_fit(corr_E, E_pred_scaled[indices_train], energy_scaled[indices_train])
        E_pred_corrected = corr_E(E_pred_scaled, *popt_E)
    except:
        E_pred_corrected = E_pred_scaled
    
    # Post-correction for force
    try:
        lr_model = LinearRegression()
        lr_model.fit(F_pred_z_scaled[indices_train].reshape(-1, 1), forces_z_scaled[indices_train])
        F_pred_corrected = lr_model.predict(F_pred_z_scaled.reshape(-1, 1)).flatten()
    except:
        F_pred_corrected = F_pred_z_scaled
    
    # Inverse transform
    E_pred_original = energy_scaler.inverse_transform(E_pred_corrected.reshape(-1, 1)).flatten()
    F_pred_original = force_scaler.inverse_transform(F_pred_corrected.reshape(-1, 1)).flatten()
    
    E_true_original = energy_scaler.inverse_transform(energy_scaled.reshape(-1, 1)).flatten()
    F_true_original = forces_H[:, 0, 2]
    
    # Compute metrics on test set
    E_mae = np.mean(np.abs(E_pred_original[indices_test] - E_true_original[indices_test]))
    E_rmse = np.sqrt(np.mean((E_pred_original[indices_test] - E_true_original[indices_test]) ** 2))
    E_r2 = 1 - np.sum((E_pred_original[indices_test] - E_true_original[indices_test])**2) / \
               np.sum((E_true_original[indices_test] - E_true_original[indices_test].mean())**2)
    
    F_mae = np.mean(np.abs(F_pred_original[indices_test] - F_true_original[indices_test]))
    F_rmse = np.sqrt(np.mean((F_pred_original[indices_test] - F_true_original[indices_test]) ** 2))
    F_r2 = 1 - np.sum((F_pred_original[indices_test] - F_true_original[indices_test])**2) / \
               np.sum((F_true_original[indices_test] - F_true_original[indices_test].mean())**2)
    
    metrics = {
        "E_mae_Ha": float(E_mae),
        "E_mae_eV": float(E_mae * 27.2114),
        "E_rmse_Ha": float(E_rmse),
        "E_rmse_eV": float(E_rmse * 27.2114),
        "E_r2": float(E_r2),
        "F_mae": float(F_mae),
        "F_rmse": float(F_rmse),
        "F_r2": float(F_r2),
    }
    
    predictions = {
        "E_pred": E_pred_original.tolist(),
        "E_true": E_true_original.tolist(),
        "F_pred": F_pred_original.tolist(),
        "F_true": F_true_original.tolist(),
        "indices_test": indices_test.tolist(),
    }
    
    return metrics, predictions


# =============================================================================
# PLOTTING
# =============================================================================

def create_comparison_plots(all_results, output_dir):
    """Create comparison plots for all four methods."""
    
    methods = ["rotationally_equivariant", "non_equivariant", "graph_permutation_equivariant", "classical_equivariant"]
    method_names = ["Rot. Equiv. QML", "Non-Equiv. QML", "Graph Perm. QML", "Classical Equiv. NN"]
    colors = ["#2ecc71", "#e74c3c", "#3498db", "#9b59b6"]
    
    # =====================================================
    # Figure 1: Training curves (2x2)
    # =====================================================
    fig1, axes1 = plt.subplots(2, 2, figsize=(12, 10))
    axes1 = axes1.flatten()
    
    for idx, (method, name, color) in enumerate(zip(methods, method_names, colors)):
        ax = axes1[idx]
        
        if len(all_results[method]["runs"]) > 0:
            history = all_results[method]["runs"][0]["history"]
            epochs = history["epoch"]
            
            ax.plot(epochs, history["train_loss"], '-', color=color, lw=2, label='Train Loss')
            ax.plot(epochs, history["test_E_loss"], '--', color=color, lw=2, alpha=0.7, label='Test E Loss')
            ax.plot(epochs, history["test_F_loss"], ':', color=color, lw=2, alpha=0.7, label='Test F Loss')
            
            ax.set_xlabel('Epoch')
            ax.set_ylabel('Loss')
            ax.set_title(f'{name}\nTraining Curves')
            ax.legend(fontsize=8)
            ax.grid(True, alpha=0.3)
            ax.set_yscale('log')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'training_curves.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    # =====================================================
    # Figure 2: Bar chart comparisons
    # =====================================================
    fig2, axes2 = plt.subplots(1, 3, figsize=(15, 5))
    
    x_pos = np.arange(len(methods))
    
    # Energy R² comparison
    ax = axes2[0]
    E_r2_means = [all_results[m]["metrics_summary"]["E_r2"]["mean"] for m in methods]
    E_r2_stds = [all_results[m]["metrics_summary"]["E_r2"]["std"] for m in methods]
    
    bars = ax.bar(x_pos, E_r2_means, yerr=E_r2_stds, color=colors, alpha=0.8, capsize=5)
    ax.set_xticks(x_pos)
    ax.set_xticklabels(method_names, fontsize=9, rotation=15, ha='right')
    ax.set_ylabel('Energy R²')
    ax.set_title('Energy Prediction R²')
    ax.set_ylim([0, 1.15])
    ax.grid(True, alpha=0.3, axis='y')
    
    for bar, mean in zip(bars, E_r2_means):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
                f'{mean:.3f}', ha='center', va='bottom', fontsize=9)
    
    # Force R² comparison
    ax = axes2[1]
    F_r2_means = [all_results[m]["metrics_summary"]["F_r2"]["mean"] for m in methods]
    F_r2_stds = [all_results[m]["metrics_summary"]["F_r2"]["std"] for m in methods]
    
    bars = ax.bar(x_pos, F_r2_means, yerr=F_r2_stds, color=colors, alpha=0.8, capsize=5)
    ax.set_xticks(x_pos)
    ax.set_xticklabels(method_names, fontsize=9, rotation=15, ha='right')
    ax.set_ylabel('Force R²')
    ax.set_title('Force Prediction R²')
    ax.set_ylim([0, 1.15])
    ax.grid(True, alpha=0.3, axis='y')
    
    for bar, mean in zip(bars, F_r2_means):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
                f'{mean:.3f}', ha='center', va='bottom', fontsize=9)
    
    # MAE comparison
    ax = axes2[2]
    width = 0.35
    E_mae_means = [all_results[m]["metrics_summary"]["E_mae_Ha"]["mean"] * 1000 for m in methods]
    F_mae_means = [all_results[m]["metrics_summary"]["F_mae"]["mean"] for m in methods]
    
    bars1 = ax.bar(x_pos - width/2, E_mae_means, width, color=colors, alpha=0.6, label='Energy MAE (mHa)')
    bars2 = ax.bar(x_pos + width/2, F_mae_means, width, color=colors, alpha=1.0, hatch='//', label='Force MAE (eV/Å)')
    
    ax.set_xticks(x_pos)
    ax.set_xticklabels(method_names, fontsize=9, rotation=15, ha='right')
    ax.set_ylabel('MAE')
    ax.set_title('Mean Absolute Errors')
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'metrics_comparison.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    # =====================================================
    # Figure 3: Scatter plots (2x4)
    # =====================================================
    fig3, axes3 = plt.subplots(2, 4, figsize=(16, 8))
    
    for idx, (method, name, color) in enumerate(zip(methods, method_names, colors)):
        if len(all_results[method]["runs"]) > 0:
            predictions = all_results[method]["runs"][0]["predictions"]
            indices_test = predictions["indices_test"]
            
            E_pred = np.array(predictions["E_pred"])[indices_test]
            E_true = np.array(predictions["E_true"])[indices_test]
            F_pred = np.array(predictions["F_pred"])[indices_test]
            F_true = np.array(predictions["F_true"])[indices_test]
            
            metrics = all_results[method]["runs"][0]["metrics"]
            
            # Energy scatter
            ax = axes3[0, idx]
            ax.scatter(E_true, E_pred, c=color, alpha=0.6, s=20)
            lims = [min(E_true.min(), E_pred.min()), max(E_true.max(), E_pred.max())]
            ax.plot(lims, lims, 'k--', lw=2)
            ax.set_xlabel('True Energy (Ha)')
            ax.set_ylabel('Predicted Energy (Ha)')
            ax.set_title(f'{name}\nE R²={metrics["E_r2"]:.3f}')
            ax.grid(True, alpha=0.3)
            
            # Force scatter
            ax = axes3[1, idx]
            ax.scatter(F_true, F_pred, c=color, alpha=0.6, s=20)
            lims = [min(F_true.min(), F_pred.min()), max(F_true.max(), F_pred.max())]
            ax.plot(lims, lims, 'k--', lw=2)
            ax.set_xlabel('True Force Z (eV/Å)')
            ax.set_ylabel('Predicted Force Z (eV/Å)')
            ax.set_title(f'{name}\nF R²={metrics["F_r2"]:.3f}')
            ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'predictions_scatter.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    # =====================================================
    # Figure 4: Summary comparison plot
    # =====================================================
    fig4, ax = plt.subplots(1, 1, figsize=(10, 6))
    
    # Create grouped bar chart for R² values
    width = 0.35
    x = np.arange(len(methods))
    
    E_r2 = [all_results[m]["metrics_summary"]["E_r2"]["mean"] for m in methods]
    F_r2 = [all_results[m]["metrics_summary"]["F_r2"]["mean"] for m in methods]
    E_r2_err = [all_results[m]["metrics_summary"]["E_r2"]["std"] for m in methods]
    F_r2_err = [all_results[m]["metrics_summary"]["F_r2"]["std"] for m in methods]
    
    bars1 = ax.bar(x - width/2, E_r2, width, yerr=E_r2_err, label='Energy R²', 
                   color='steelblue', alpha=0.8, capsize=4)
    bars2 = ax.bar(x + width/2, F_r2, width, yerr=F_r2_err, label='Force R²', 
                   color='coral', alpha=0.8, capsize=4)
    
    ax.set_ylabel('R² Score')
    ax.set_title('Method Comparison: Energy vs Force Prediction Performance')
    ax.set_xticks(x)
    ax.set_xticklabels(method_names, fontsize=10)
    ax.legend()
    ax.set_ylim([0, 1.1])
    ax.grid(True, alpha=0.3, axis='y')
    
    # Add value labels
    for bar in bars1:
        height = bar.get_height()
        ax.annotate(f'{height:.3f}',
                    xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3), textcoords="offset points",
                    ha='center', va='bottom', fontsize=8)
    for bar in bars2:
        height = bar.get_height()
        ax.annotate(f'{height:.3f}',
                    xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3), textcoords="offset points",
                    ha='center', va='bottom', fontsize=8)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'summary_comparison.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"  Plots saved to: {output_dir}/")


# =============================================================================
# MAIN COMPARISON
# =============================================================================

def run_comparison(n_runs=3, n_epochs=100, output_dir="lih_comparison_results", data_dir="eqnn_force_field_data_LiH"):
    """Run comparison between all four methods."""
    
    os.makedirs(output_dir, exist_ok=True)
    
    print("="*90)
    print("LiH Energy/Force Prediction: Four Methods Comparison")
    print("="*90)
    print("Methods:")
    print("  1. Rotationally Equivariant QML (SO(3) symmetry)")
    print("  2. Non-Equivariant QML (baseline)")
    print("  3. Graph Permutation Equivariant QML (permutation symmetry)")
    print("  4. Classical Rotationally Equivariant NN (E(3) invariant MLP)")
    print("="*90)
    
    # Load data
    print("\nLoading data...")
    try:
        energy, forces, positions = load_lih_data(data_dir)
        print(f"  Loaded {len(energy)} samples")
    except FileNotFoundError:
        print(f"  ERROR: Data not found in {data_dir}")
        print("  Please ensure the LiH data files are available")
        return None
    
    # Prepare data
    data = prepare_data(energy, forces, positions)
    print(f"  Train: {len(data['indices_train'])}, Test: {len(data['indices_test'])}")
    
    # Results storage
    all_results = {
        "config": {
            "n_runs": n_runs,
            "n_epochs": n_epochs,
            "timestamp": datetime.now().isoformat(),
        },
        "rotationally_equivariant": {"runs": [], "metrics_summary": {}},
        "non_equivariant": {"runs": [], "metrics_summary": {}},
        "graph_permutation_equivariant": {"runs": [], "metrics_summary": {}},
        "classical_equivariant": {"runs": [], "metrics_summary": {}},
    }
    
    # Prepare training data
    E_train = data["energy_scaled"][data["indices_train"]]
    E_test = data["energy_scaled"][data["indices_test"]]
    
    # For rotationally equivariant QML model
    data_train_eq = jnp.array(data["positions_centered"][data["indices_train"]])
    data_test_eq = jnp.array(data["positions_centered"][data["indices_test"]])
    # Extract z-component of H atom force (index 0 since centered on Li)
    F_train_eq = data["forces_scaled"][data["indices_train"], 0, 2]
    F_test_eq = data["forces_scaled"][data["indices_test"], 0, 2]
    
    # For non-equivariant, graph permutation, and classical models
    pos_train_neq = jnp.array(data["positions_raw"][data["indices_train"]])
    pos_test_neq = jnp.array(data["positions_raw"][data["indices_test"]])
    F_train_neq = data["forces_scaled"][data["indices_train"], 0, 2]
    F_test_neq = data["forces_scaled"][data["indices_test"], 0, 2]
    
    # Run experiments
    for run in range(n_runs):
        print(f"\n{'='*90}")
        print(f"RUN {run+1}/{n_runs}")
        print(f"{'='*90}")
        
        run_seed = 42 + run * 100
        
        # --- 1. Rotationally Equivariant QML Model ---
        print(f"\n[1/4] Rotationally Equivariant QML")
        eq_model = RotationallyEquivariantQML(n_qubits=6, depth=6, seed=run_seed)
        
        print(f"  Training for {n_epochs} epochs...")
        eq_history = train_rotationally_equivariant(
            eq_model, data_train_eq, E_train, F_train_eq,
            data_test_eq, E_test, F_test_eq, n_epochs=n_epochs
        )
        
        print(f"  Evaluating...")
        eq_metrics, eq_predictions = evaluate_model(eq_model, data, "rotationally_equivariant")
        
        print(f"  Energy: MAE={eq_metrics['E_mae_Ha']:.6f} Ha, R²={eq_metrics['E_r2']:.4f}")
        print(f"  Force:  MAE={eq_metrics['F_mae']:.4f} eV/Å, R²={eq_metrics['F_r2']:.4f}")
        
        all_results["rotationally_equivariant"]["runs"].append({
            "run_id": run,
            "seed": run_seed,
            "history": eq_history,
            "metrics": eq_metrics,
            "predictions": eq_predictions,
        })
        
        # --- 2. Non-Equivariant QML Model ---
        print(f"\n[2/4] Non-Equivariant QML")
        neq_model = NonEquivariantQML(num_qubits=4, depth=3, seed=run_seed)
        
        print(f"  Training for {n_epochs} epochs...")
        neq_history = train_non_equivariant(
            neq_model, pos_train_neq, E_train, F_train_neq,
            pos_test_neq, E_test, F_test_neq, n_epochs=n_epochs
        )
        
        print(f"  Evaluating...")
        neq_metrics, neq_predictions = evaluate_model(neq_model, data, "non_equivariant")
        
        print(f"  Energy: MAE={neq_metrics['E_mae_Ha']:.6f} Ha, R²={neq_metrics['E_r2']:.4f}")
        print(f"  Force:  MAE={neq_metrics['F_mae']:.4f} eV/Å, R²={neq_metrics['F_r2']:.4f}")
        
        all_results["non_equivariant"]["runs"].append({
            "run_id": run,
            "seed": run_seed,
            "history": neq_history,
            "metrics": neq_metrics,
            "predictions": neq_predictions,
        })
        
        # --- 3. Graph Permutation Equivariant QML Model ---
        print(f"\n[3/4] Graph Permutation Equivariant QML")
        gpe_model = GraphPermutationEquivariantQML(num_qubits=4, depth=4, seed=run_seed)
        
        print(f"  Training for {n_epochs} epochs...")
        gpe_history = train_graph_permutation_equivariant(
            gpe_model, pos_train_neq, E_train, F_train_neq,
            pos_test_neq, E_test, F_test_neq, n_epochs=n_epochs
        )
        
        print(f"  Evaluating...")
        gpe_metrics, gpe_predictions = evaluate_model(gpe_model, data, "graph_permutation_equivariant")
        
        print(f"  Energy: MAE={gpe_metrics['E_mae_Ha']:.6f} Ha, R²={gpe_metrics['E_r2']:.4f}")
        print(f"  Force:  MAE={gpe_metrics['F_mae']:.4f} eV/Å, R²={gpe_metrics['F_r2']:.4f}")
        
        all_results["graph_permutation_equivariant"]["runs"].append({
            "run_id": run,
            "seed": run_seed,
            "history": gpe_history,
            "metrics": gpe_metrics,
            "predictions": gpe_predictions,
        })
        
        # --- 4. Classical Rotationally Equivariant NN ---
        print(f"\n[4/4] Classical Rotationally Equivariant NN")
        classical_model = ClassicalRotationallyEquivariantNN(hidden_dims=[128, 128, 64], seed=run_seed)
        
        print(f"  Training for {n_epochs} epochs...")
        classical_history = train_classical_equivariant(
            classical_model, pos_train_neq, E_train, F_train_neq,
            pos_test_neq, E_test, F_test_neq, n_epochs=n_epochs
        )
        
        print(f"  Evaluating...")
        classical_metrics, classical_predictions = evaluate_model(classical_model, data, "classical_equivariant")
        
        print(f"  Energy: MAE={classical_metrics['E_mae_Ha']:.6f} Ha, R²={classical_metrics['E_r2']:.4f}")
        print(f"  Force:  MAE={classical_metrics['F_mae']:.4f} eV/Å, R²={classical_metrics['F_r2']:.4f}")
        
        all_results["classical_equivariant"]["runs"].append({
            "run_id": run,
            "seed": run_seed,
            "history": classical_history,
            "metrics": classical_metrics,
            "predictions": classical_predictions,
        })
    
    # Compute summary statistics
    for model_type in ["rotationally_equivariant", "non_equivariant", "graph_permutation_equivariant", "classical_equivariant"]:
        metrics_list = [r["metrics"] for r in all_results[model_type]["runs"]]
        
        summary = {}
        for key in metrics_list[0].keys():
            values = [m[key] for m in metrics_list]
            summary[key] = {
                "mean": float(np.mean(values)),
                "std": float(np.std(values)),
                "min": float(np.min(values)),
                "max": float(np.max(values)),
                "values": values,
            }
        all_results[model_type]["metrics_summary"] = summary
    
    # Save results
    results_path = os.path.join(output_dir, "results.json")
    with open(results_path, 'w') as f:
        json.dump(all_results, f, indent=2)
    print(f"\nResults saved to: {results_path}")
    
    # Save numpy arrays for easy loading
    np.savez(os.path.join(output_dir, "metrics.npz"),
             # Rotationally Equivariant QML
             rot_eq_E_r2=[r["metrics"]["E_r2"] for r in all_results["rotationally_equivariant"]["runs"]],
             rot_eq_F_r2=[r["metrics"]["F_r2"] for r in all_results["rotationally_equivariant"]["runs"]],
             rot_eq_E_mae=[r["metrics"]["E_mae_Ha"] for r in all_results["rotationally_equivariant"]["runs"]],
             rot_eq_F_mae=[r["metrics"]["F_mae"] for r in all_results["rotationally_equivariant"]["runs"]],
             # Non-Equivariant QML
             neq_E_r2=[r["metrics"]["E_r2"] for r in all_results["non_equivariant"]["runs"]],
             neq_F_r2=[r["metrics"]["F_r2"] for r in all_results["non_equivariant"]["runs"]],
             neq_E_mae=[r["metrics"]["E_mae_Ha"] for r in all_results["non_equivariant"]["runs"]],
             neq_F_mae=[r["metrics"]["F_mae"] for r in all_results["non_equivariant"]["runs"]],
             # Graph Permutation Equivariant QML
             gpe_E_r2=[r["metrics"]["E_r2"] for r in all_results["graph_permutation_equivariant"]["runs"]],
             gpe_F_r2=[r["metrics"]["F_r2"] for r in all_results["graph_permutation_equivariant"]["runs"]],
             gpe_E_mae=[r["metrics"]["E_mae_Ha"] for r in all_results["graph_permutation_equivariant"]["runs"]],
             gpe_F_mae=[r["metrics"]["F_mae"] for r in all_results["graph_permutation_equivariant"]["runs"]],
             # Classical Equivariant NN
             classical_E_r2=[r["metrics"]["E_r2"] for r in all_results["classical_equivariant"]["runs"]],
             classical_F_r2=[r["metrics"]["F_r2"] for r in all_results["classical_equivariant"]["runs"]],
             classical_E_mae=[r["metrics"]["E_mae_Ha"] for r in all_results["classical_equivariant"]["runs"]],
             classical_F_mae=[r["metrics"]["F_mae"] for r in all_results["classical_equivariant"]["runs"]])
    
    # Create plots
    print("\nGenerating comparison plots...")
    create_comparison_plots(all_results, output_dir)
    
    # Print summary
    print("\n" + "="*100)
    print("SUMMARY")
    print("="*100)
    print(f"\n{'Metric':<15} {'Rot. Equiv. QML':<20} {'Non-Equiv. QML':<20} {'Graph Perm. QML':<20} {'Classical Equiv.':<20}")
    print("-"*95)
    
    for metric in ["E_r2", "E_mae_Ha", "F_r2", "F_mae"]:
        rot_mean = all_results["rotationally_equivariant"]["metrics_summary"][metric]["mean"]
        rot_std = all_results["rotationally_equivariant"]["metrics_summary"][metric]["std"]
        neq_mean = all_results["non_equivariant"]["metrics_summary"][metric]["mean"]
        neq_std = all_results["non_equivariant"]["metrics_summary"][metric]["std"]
        gpe_mean = all_results["graph_permutation_equivariant"]["metrics_summary"][metric]["mean"]
        gpe_std = all_results["graph_permutation_equivariant"]["metrics_summary"][metric]["std"]
        cls_mean = all_results["classical_equivariant"]["metrics_summary"][metric]["mean"]
        cls_std = all_results["classical_equivariant"]["metrics_summary"][metric]["std"]
        
        print(f"{metric:<15} {rot_mean:.4f}±{rot_std:.4f}       {neq_mean:.4f}±{neq_std:.4f}       "
              f"{gpe_mean:.4f}±{gpe_std:.4f}       {cls_mean:.4f}±{cls_std:.4f}")
    
    print("="*100)
    
    return all_results


# =============================================================================
# MAIN
# =============================================================================

def main(n_runs=2, n_epochs=50, output_dir="lih_comparison_results", data_dir="eqnn_force_field_data_LiH"):
    """
    Main function - can be called directly from Jupyter or command line.
    
    Args:
        n_runs: Number of runs for each model
        n_epochs: Training epochs per run
        output_dir: Directory to save results
        data_dir: Directory containing LiH data (.npy files)
    
    Returns:
        results dictionary
    """
    return run_comparison(
        n_runs=n_runs,
        n_epochs=n_epochs,
        output_dir=output_dir,
        data_dir=data_dir
    )


if __name__ == "__main__":
    import sys
    
    # Check if running in Jupyter
    if 'ipykernel' in sys.modules:
        print("Running in Jupyter notebook. Call main() directly with parameters:")
        print("  results = main(n_runs=2, n_epochs=50, output_dir='lih_results', data_dir='eqnn_force_field_data_LiH')")
    else:
        parser = argparse.ArgumentParser(description="Compare Four Methods on LiH")
        parser.add_argument("--n_runs", type=int, default=2, help="Number of runs")
        parser.add_argument("--n_epochs", type=int, default=50, help="Training epochs per run")
        parser.add_argument("--output_dir", type=str, default="lih_comparison_results", help="Output directory")
        parser.add_argument("--data_dir", type=str, default="eqnn_force_field_data_LiH", help="Data directory")
        
        args = parser.parse_args()
        
        results = main(
            n_runs=args.n_runs,
            n_epochs=args.n_epochs,
            output_dir=args.output_dir,
            data_dir=args.data_dir
        )

Running in Jupyter notebook. Call main() directly with parameters:
  results = main(n_runs=2, n_epochs=50, output_dir='lih_results', data_dir='eqnn_force_field_data_LiH')


In [4]:
results = main(n_runs=1, n_epochs=200, output_dir='lih_results', data_dir='eqnn_force_field_data_LiH')

LiH Energy/Force Prediction: Four Methods Comparison
Methods:
  1. Rotationally Equivariant QML (SO(3) symmetry)
  2. Non-Equivariant QML (baseline)
  3. Graph Permutation Equivariant QML (permutation symmetry)
  4. Classical Rotationally Equivariant NN (E(3) invariant MLP)

Loading data...
  Loaded 2400 samples
  Train: 1920, Test: 480

RUN 1/1

[1/4] Rotationally Equivariant QML
  Training for 200 epochs...
  Evaluating...
  Energy: MAE=0.027774 Ha, R²=0.9963
  Force:  MAE=0.8454 eV/Å, R²=0.9937

[2/4] Non-Equivariant QML
  Training for 200 epochs...
  Evaluating...
  Energy: MAE=0.026232 Ha, R²=0.9979
  Force:  MAE=2.5130 eV/Å, R²=0.9584

[3/4] Graph Permutation Equivariant QML
  Training for 200 epochs...
  Evaluating...
  Energy: MAE=0.033164 Ha, R²=0.9964
  Force:  MAE=2.5839 eV/Å, R²=0.9548

[4/4] Classical Rotationally Equivariant NN
  Training for 200 epochs...
  Evaluating...
  Energy: MAE=0.023289 Ha, R²=0.9978
  Force:  MAE=0.9685 eV/Å, R²=0.9930

Results saved to: lih_resu