In [1]:
"""
Comparison of Graph Equivariant vs Non-Equivariant QML for LiH Energy/Force Prediction

This script runs both methods on LiH molecular data and compares their performance.
All results are saved to an output directory for later analysis.

Methods compared:
1. Graph Equivariant QML - Uses SO(3) equivariant encoding with Heisenberg observable
2. Non-Equivariant QML - Simple QNN with basic rotations

Usage:
    python run_comparison.py --n_runs 3 --n_epochs 100 --output_dir results
"""

import pennylane as qml
import numpy as np
import json
import os
from datetime import datetime
import argparse

import jax
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)
from jax import numpy as jnp

from jax.example_libraries import optimizers
from sklearn.preprocessing import MinMaxScaler
from scipy.optimize import curve_fit
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# =============================================================================
# GRAPH EQUIVARIANT QML MODEL
# =============================================================================

class EquivariantQML:
    """
    Graph Equivariant Quantum Machine Learning model for molecular properties.
    Uses SO(3) equivariant encoding with Heisenberg Hamiltonian observable.
    """
    
    def __init__(self, num_qubits=3, depth=6, blocks=2, seed=42):
        self.num_qubits = num_qubits
        self.depth = depth
        self.blocks = blocks
        self.seed = seed
        
        # Pauli matrices
        self.X = np.array([[0, 1], [1, 0]])
        self.Y = np.array([[0, -1.0j], [1.0j, 0]])
        self.Z = np.array([[1, 0], [0, -1]])
        
        self.sigmas = jnp.array(np.array([self.X, self.Y, self.Z]))
        self.sigmas_sigmas = jnp.array(np.array([
            np.kron(self.X, self.X),
            np.kron(self.Y, self.Y),
            np.kron(self.Z, self.Z)
        ]))
        
        # Create device and circuit
        self.dev = qml.device("default.qubit", wires=num_qubits)
        self._create_circuit()
        
        # Initialize parameters
        self._init_params()
    
    def _create_circuit(self):
        """Create the equivariant quantum circuit."""
        num_qubits = self.num_qubits
        depth = self.depth
        blocks = self.blocks
        sigmas = self.sigmas
        sigmas_sigmas = self.sigmas_sigmas
        
        # Heisenberg observable
        Heisenberg = [
            qml.PauliX(0) @ qml.PauliX(1),
            qml.PauliY(0) @ qml.PauliY(1),
            qml.PauliZ(0) @ qml.PauliZ(1),
        ]
        self.Observable = qml.Hamiltonian(np.ones((3)), Heisenberg)
        
        def singlet(wires):
            qml.Hadamard(wires=wires[0])
            qml.PauliZ(wires=wires[0])
            qml.PauliX(wires=wires[1])
            qml.CNOT(wires=wires)
        
        def equivariant_encoding(alpha, data, wires):
            hamiltonian = jnp.einsum("i,ijk", data, sigmas)
            U = jax.scipy.linalg.expm(-1.0j * alpha * hamiltonian / 2)
            qml.QubitUnitary(U, wires=wires, id="E")
        
        def trainable_layer(weight, wires):
            hamiltonian = jnp.einsum("ijk->jk", sigmas_sigmas)
            U = jax.scipy.linalg.expm(-1.0j * weight * hamiltonian)
            qml.QubitUnitary(U, wires=wires, id="U")
        
        @qml.qnode(self.dev, interface="jax")
        def circuit(data, params):
            weights = params["params"]["weights"]
            alphas = params["params"]["alphas"]
            
            # Initial entangled state
            if num_qubits >= 2:
                singlet(wires=[0, 1])
            if num_qubits >= 3:
                qml.CNOT(wires=[1, 2])
            
            # Initial encoding
            for i in range(num_qubits):
                equivariant_encoding(alphas[i, 0], jnp.asarray(data, dtype=complex)[i % 1, ...], wires=[i])
            
            # Layers
            for d in range(depth):
                qml.Barrier()
                for b in range(blocks):
                    for i in range(0, num_qubits - 1, 2):
                        trainable_layer(weights[i, d + 1, b], wires=[i, (i + 1) % num_qubits])
                    for i in range(1, num_qubits, 2):
                        trainable_layer(weights[i, d + 1, b], wires=[i, (i + 1) % num_qubits])
                
                for i in range(num_qubits):
                    equivariant_encoding(alphas[i, d + 1], jnp.asarray(data, dtype=complex)[i % 1, ...], wires=[i])
            
            return qml.expval(self.Observable)
        
        self.circuit = circuit
        self.vec_circuit = jax.vmap(circuit, (0, None), 0)
    
    def _init_params(self):
        """Initialize trainable parameters."""
        np.random.seed(self.seed)
        limit = np.sqrt(1.0 / (self.num_qubits * self.depth))
        weights = np.random.uniform(-limit, limit, (self.num_qubits, self.depth + 1, self.blocks))
        
        np.random.seed(self.seed + 1)
        alphas = np.random.uniform(0.3, 0.8, (self.num_qubits, self.depth + 1))
        
        self.params = {
            "params": {
                "weights": jnp.array(weights),
                "alphas": jnp.array(alphas),
                "epsilon": None
            }
        }
    
    def get_params(self):
        return self.params
    
    def set_params(self, params):
        self.params = params


# =============================================================================
# NON-EQUIVARIANT QML MODEL
# =============================================================================

class NonEquivariantQML:
    """
    Simple non-equivariant QNN for molecular properties.
    Uses basic rotations without symmetry preservation.
    """
    
    def __init__(self, num_qubits=4, depth=3, seed=42):
        self.num_qubits = num_qubits
        self.depth = depth
        self.seed = seed
        
        self.dev = qml.device("default.qubit", wires=num_qubits)
        self._create_circuit()
        self._init_params()
    
    def _create_circuit(self):
        """Create the simple QNN circuit."""
        num_qubits = self.num_qubits
        depth = self.depth
        
        @qml.qnode(self.dev, interface="jax", diff_method="backprop")
        def circuit(positions, params):
            weights = params["weights"]
            
            # Single feature: bond length
            dist = jnp.linalg.norm(positions[1] - positions[0])
            
            # Initialize
            for i in range(num_qubits):
                qml.RY(0.5, wires=i)
            
            # Simple layers
            for layer in range(depth):
                for i in range(num_qubits):
                    qml.RY(weights[layer, i, 0] * dist, wires=i)
                
                for i in range(num_qubits - 1):
                    qml.CNOT(wires=[i, i + 1])
                
                for i in range(num_qubits):
                    qml.RZ(weights[layer, i, 1], wires=i)
                    qml.RY(weights[layer, i, 2], wires=i)
            
            return qml.expval(qml.PauliZ(0) + qml.PauliZ(1) + qml.PauliZ(2) + qml.PauliZ(3))
        
        self.circuit = circuit
        self.vec_circuit = jax.vmap(circuit, (0, None), 0)
    
    def _init_params(self):
        """Initialize parameters."""
        np.random.seed(self.seed)
        weights = np.random.normal(0, 0.1, (self.depth, self.num_qubits, 3))
        self.params = {"weights": jnp.array(weights)}
    
    def get_params(self):
        return self.params
    
    def set_params(self, params):
        self.params = params


# =============================================================================
# TRAINING FUNCTIONS
# =============================================================================

def train_equivariant(model, data_train, E_train, F_train, data_test, E_test, F_test,
                      n_epochs=200, lr=0.01, lambda_E=1.5, lambda_F=2.0):
    """Train the equivariant model."""
    
    def energy_single(coords, params):
        return model.circuit(coords, params)
    
    def force_single(coords, params):
        grad_fn = jax.grad(energy_single, argnums=0)
        return -grad_fn(coords, params)
    
    vec_force = jax.vmap(force_single, (0, None), 0)
    
    @jax.jit
    def mse_loss(predictions, targets):
        return jnp.mean((predictions - targets) ** 2)
    
    @jax.jit
    def cost(params, data, E_target, F_target):
        E_pred = model.vec_circuit(data, params)
        E_loss = mse_loss(E_pred, E_target)
        
        F_pred = vec_force(data, params)
        F_loss = mse_loss(F_pred, F_target)
        
        total_loss = lambda_E * E_loss + lambda_F * F_loss
        return total_loss, (E_loss, F_loss)
    
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(model.params)
    
    history = {"epoch": [], "train_loss": [], "test_E_loss": [], "test_F_loss": []}
    
    for epoch in range(n_epochs):
        (loss, (E_loss, F_loss)), grads = jax.value_and_grad(cost, argnums=0, has_aux=True)(
            get_params(opt_state), data_train, E_train, F_train
        )
        opt_state = opt_update(epoch, grads, opt_state)
        
        if (epoch + 1) % max(1, n_epochs // 20) == 0:
            test_params = get_params(opt_state)
            E_pred_test = np.array(model.vec_circuit(data_test, test_params))
            F_pred_test = np.array(vec_force(data_test, test_params))
            
            E_test_loss = np.mean((E_pred_test - np.array(E_test)) ** 2)
            F_test_loss = np.mean((F_pred_test - np.array(F_test)) ** 2)
            
            history["epoch"].append(epoch + 1)
            history["train_loss"].append(float(loss))
            history["test_E_loss"].append(float(E_test_loss))
            history["test_F_loss"].append(float(F_test_loss))
    
    model.set_params(get_params(opt_state))
    return history


def train_non_equivariant(model, pos_train, E_train, F_train, pos_test, E_test, F_test,
                          n_epochs=200, lr=0.01, lambda_E=2.0, lambda_F=1.0):
    """Train the non-equivariant model."""
    
    def energy_single(coords, params):
        return model.circuit(coords, params)
    
    def force_single(coords, params):
        grad_fn = jax.grad(energy_single, argnums=0)
        return -grad_fn(coords, params)
    
    vec_force = jax.vmap(force_single, (0, None), 0)
    
    @jax.jit
    def combined_loss(params, positions, E_target, F_target):
        E_pred = model.vec_circuit(positions, params)
        E_loss = jnp.mean((E_pred - E_target) ** 2)
        
        F_pred_full = vec_force(positions, params)
        F_pred_z = F_pred_full[:, 1, 2]
        F_loss = jnp.mean((F_pred_z - F_target) ** 2)
        
        E_loss = jnp.where(jnp.isnan(E_loss), 1.0, E_loss)
        F_loss = jnp.where(jnp.isnan(F_loss), 1.0, F_loss)
        
        total_loss = lambda_E * E_loss + lambda_F * F_loss
        return total_loss, (E_loss, F_loss)
    
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(model.params)
    
    history = {"epoch": [], "train_loss": [], "test_E_loss": [], "test_F_loss": []}
    
    for epoch in range(n_epochs):
        (loss, (E_loss, F_loss)), grads = jax.value_and_grad(combined_loss, has_aux=True)(
            get_params(opt_state), pos_train, E_train, F_train
        )
        
        grad_norm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in jax.tree.leaves(grads) if g is not None))
        if grad_norm > 10.0:
            grads = jax.tree.map(lambda g: g * (10.0 / grad_norm) if g is not None else g, grads)
        
        opt_state = opt_update(epoch, grads, opt_state)
        
        if (epoch + 1) % max(1, n_epochs // 20) == 0:
            test_params = get_params(opt_state)
            E_pred_test = np.array(model.vec_circuit(pos_test, test_params))
            F_pred_test = np.array(vec_force(pos_test, test_params))[:, 1, 2]
            
            E_test_loss = np.mean((E_pred_test - np.array(E_test)) ** 2)
            F_test_loss = np.mean((F_pred_test - np.array(F_test)) ** 2)
            
            history["epoch"].append(epoch + 1)
            history["train_loss"].append(float(loss))
            history["test_E_loss"].append(float(E_test_loss))
            history["test_F_loss"].append(float(F_test_loss))
    
    model.set_params(get_params(opt_state))
    return history


# =============================================================================
# DATA LOADING
# =============================================================================

def load_lih_data(data_dir="eqnn_force_field_data_LiH"):
    """Load LiH molecular data."""
    energy = np.load(os.path.join(data_dir, "Energy.npy"))
    forces = np.load(os.path.join(data_dir, "Forces.npy"))
    positions = np.load(os.path.join(data_dir, "Positions.npy"))
    
    return energy, forces, positions


def prepare_data(energy, forces, positions, test_split=0.2, seed=42):
    """Prepare and scale data for training."""
    shape = positions.shape
    
    # Scale energy
    energy_scaler = MinMaxScaler((-1, 1))
    if energy.ndim == 1:
        energy = energy.reshape(-1, 1)
    energy_scaled = energy_scaler.fit_transform(energy).flatten()
    
    # Center molecule positions
    n_atoms_total = positions.shape[1]
    positions_centered = np.zeros((shape[0], n_atoms_total - 1, 3))
    positions_centered[:, 0, :] = positions[:, 1, :] - positions[:, 0, :]
    
    # Scale forces (z-component of H atom)
    forces_H = forces[:, 1:, :]
    force_scaler = MinMaxScaler((-1, 1))
    forces_z_only = forces_H[:, 0, 2].reshape(-1, 1)
    forces_z_scaled = force_scaler.fit_transform(forces_z_only).flatten()
    
    forces_scaled = np.zeros_like(forces_H)
    forces_scaled[:, 0, 2] = forces_z_scaled
    
    # Train/test split
    np.random.seed(seed)
    n_samples = shape[0]
    indices_train = np.random.choice(np.arange(n_samples), size=int((1-test_split) * n_samples), replace=False)
    indices_test = np.setdiff1d(np.arange(n_samples), indices_train)
    
    data = {
        "energy_scaler": energy_scaler,
        "force_scaler": force_scaler,
        "energy_scaled": energy_scaled,
        "forces_scaled": forces_scaled,
        "positions_centered": positions_centered,
        "positions_raw": positions,
        "forces_H": forces_H,
        "indices_train": indices_train,
        "indices_test": indices_test,
    }
    
    return data


# =============================================================================
# EVALUATION
# =============================================================================

def evaluate_model(model, data, model_type="equivariant"):
    """Evaluate model and compute metrics."""
    
    positions_centered = data["positions_centered"]
    positions_raw = data["positions_raw"]
    energy_scaled = data["energy_scaled"]
    forces_z_scaled = data["forces_scaled"][:, 0, 2]
    energy_scaler = data["energy_scaler"]
    force_scaler = data["force_scaler"]
    indices_train = data["indices_train"]
    indices_test = data["indices_test"]
    forces_H = data["forces_H"]
    
    if model_type == "equivariant":
        # Get predictions
        E_pred_scaled = np.array(model.vec_circuit(jnp.array(positions_centered), model.params))
        
        def energy_single(coords, params):
            return model.circuit(coords, params)
        def force_single(coords, params):
            return -jax.grad(energy_single, argnums=0)(coords, params)
        vec_force = jax.vmap(force_single, (0, None), 0)
        
        F_pred_scaled = np.array(vec_force(jnp.array(positions_centered), model.params))
        F_pred_z_scaled = F_pred_scaled[:, 0, 2]
    else:
        E_pred_scaled = np.array(model.vec_circuit(jnp.array(positions_raw), model.params))
        
        def energy_single(coords, params):
            return model.circuit(coords, params)
        def force_single(coords, params):
            return -jax.grad(energy_single, argnums=0)(coords, params)
        vec_force = jax.vmap(force_single, (0, None), 0)
        
        F_pred_all = np.array(vec_force(jnp.array(positions_raw), model.params))
        F_pred_z_scaled = F_pred_all[:, 1, 2]
    
    # Post-correction for energy
    def corr_E(E, a, b, c):
        return a * E**2 + b * E + c
    
    try:
        popt_E, _ = curve_fit(corr_E, E_pred_scaled[indices_train], energy_scaled[indices_train])
        E_pred_corrected = corr_E(E_pred_scaled, *popt_E)
    except:
        E_pred_corrected = E_pred_scaled
    
    # Post-correction for force
    try:
        lr_model = LinearRegression()
        lr_model.fit(F_pred_z_scaled[indices_train].reshape(-1, 1), forces_z_scaled[indices_train])
        F_pred_corrected = lr_model.predict(F_pred_z_scaled.reshape(-1, 1)).flatten()
    except:
        F_pred_corrected = F_pred_z_scaled
    
    # Inverse transform
    E_pred_original = energy_scaler.inverse_transform(E_pred_corrected.reshape(-1, 1)).flatten()
    F_pred_original = force_scaler.inverse_transform(F_pred_corrected.reshape(-1, 1)).flatten()
    
    E_true_original = energy_scaler.inverse_transform(energy_scaled.reshape(-1, 1)).flatten()
    F_true_original = forces_H[:, 0, 2]
    
    # Compute metrics on test set
    E_mae = np.mean(np.abs(E_pred_original[indices_test] - E_true_original[indices_test]))
    E_rmse = np.sqrt(np.mean((E_pred_original[indices_test] - E_true_original[indices_test]) ** 2))
    E_r2 = 1 - np.sum((E_pred_original[indices_test] - E_true_original[indices_test])**2) / \
               np.sum((E_true_original[indices_test] - E_true_original[indices_test].mean())**2)
    
    F_mae = np.mean(np.abs(F_pred_original[indices_test] - F_true_original[indices_test]))
    F_rmse = np.sqrt(np.mean((F_pred_original[indices_test] - F_true_original[indices_test]) ** 2))
    F_r2 = 1 - np.sum((F_pred_original[indices_test] - F_true_original[indices_test])**2) / \
               np.sum((F_true_original[indices_test] - F_true_original[indices_test].mean())**2)
    
    metrics = {
        "E_mae_Ha": float(E_mae),
        "E_mae_eV": float(E_mae * 27.2114),
        "E_rmse_Ha": float(E_rmse),
        "E_rmse_eV": float(E_rmse * 27.2114),
        "E_r2": float(E_r2),
        "F_mae": float(F_mae),
        "F_rmse": float(F_rmse),
        "F_r2": float(F_r2),
    }
    
    predictions = {
        "E_pred": E_pred_original.tolist(),
        "E_true": E_true_original.tolist(),
        "F_pred": F_pred_original.tolist(),
        "F_true": F_true_original.tolist(),
        "indices_test": indices_test.tolist(),
    }
    
    return metrics, predictions


# =============================================================================
# MAIN COMPARISON
# =============================================================================

def run_comparison(n_runs=3, n_epochs=100, output_dir="lih_comparison_results", data_dir="eqnn_force_field_data_LiH"):
    """Run comparison between equivariant and non-equivariant models."""
    
    os.makedirs(output_dir, exist_ok=True)
    
    print("="*70)
    print("LiH Energy/Force Prediction: Equivariant vs Non-Equivariant QML")
    print("="*70)
    
    # Load data
    print("\nLoading data...")
    try:
        energy, forces, positions = load_lih_data(data_dir)
        print(f"  Loaded {len(energy)} samples")
    except FileNotFoundError:
        print(f"  ERROR: Data not found in {data_dir}")
        print("  Please ensure the LiH data files are available")
        return None
    
    # Prepare data
    data = prepare_data(energy, forces, positions)
    print(f"  Train: {len(data['indices_train'])}, Test: {len(data['indices_test'])}")
    
    # Results storage
    all_results = {
        "config": {
            "n_runs": n_runs,
            "n_epochs": n_epochs,
            "timestamp": datetime.now().isoformat(),
        },
        "equivariant": {"runs": [], "metrics_summary": {}},
        "non_equivariant": {"runs": [], "metrics_summary": {}},
    }
    
    # Prepare training data
    E_train = data["energy_scaled"][data["indices_train"]]
    E_test = data["energy_scaled"][data["indices_test"]]
    
    # For equivariant model
    data_train_eq = jnp.array(data["positions_centered"][data["indices_train"]])
    data_test_eq = jnp.array(data["positions_centered"][data["indices_test"]])
    F_train_eq = data["forces_scaled"][data["indices_train"]]
    F_test_eq = data["forces_scaled"][data["indices_test"]]
    
    # For non-equivariant model
    pos_train_neq = jnp.array(data["positions_raw"][data["indices_train"]])
    pos_test_neq = jnp.array(data["positions_raw"][data["indices_test"]])
    F_train_neq = data["forces_scaled"][data["indices_train"], 0, 2]
    F_test_neq = data["forces_scaled"][data["indices_test"], 0, 2]
    
    # Run experiments
    for run in range(n_runs):
        print(f"\n{'='*70}")
        print(f"RUN {run+1}/{n_runs}")
        print(f"{'='*70}")
        
        run_seed = 42 + run * 100
        
        # --- Equivariant Model ---
        print(f"\n[Equivariant Model]")
        eq_model = EquivariantQML(num_qubits=3, depth=6, blocks=2, seed=run_seed)
        
        print(f"  Training for {n_epochs} epochs...")
        eq_history = train_equivariant(
            eq_model, data_train_eq, E_train, F_train_eq,
            data_test_eq, E_test, F_test_eq, n_epochs=n_epochs
        )
        
        print(f"  Evaluating...")
        eq_metrics, eq_predictions = evaluate_model(eq_model, data, "equivariant")
        
        print(f"  Energy: MAE={eq_metrics['E_mae_Ha']:.6f} Ha, R²={eq_metrics['E_r2']:.4f}")
        print(f"  Force:  MAE={eq_metrics['F_mae']:.4f} eV/Å, R²={eq_metrics['F_r2']:.4f}")
        
        all_results["equivariant"]["runs"].append({
            "run_id": run,
            "seed": run_seed,
            "history": eq_history,
            "metrics": eq_metrics,
            "predictions": eq_predictions,
        })
        
        # --- Non-Equivariant Model ---
        print(f"\n[Non-Equivariant Model]")
        neq_model = NonEquivariantQML(num_qubits=4, depth=3, seed=run_seed)
        
        print(f"  Training for {n_epochs} epochs...")
        neq_history = train_non_equivariant(
            neq_model, pos_train_neq, E_train, F_train_neq,
            pos_test_neq, E_test, F_test_neq, n_epochs=n_epochs
        )
        
        print(f"  Evaluating...")
        neq_metrics, neq_predictions = evaluate_model(neq_model, data, "non_equivariant")
        
        print(f"  Energy: MAE={neq_metrics['E_mae_Ha']:.6f} Ha, R²={neq_metrics['E_r2']:.4f}")
        print(f"  Force:  MAE={neq_metrics['F_mae']:.4f} eV/Å, R²={neq_metrics['F_r2']:.4f}")
        
        all_results["non_equivariant"]["runs"].append({
            "run_id": run,
            "seed": run_seed,
            "history": neq_history,
            "metrics": neq_metrics,
            "predictions": neq_predictions,
        })
    
    # Compute summary statistics
    for model_type in ["equivariant", "non_equivariant"]:
        metrics_list = [r["metrics"] for r in all_results[model_type]["runs"]]
        
        summary = {}
        for key in metrics_list[0].keys():
            values = [m[key] for m in metrics_list]
            summary[key] = {
                "mean": float(np.mean(values)),
                "std": float(np.std(values)),
                "min": float(np.min(values)),
                "max": float(np.max(values)),
                "values": values,
            }
        all_results[model_type]["metrics_summary"] = summary
    
    # Save results
    results_path = os.path.join(output_dir, "results.json")
    with open(results_path, 'w') as f:
        json.dump(all_results, f, indent=2)
    print(f"\nResults saved to: {results_path}")
    
    # Save numpy arrays for easy loading
    np.savez(os.path.join(output_dir, "metrics.npz"),
             eq_E_r2=[r["metrics"]["E_r2"] for r in all_results["equivariant"]["runs"]],
             eq_F_r2=[r["metrics"]["F_r2"] for r in all_results["equivariant"]["runs"]],
             eq_E_mae=[r["metrics"]["E_mae_Ha"] for r in all_results["equivariant"]["runs"]],
             eq_F_mae=[r["metrics"]["F_mae"] for r in all_results["equivariant"]["runs"]],
             neq_E_r2=[r["metrics"]["E_r2"] for r in all_results["non_equivariant"]["runs"]],
             neq_F_r2=[r["metrics"]["F_r2"] for r in all_results["non_equivariant"]["runs"]],
             neq_E_mae=[r["metrics"]["E_mae_Ha"] for r in all_results["non_equivariant"]["runs"]],
             neq_F_mae=[r["metrics"]["F_mae"] for r in all_results["non_equivariant"]["runs"]])
    
    # Print summary
    print("\n" + "="*70)
    print("SUMMARY")
    print("="*70)
    print(f"\n{'Metric':<20} {'Equivariant':<25} {'Non-Equivariant':<25}")
    print("-"*70)
    
    for metric in ["E_r2", "E_mae_Ha", "F_r2", "F_mae"]:
        eq_mean = all_results["equivariant"]["metrics_summary"][metric]["mean"]
        eq_std = all_results["equivariant"]["metrics_summary"][metric]["std"]
        neq_mean = all_results["non_equivariant"]["metrics_summary"][metric]["mean"]
        neq_std = all_results["non_equivariant"]["metrics_summary"][metric]["std"]
        
        print(f"{metric:<20} {eq_mean:.4f} ± {eq_std:.4f}       {neq_mean:.4f} ± {neq_std:.4f}")
    
    print("="*70)
    
    return all_results


# =============================================================================
# MAIN
# =============================================================================

def main(n_runs=2, n_epochs=50, output_dir="lih_comparison_results", data_dir="eqnn_force_field_data_LiH"):
    """
    Main function - can be called directly from Jupyter or command line.
    
    Args:
        n_runs: Number of runs for each model
        n_epochs: Training epochs per run
        output_dir: Directory to save results
        data_dir: Directory containing LiH data (.npy files)
    
    Returns:
        results dictionary
    """
    return run_comparison(
        n_runs=n_runs,
        n_epochs=n_epochs,
        output_dir=output_dir,
        data_dir=data_dir
    )


if __name__ == "__main__":
    import sys
    
    # Check if running in Jupyter
    if 'ipykernel' in sys.modules:
        # Running in Jupyter - use defaults or call main() directly
        print("Running in Jupyter notebook. Call main() directly with parameters:")
        print("  results = main(n_runs=2, n_epochs=50, output_dir='lih_results', data_dir='eqnn_force_field_data_LiH')")
    else:
        # Running as script - use argparse
        parser = argparse.ArgumentParser(description="Compare Equivariant vs Non-Equivariant QML on LiH")
        parser.add_argument("--n_runs", type=int, default=2, help="Number of runs")
        parser.add_argument("--n_epochs", type=int, default=50, help="Training epochs per run")
        parser.add_argument("--output_dir", type=str, default="lih_comparison_results", help="Output directory")
        parser.add_argument("--data_dir", type=str, default="eqnn_force_field_data_LiH", help="Data directory")
        
        args = parser.parse_args()
        
        results = main(
            n_runs=args.n_runs,
            n_epochs=args.n_epochs,
            output_dir=args.output_dir,
            data_dir=args.data_dir
        )



Running in Jupyter notebook. Call main() directly with parameters:
  results = main(n_runs=2, n_epochs=50, output_dir='lih_results', data_dir='eqnn_force_field_data_LiH')


In [2]:
results = main(n_runs=1, n_epochs=200, output_dir='lih_results', data_dir='eqnn_force_field_data_LiH')



LiH Energy/Force Prediction: Equivariant vs Non-Equivariant QML

Loading data...
  Loaded 2400 samples
  Train: 1920, Test: 480

RUN 1/1

[Equivariant Model]
  Training for 200 epochs...
  Evaluating...
  Energy: MAE=0.033382 Ha, R²=0.9966
  Force:  MAE=3.2999 eV/Å, R²=0.9239

[Non-Equivariant Model]
  Training for 200 epochs...
  Evaluating...
  Energy: MAE=0.026232 Ha, R²=0.9979
  Force:  MAE=2.5130 eV/Å, R²=0.9584

Results saved to: lih_results/results.json

SUMMARY

Metric               Equivariant               Non-Equivariant          
----------------------------------------------------------------------
E_r2                 0.9966 ± 0.0000       0.9979 ± 0.0000
E_mae_Ha             0.0334 ± 0.0000       0.0262 ± 0.0000
F_r2                 0.9239 ± 0.0000       0.9584 ± 0.0000
F_mae                3.2999 ± 0.0000       2.5130 ± 0.0000


In [None]:
"""
Comparison of Graph Equivariant vs Non-Equivariant QML for LiH Energy/Force Prediction

This script runs both methods on LiH molecular data and compares their performance.
All results are saved to an output directory for later analysis.

Methods compared:
1. Rotationally Equivariant QML - Uses SO(3) equivariant encoding with Heisenberg observable
2. Graph Embedding Equivariant QML - Simple QNN with basic rotations

Usage:
    python run_comparison.py --n_runs 3 --n_epochs 100 --output_dir results
"""

import pennylane as qml
import numpy as np
import json
import os
from datetime import datetime
import argparse

import jax
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)
from jax import numpy as jnp

from jax.example_libraries import optimizers
from sklearn.preprocessing import MinMaxScaler
from scipy.optimize import curve_fit
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# =============================================================================
# ROTATIONALLY EQUIVARIANT QML MODEL
# =============================================================================

class EquivariantQML:
    """
    Rotationally Equivariant Quantum Machine Learning model for molecular properties.
    Uses SO(3) equivariant encoding with Heisenberg Hamiltonian observable.
    """
    
    def __init__(self, num_qubits=3, depth=6, blocks=2, seed=42):
        self.num_qubits = num_qubits
        self.depth = depth
        self.blocks = blocks
        self.seed = seed
        
        # Pauli matrices
        self.X = np.array([[0, 1], [1, 0]])
        self.Y = np.array([[0, -1.0j], [1.0j, 0]])
        self.Z = np.array([[1, 0], [0, -1]])
        
        self.sigmas = jnp.array(np.array([self.X, self.Y, self.Z]))
        self.sigmas_sigmas = jnp.array(np.array([
            np.kron(self.X, self.X),
            np.kron(self.Y, self.Y),
            np.kron(self.Z, self.Z)
        ]))
        
        # Create device and circuit
        self.dev = qml.device("default.qubit", wires=num_qubits)
        self._create_circuit()
        
        # Initialize parameters
        self._init_params()
    
    def _create_circuit(self):
        """Create the equivariant quantum circuit."""
        num_qubits = self.num_qubits
        depth = self.depth
        blocks = self.blocks
        sigmas = self.sigmas
        sigmas_sigmas = self.sigmas_sigmas
        
        # Heisenberg observable
        Heisenberg = [
            qml.PauliX(0) @ qml.PauliX(1),
            qml.PauliY(0) @ qml.PauliY(1),
            qml.PauliZ(0) @ qml.PauliZ(1),
        ]
        self.Observable = qml.Hamiltonian(np.ones((3)), Heisenberg)
        
        def singlet(wires):
            qml.Hadamard(wires=wires[0])
            qml.PauliZ(wires=wires[0])
            qml.PauliX(wires=wires[1])
            qml.CNOT(wires=wires)
        
        def equivariant_encoding(alpha, data, wires):
            hamiltonian = jnp.einsum("i,ijk", data, sigmas)
            U = jax.scipy.linalg.expm(-1.0j * alpha * hamiltonian / 2)
            qml.QubitUnitary(U, wires=wires, id="E")
        
        def trainable_layer(weight, wires):
            hamiltonian = jnp.einsum("ijk->jk", sigmas_sigmas)
            U = jax.scipy.linalg.expm(-1.0j * weight * hamiltonian)
            qml.QubitUnitary(U, wires=wires, id="U")
        
        @qml.qnode(self.dev, interface="jax")
        def circuit(data, params):
            weights = params["params"]["weights"]
            alphas = params["params"]["alphas"]
            
            # Initial entangled state
            if num_qubits >= 2:
                singlet(wires=[0, 1])
            if num_qubits >= 3:
                qml.CNOT(wires=[1, 2])
            
            # Initial encoding
            for i in range(num_qubits):
                equivariant_encoding(alphas[i, 0], jnp.asarray(data, dtype=complex)[i % 1, ...], wires=[i])
            
            # Layers
            for d in range(depth):
                qml.Barrier()
                for b in range(blocks):
                    for i in range(0, num_qubits - 1, 2):
                        trainable_layer(weights[i, d + 1, b], wires=[i, (i + 1) % num_qubits])
                    for i in range(1, num_qubits, 2):
                        trainable_layer(weights[i, d + 1, b], wires=[i, (i + 1) % num_qubits])
                
                for i in range(num_qubits):
                    equivariant_encoding(alphas[i, d + 1], jnp.asarray(data, dtype=complex)[i % 1, ...], wires=[i])
            
            return qml.expval(self.Observable)
        
        self.circuit = circuit
        self.vec_circuit = jax.vmap(circuit, (0, None), 0)
    
    def _init_params(self):
        """Initialize trainable parameters."""
        np.random.seed(self.seed)
        limit = np.sqrt(1.0 / (self.num_qubits * self.depth))
        weights = np.random.uniform(-limit, limit, (self.num_qubits, self.depth + 1, self.blocks))
        
        np.random.seed(self.seed + 1)
        alphas = np.random.uniform(0.3, 0.8, (self.num_qubits, self.depth + 1))
        
        self.params = {
            "params": {
                "weights": jnp.array(weights),
                "alphas": jnp.array(alphas),
                "epsilon": None
            }
        }
    
    def get_params(self):
        return self.params
    
    def set_params(self, params):
        self.params = params


# =============================================================================
# GRAPH EMBEDDING EQUIVARIANT QML MODEL
# =============================================================================

class NonEquivariantQML:
    """
    Graph Embedding Equivariant QNN for molecular properties.
    Uses basic rotations without symmetry preservation.
    """
    
    def __init__(self, num_qubits=4, depth=3, seed=42):
        self.num_qubits = num_qubits
        self.depth = depth
        self.seed = seed
        
        self.dev = qml.device("default.qubit", wires=num_qubits)
        self._create_circuit()
        self._init_params()
    
    def _create_circuit(self):
        """Create the simple QNN circuit."""
        num_qubits = self.num_qubits
        depth = self.depth
        
        @qml.qnode(self.dev, interface="jax", diff_method="backprop")
        def circuit(positions, params):
            weights = params["weights"]
            
            # Single feature: bond length
            dist = jnp.linalg.norm(positions[1] - positions[0])
            
            # Initialize
            for i in range(num_qubits):
                qml.RY(0.5, wires=i)
            
            # Simple layers
            for layer in range(depth):
                for i in range(num_qubits):
                    qml.RY(weights[layer, i, 0] * dist, wires=i)
                
                for i in range(num_qubits - 1):
                    qml.CNOT(wires=[i, i + 1])
                
                for i in range(num_qubits):
                    qml.RZ(weights[layer, i, 1], wires=i)
                    qml.RY(weights[layer, i, 2], wires=i)
            
            return qml.expval(qml.PauliZ(0) + qml.PauliZ(1) + qml.PauliZ(2) + qml.PauliZ(3))
        
        self.circuit = circuit
        self.vec_circuit = jax.vmap(circuit, (0, None), 0)
    
    def _init_params(self):
        """Initialize parameters."""
        np.random.seed(self.seed)
        weights = np.random.normal(0, 0.1, (self.depth, self.num_qubits, 3))
        self.params = {"weights": jnp.array(weights)}
    
    def get_params(self):
        return self.params
    
    def set_params(self, params):
        self.params = params


# =============================================================================
# TRAINING FUNCTIONS
# =============================================================================

def train_equivariant(model, data_train, E_train, F_train, data_test, E_test, F_test,
                      n_epochs=200, lr=0.01, lambda_E=1.5, lambda_F=2.0):
    """Train the equivariant model."""
    
    def energy_single(coords, params):
        return model.circuit(coords, params)
    
    def force_single(coords, params):
        grad_fn = jax.grad(energy_single, argnums=0)
        return -grad_fn(coords, params)
    
    vec_force = jax.vmap(force_single, (0, None), 0)
    
    @jax.jit
    def mse_loss(predictions, targets):
        return jnp.mean((predictions - targets) ** 2)
    
    @jax.jit
    def cost(params, data, E_target, F_target):
        E_pred = model.vec_circuit(data, params)
        E_loss = mse_loss(E_pred, E_target)
        
        F_pred = vec_force(data, params)
        F_loss = mse_loss(F_pred, F_target)
        
        total_loss = lambda_E * E_loss + lambda_F * F_loss
        return total_loss, (E_loss, F_loss)
    
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(model.params)
    
    history = {"epoch": [], "train_loss": [], "test_E_loss": [], "test_F_loss": []}
    
    for epoch in range(n_epochs):
        (loss, (E_loss, F_loss)), grads = jax.value_and_grad(cost, argnums=0, has_aux=True)(
            get_params(opt_state), data_train, E_train, F_train
        )
        opt_state = opt_update(epoch, grads, opt_state)
        
        if (epoch + 1) % max(1, n_epochs // 20) == 0:
            test_params = get_params(opt_state)
            E_pred_test = np.array(model.vec_circuit(data_test, test_params))
            F_pred_test = np.array(vec_force(data_test, test_params))
            
            E_test_loss = np.mean((E_pred_test - np.array(E_test)) ** 2)
            F_test_loss = np.mean((F_pred_test - np.array(F_test)) ** 2)
            
            history["epoch"].append(epoch + 1)
            history["train_loss"].append(float(loss))
            history["test_E_loss"].append(float(E_test_loss))
            history["test_F_loss"].append(float(F_test_loss))
    
    model.set_params(get_params(opt_state))
    return history


def train_non_equivariant(model, pos_train, E_train, F_train, pos_test, E_test, F_test,
                          n_epochs=200, lr=0.01, lambda_E=2.0, lambda_F=1.0):
    """Train the non-equivariant model."""
    
    def energy_single(coords, params):
        return model.circuit(coords, params)
    
    def force_single(coords, params):
        grad_fn = jax.grad(energy_single, argnums=0)
        return -grad_fn(coords, params)
    
    vec_force = jax.vmap(force_single, (0, None), 0)
    
    @jax.jit
    def combined_loss(params, positions, E_target, F_target):
        E_pred = model.vec_circuit(positions, params)
        E_loss = jnp.mean((E_pred - E_target) ** 2)
        
        F_pred_full = vec_force(positions, params)
        F_pred_z = F_pred_full[:, 1, 2]
        F_loss = jnp.mean((F_pred_z - F_target) ** 2)
        
        E_loss = jnp.where(jnp.isnan(E_loss), 1.0, E_loss)
        F_loss = jnp.where(jnp.isnan(F_loss), 1.0, F_loss)
        
        total_loss = lambda_E * E_loss + lambda_F * F_loss
        return total_loss, (E_loss, F_loss)
    
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(model.params)
    
    history = {"epoch": [], "train_loss": [], "test_E_loss": [], "test_F_loss": []}
    
    for epoch in range(n_epochs):
        (loss, (E_loss, F_loss)), grads = jax.value_and_grad(combined_loss, has_aux=True)(
            get_params(opt_state), pos_train, E_train, F_train
        )
        
        grad_norm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in jax.tree.leaves(grads) if g is not None))
        if grad_norm > 10.0:
            grads = jax.tree.map(lambda g: g * (10.0 / grad_norm) if g is not None else g, grads)
        
        opt_state = opt_update(epoch, grads, opt_state)
        
        if (epoch + 1) % max(1, n_epochs // 20) == 0:
            test_params = get_params(opt_state)
            E_pred_test = np.array(model.vec_circuit(pos_test, test_params))
            F_pred_test = np.array(vec_force(pos_test, test_params))[:, 1, 2]
            
            E_test_loss = np.mean((E_pred_test - np.array(E_test)) ** 2)
            F_test_loss = np.mean((F_pred_test - np.array(F_test)) ** 2)
            
            history["epoch"].append(epoch + 1)
            history["train_loss"].append(float(loss))
            history["test_E_loss"].append(float(E_test_loss))
            history["test_F_loss"].append(float(F_test_loss))
    
    model.set_params(get_params(opt_state))
    return history


# =============================================================================
# DATA LOADING
# =============================================================================

def load_lih_data(data_dir="eqnn_force_field_data_LiH"):
    """Load LiH molecular data."""
    energy = np.load(os.path.join(data_dir, "Energy.npy"))
    forces = np.load(os.path.join(data_dir, "Forces.npy"))
    positions = np.load(os.path.join(data_dir, "Positions.npy"))
    
    return energy, forces, positions


def prepare_data(energy, forces, positions, test_split=0.2, seed=42):
    """Prepare and scale data for training."""
    shape = positions.shape
    
    # Scale energy
    energy_scaler = MinMaxScaler((-1, 1))
    if energy.ndim == 1:
        energy = energy.reshape(-1, 1)
    energy_scaled = energy_scaler.fit_transform(energy).flatten()
    
    # Center molecule positions
    n_atoms_total = positions.shape[1]
    positions_centered = np.zeros((shape[0], n_atoms_total - 1, 3))
    positions_centered[:, 0, :] = positions[:, 1, :] - positions[:, 0, :]
    
    # Scale forces (z-component of H atom)
    forces_H = forces[:, 1:, :]
    force_scaler = MinMaxScaler((-1, 1))
    forces_z_only = forces_H[:, 0, 2].reshape(-1, 1)
    forces_z_scaled = force_scaler.fit_transform(forces_z_only).flatten()
    
    forces_scaled = np.zeros_like(forces_H)
    forces_scaled[:, 0, 2] = forces_z_scaled
    
    # Train/test split
    np.random.seed(seed)
    n_samples = shape[0]
    indices_train = np.random.choice(np.arange(n_samples), size=int((1-test_split) * n_samples), replace=False)
    indices_test = np.setdiff1d(np.arange(n_samples), indices_train)
    
    data = {
        "energy_scaler": energy_scaler,
        "force_scaler": force_scaler,
        "energy_scaled": energy_scaled,
        "forces_scaled": forces_scaled,
        "positions_centered": positions_centered,
        "positions_raw": positions,
        "forces_H": forces_H,
        "indices_train": indices_train,
        "indices_test": indices_test,
    }
    
    return data


# =============================================================================
# EVALUATION
# =============================================================================

def evaluate_model(model, data, model_type="equivariant"):
    """Evaluate model and compute metrics."""
    
    positions_centered = data["positions_centered"]
    positions_raw = data["positions_raw"]
    energy_scaled = data["energy_scaled"]
    forces_z_scaled = data["forces_scaled"][:, 0, 2]
    energy_scaler = data["energy_scaler"]
    force_scaler = data["force_scaler"]
    indices_train = data["indices_train"]
    indices_test = data["indices_test"]
    forces_H = data["forces_H"]
    
    if model_type == "equivariant":
        # Get predictions
        E_pred_scaled = np.array(model.vec_circuit(jnp.array(positions_centered), model.params))
        
        def energy_single(coords, params):
            return model.circuit(coords, params)
        def force_single(coords, params):
            return -jax.grad(energy_single, argnums=0)(coords, params)
        vec_force = jax.vmap(force_single, (0, None), 0)
        
        F_pred_scaled = np.array(vec_force(jnp.array(positions_centered), model.params))
        F_pred_z_scaled = F_pred_scaled[:, 0, 2]
    else:
        E_pred_scaled = np.array(model.vec_circuit(jnp.array(positions_raw), model.params))
        
        def energy_single(coords, params):
            return model.circuit(coords, params)
        def force_single(coords, params):
            return -jax.grad(energy_single, argnums=0)(coords, params)
        vec_force = jax.vmap(force_single, (0, None), 0)
        
        F_pred_all = np.array(vec_force(jnp.array(positions_raw), model.params))
        F_pred_z_scaled = F_pred_all[:, 1, 2]
    
    # Post-correction for energy
    def corr_E(E, a, b, c):
        return a * E**2 + b * E + c
    
    try:
        popt_E, _ = curve_fit(corr_E, E_pred_scaled[indices_train], energy_scaled[indices_train])
        E_pred_corrected = corr_E(E_pred_scaled, *popt_E)
    except:
        E_pred_corrected = E_pred_scaled
    
    # Post-correction for force
    try:
        lr_model = LinearRegression()
        lr_model.fit(F_pred_z_scaled[indices_train].reshape(-1, 1), forces_z_scaled[indices_train])
        F_pred_corrected = lr_model.predict(F_pred_z_scaled.reshape(-1, 1)).flatten()
    except:
        F_pred_corrected = F_pred_z_scaled
    
    # Inverse transform
    E_pred_original = energy_scaler.inverse_transform(E_pred_corrected.reshape(-1, 1)).flatten()
    F_pred_original = force_scaler.inverse_transform(F_pred_corrected.reshape(-1, 1)).flatten()
    
    E_true_original = energy_scaler.inverse_transform(energy_scaled.reshape(-1, 1)).flatten()
    F_true_original = forces_H[:, 0, 2]
    
    # Compute metrics on test set
    E_mae = np.mean(np.abs(E_pred_original[indices_test] - E_true_original[indices_test]))
    E_rmse = np.sqrt(np.mean((E_pred_original[indices_test] - E_true_original[indices_test]) ** 2))
    E_r2 = 1 - np.sum((E_pred_original[indices_test] - E_true_original[indices_test])**2) / \
               np.sum((E_true_original[indices_test] - E_true_original[indices_test].mean())**2)
    
    F_mae = np.mean(np.abs(F_pred_original[indices_test] - F_true_original[indices_test]))
    F_rmse = np.sqrt(np.mean((F_pred_original[indices_test] - F_true_original[indices_test]) ** 2))
    F_r2 = 1 - np.sum((F_pred_original[indices_test] - F_true_original[indices_test])**2) / \
               np.sum((F_true_original[indices_test] - F_true_original[indices_test].mean())**2)
    
    metrics = {
        "E_mae_Ha": float(E_mae),
        "E_mae_eV": float(E_mae * 27.2114),
        "E_rmse_Ha": float(E_rmse),
        "E_rmse_eV": float(E_rmse * 27.2114),
        "E_r2": float(E_r2),
        "F_mae": float(F_mae),
        "F_rmse": float(F_rmse),
        "F_r2": float(F_r2),
    }
    
    predictions = {
        "E_pred": E_pred_original.tolist(),
        "E_true": E_true_original.tolist(),
        "F_pred": F_pred_original.tolist(),
        "F_true": F_true_original.tolist(),
        "indices_test": indices_test.tolist(),
    }
    
    return metrics, predictions


# =============================================================================
# MAIN COMPARISON
# =============================================================================

def run_comparison(n_runs=3, n_epochs=100, output_dir="lih_comparison_results", data_dir="eqnn_force_field_data_LiH"):
    """Run comparison between equivariant and non-equivariant models."""
    
    os.makedirs(output_dir, exist_ok=True)
    
    print("="*70)
    print("LiH Energy/Force Prediction: Rotationally Equivariant vs Graph Embedding Equivariant QML")
    print("="*70)
    
    # Load data
    print("\nLoading data...")
    try:
        energy, forces, positions = load_lih_data(data_dir)
        print(f"  Loaded {len(energy)} samples")
    except FileNotFoundError:
        print(f"  ERROR: Data not found in {data_dir}")
        print("  Please ensure the LiH data files are available")
        return None
    
    # Prepare data
    data = prepare_data(energy, forces, positions)
    print(f"  Train: {len(data['indices_train'])}, Test: {len(data['indices_test'])}")
    
    # Results storage
    all_results = {
        "config": {
            "n_runs": n_runs,
            "n_epochs": n_epochs,
            "timestamp": datetime.now().isoformat(),
        },
        "equivariant": {"runs": [], "metrics_summary": {}},
        "non_equivariant": {"runs": [], "metrics_summary": {}},
    }
    
    # Prepare training data
    E_train = data["energy_scaled"][data["indices_train"]]
    E_test = data["energy_scaled"][data["indices_test"]]
    
    # For equivariant model
    data_train_eq = jnp.array(data["positions_centered"][data["indices_train"]])
    data_test_eq = jnp.array(data["positions_centered"][data["indices_test"]])
    F_train_eq = data["forces_scaled"][data["indices_train"]]
    F_test_eq = data["forces_scaled"][data["indices_test"]]
    
    # For non-equivariant model
    pos_train_neq = jnp.array(data["positions_raw"][data["indices_train"]])
    pos_test_neq = jnp.array(data["positions_raw"][data["indices_test"]])
    F_train_neq = data["forces_scaled"][data["indices_train"], 0, 2]
    F_test_neq = data["forces_scaled"][data["indices_test"], 0, 2]
    
    # Run experiments
    for run in range(n_runs):
        print(f"\n{'='*70}")
        print(f"RUN {run+1}/{n_runs}")
        print(f"{'='*70}")
        
        run_seed = 42 + run * 100
        
        # --- Rotationally Equivariant Model ---
        print(f"\n[Rotationally Equivariant QML]")
        eq_model = EquivariantQML(num_qubits=3, depth=6, blocks=2, seed=run_seed)
        
        print(f"  Training for {n_epochs} epochs...")
        eq_history = train_equivariant(
            eq_model, data_train_eq, E_train, F_train_eq,
            data_test_eq, E_test, F_test_eq, n_epochs=n_epochs
        )
        
        print(f"  Evaluating...")
        eq_metrics, eq_predictions = evaluate_model(eq_model, data, "equivariant")
        
        print(f"  Energy: MAE={eq_metrics['E_mae_Ha']:.6f} Ha, R²={eq_metrics['E_r2']:.4f}")
        print(f"  Force:  MAE={eq_metrics['F_mae']:.4f} eV/Å, R²={eq_metrics['F_r2']:.4f}")
        
        all_results["equivariant"]["runs"].append({
            "run_id": run,
            "seed": run_seed,
            "history": eq_history,
            "metrics": eq_metrics,
            "predictions": eq_predictions,
        })
        
        # --- Graph Embedding Equivariant Model ---
        print(f"\n[Graph Embedding Equivariant QML]")
        neq_model = NonEquivariantQML(num_qubits=4, depth=3, seed=run_seed)
        
        print(f"  Training for {n_epochs} epochs...")
        neq_history = train_non_equivariant(
            neq_model, pos_train_neq, E_train, F_train_neq,
            pos_test_neq, E_test, F_test_neq, n_epochs=n_epochs
        )
        
        print(f"  Evaluating...")
        neq_metrics, neq_predictions = evaluate_model(neq_model, data, "non_equivariant")
        
        print(f"  Energy: MAE={neq_metrics['E_mae_Ha']:.6f} Ha, R²={neq_metrics['E_r2']:.4f}")
        print(f"  Force:  MAE={neq_metrics['F_mae']:.4f} eV/Å, R²={neq_metrics['F_r2']:.4f}")
        
        all_results["non_equivariant"]["runs"].append({
            "run_id": run,
            "seed": run_seed,
            "history": neq_history,
            "metrics": neq_metrics,
            "predictions": neq_predictions,
        })
    
    # Compute summary statistics
    for model_type in ["equivariant", "non_equivariant"]:
        metrics_list = [r["metrics"] for r in all_results[model_type]["runs"]]
        
        summary = {}
        for key in metrics_list[0].keys():
            values = [m[key] for m in metrics_list]
            summary[key] = {
                "mean": float(np.mean(values)),
                "std": float(np.std(values)),
                "min": float(np.min(values)),
                "max": float(np.max(values)),
                "values": values,
            }
        all_results[model_type]["metrics_summary"] = summary
    
    # Save results
    results_path = os.path.join(output_dir, "results.json")
    with open(results_path, 'w') as f:
        json.dump(all_results, f, indent=2)
    print(f"\nResults saved to: {results_path}")
    
    # Save numpy arrays for easy loading
    np.savez(os.path.join(output_dir, "metrics.npz"),
             eq_E_r2=[r["metrics"]["E_r2"] for r in all_results["equivariant"]["runs"]],
             eq_F_r2=[r["metrics"]["F_r2"] for r in all_results["equivariant"]["runs"]],
             eq_E_mae=[r["metrics"]["E_mae_Ha"] for r in all_results["equivariant"]["runs"]],
             eq_F_mae=[r["metrics"]["F_mae"] for r in all_results["equivariant"]["runs"]],
             neq_E_r2=[r["metrics"]["E_r2"] for r in all_results["non_equivariant"]["runs"]],
             neq_F_r2=[r["metrics"]["F_r2"] for r in all_results["non_equivariant"]["runs"]],
             neq_E_mae=[r["metrics"]["E_mae_Ha"] for r in all_results["non_equivariant"]["runs"]],
             neq_F_mae=[r["metrics"]["F_mae"] for r in all_results["non_equivariant"]["runs"]])
    
    # Print summary
    print("\n" + "="*70)
    print("SUMMARY")
    print("="*70)
    print(f"\n{'Metric':<20} {'Rotationally Equivariant QML':<30} {'Graph Embedding Equivariant QML':<30}")
    print("-"*80)
    
    for metric in ["E_r2", "E_mae_Ha", "F_r2", "F_mae"]:
        eq_mean = all_results["equivariant"]["metrics_summary"][metric]["mean"]
        eq_std = all_results["equivariant"]["metrics_summary"][metric]["std"]
        neq_mean = all_results["non_equivariant"]["metrics_summary"][metric]["mean"]
        neq_std = all_results["non_equivariant"]["metrics_summary"][metric]["std"]
        
        print(f"{metric:<20} {eq_mean:.4f} ± {eq_std:.4f}            {neq_mean:.4f} ± {neq_std:.4f}")
    
    print("="*70)
    
    return all_results


# =============================================================================
# MAIN
# =============================================================================

def main(n_runs=2, n_epochs=50, output_dir="lih_comparison_results", data_dir="eqnn_force_field_data_LiH"):
    """
    Main function - can be called directly from Jupyter or command line.
    
    Args:
        n_runs: Number of runs for each model
        n_epochs: Training epochs per run
        output_dir: Directory to save results
        data_dir: Directory containing LiH data (.npy files)
    
    Returns:
        results dictionary
    """
    return run_comparison(
        n_runs=n_runs,
        n_epochs=n_epochs,
        output_dir=output_dir,
        data_dir=data_dir
    )


if __name__ == "__main__":
    import sys
    
    # Check if running in Jupyter
    if 'ipykernel' in sys.modules:
        # Running in Jupyter - use defaults or call main() directly
        print("Running in Jupyter notebook. Call main() directly with parameters:")
        print("  results = main(n_runs=2, n_epochs=50, output_dir='lih_results', data_dir='eqnn_force_field_data_LiH')")
    else:
        # Running as script - use argparse
        parser = argparse.ArgumentParser(description="Compare Rotationally Equivariant vs Graph Embedding Equivariant QML on LiH")
        parser.add_argument("--n_runs", type=int, default=2, help="Number of runs")
        parser.add_argument("--n_epochs", type=int, default=50, help="Training epochs per run")
        parser.add_argument("--output_dir", type=str, default="lih_comparison_results", help="Output directory")
        parser.add_argument("--data_dir", type=str, default="eqnn_force_field_data_LiH", help="Data directory")
        
        args = parser.parse_args()
        
        results = main(
            n_runs=args.n_runs,
            n_epochs=args.n_epochs,
            output_dir=args.output_dir,
            data_dir=args.data_dir
        )

In [3]:
results = main(n_runs=1, n_epochs=200, output_dir='lih_results', data_dir='eqnn_force_field_data_LiH')

LiH Energy/Force Prediction: Equivariant vs Non-Equivariant QML

Loading data...
  Loaded 2400 samples
  Train: 1920, Test: 480

RUN 1/1

[Equivariant Model]
  Training for 200 epochs...
  Evaluating...
  Energy: MAE=0.033382 Ha, R²=0.9966
  Force:  MAE=3.2999 eV/Å, R²=0.9239

[Non-Equivariant Model]
  Training for 200 epochs...
  Evaluating...
  Energy: MAE=0.026232 Ha, R²=0.9979
  Force:  MAE=2.5130 eV/Å, R²=0.9584

Results saved to: lih_results/results.json

SUMMARY

Metric               Equivariant               Non-Equivariant          
----------------------------------------------------------------------
E_r2                 0.9966 ± 0.0000       0.9979 ± 0.0000
E_mae_Ha             0.0334 ± 0.0000       0.0262 ± 0.0000
F_r2                 0.9239 ± 0.0000       0.9584 ± 0.0000
F_mae                3.2999 ± 0.0000       2.5130 ± 0.0000
