In [None]:
# Copyright 2025 The LEVER Authors - All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Iterative S-space evolution for LEVER wavefunction optimization.

Demonstrates modular evolution framework with configurable strategies for
space evolution, Hamiltonian screening, and energy evaluation.

File: examples/run_evolution.py
Author: Zheng (Alex) Che, email: wsmxcz@gmail.com
Date: January, 2025
"""

from __future__ import annotations

import time
from enum import Enum
from pathlib import Path
from typing import Any

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import scipy.sparse as sp
from pyscf import lib

import lever
from lever import engine, evolution

# JAX configuration
print("JAX detected devices:", jax.devices())
jax.config.update("jax_platforms", "cuda")
jax.config.update("jax_log_compiles", False)


# ========================================================================
# Configuration
# ========================================================================

class EvalMode(Enum):
    """Timing control for energy evaluations."""
    NEVER = "never"   # Skip evaluation
    FINAL = "final"   # Evaluate only at last cycle
    EVERY = "every"   # Evaluate at every cycle


class ScreenMode(Enum):
    """Heat-bath screening strategy for C-space selection."""
    OFF = "off"          # Full C-space without screening
    STATIC = "static"    # Fixed threshold heat-bath
    DYNAMIC = "dynamic"  # Amplitude-weighted heat-bath


CONFIG = {
    # System configuration
    'fcidump_path': "../benchmark/FCIDUMP/H2O_631g.FCIDUMP",
    'n_orbitals': 13,
    'n_alpha': 5,
    'n_beta': 5,
  
    # Optimization parameters
    'seed': 42,
    'learning_rate': 5e-4,
    'num_cycles': 10,
    's_space_size': 400,
    'steps_per_cycle': 500,
    'report_interval': 50,
  
    # Evaluation modes
    'var_energy_mode': EvalMode.FINAL,
    'T_CI_energy_mode': EvalMode.NEVER,
    'S_CI_energy_mode': EvalMode.FINAL,
  
    # Screening configuration
    'hb_screen': ScreenMode.DYNAMIC,
    'eps1': 1e-6,
}

ENGINE_CONFIG = engine.EngineConfig(
    compute_dtype=jnp.float64,
    energy_mode=engine.EnergyMode.PROXY,
    grad_mode=engine.GradMode.PROXY,
)


# ========================================================================
# Determinant & Reference Construction
# ========================================================================

def get_hf_determinant(n_orb: int, n_a: int, n_b: int) -> np.ndarray:
    """Construct Hartree-Fock reference as bit string."""
    return np.array([[(1 << n_a) - 1, (1 << n_b) - 1]], dtype=np.uint64)

def count_parameters(params: Any) -> int:
    """Count total number of parameters in JAX pytree."""
    return sum(x.size for x in jax.tree_util.tree_leaves(params))
# ========================================================================
# Energy Evaluation
# ========================================================================

def compute_fci_energy(int_ctx: lever.IntCtx, n_orb: int, n_a: int, n_b: int) -> float:
    """
    Exact ground state via full CI Hamiltonian diagonalization.
  
    Returns electronic energy only (nuclear repulsion excluded).
    """
    fci_dets = lever.core.gen_fci_dets(n_orb, n_a, n_b)
    ham_fci, _, _ = engine.hamiltonian.get_ham_proxy(
        S_dets=fci_dets, int_ctx=int_ctx, n_orbitals=n_orb, use_heatbath=False
    )
  
    H_csr = sp.coo_matrix(
        (ham_fci.vals, (ham_fci.rows, ham_fci.cols)), shape=ham_fci.shape
    ).tocsr()
    H_csr.sum_duplicates()
  
    x0 = np.zeros(H_csr.shape[0])
    x0[0] = 1.0
  
    e, _ = lib.eigh(lambda x: H_csr @ x, x0, H_csr.diagonal(), 
                    nroots=5, max_cycle=200, tol=1e-8)
    return float(e[0])


def diagonalize_hamiltonian(ham: Any, e_nuc: float) -> float:
    """
    Sparse Hamiltonian diagonalization for ground state.
  
    Returns total energy (electronic + nuclear repulsion).
    """
    if ham.shape[0] < 2:
        return (float(ham.vals[0]) if ham.shape[0] == 1 else 0.0) + e_nuc
  
    H_csr = sp.coo_matrix(
        (ham.vals, (ham.rows, ham.cols)), shape=ham.shape
    ).tocsr()
    H_csr.sum_duplicates()
  
    x0 = np.zeros(H_csr.shape[0])
    x0[0] = 1.0
  
    e, _ = lib.eigh(lambda x: H_csr @ x, x0, H_csr.diagonal(),
                    nroots=5, max_cycle=200, tol=1e-8)
    return float(e[0]) + e_nuc


def compute_variational_energy(
    variables: Any, logpsi_fn: callable, ham: Any,
    dets: np.ndarray, n_orb: int, e_nuc: float
) -> float:
    """
    Variational energy: ⟨ψ|H|ψ⟩ / ⟨ψ|ψ⟩.
  
    Evaluates Rayleigh quotient for neural network wavefunction.
    """
    t_vecs = engine.utils.masks_to_vecs(jnp.asarray(dets), n_orb)
    log_psi = logpsi_fn(variables, t_vecs)
    psi = np.array(jnp.exp(log_psi))
  
    h_psi = engine.kernels.coo_matvec(ham.rows, ham.cols, ham.vals, psi, len(dets))
    e_elec = np.vdot(psi, h_psi).real / np.vdot(psi, psi).real
    return e_elec + e_nuc


# ========================================================================
# Wavefunction Analysis
# ========================================================================

def compute_s_space_amplitudes(
    variables: Any, logpsi_fn: callable, s_dets: np.ndarray, n_orb: int
) -> np.ndarray:
    """Compute L2-normalized amplitude magnitudes for S-space."""
    s_vecs = engine.utils.masks_to_vecs(jnp.asarray(s_dets), n_orb)
    log_psi_s = logpsi_fn(variables, s_vecs)
    psi_s = np.abs(np.array(jnp.exp(log_psi_s)))
    norm = np.linalg.norm(psi_s)
    return psi_s / norm if norm > 1e-14 else psi_s


def should_evaluate(mode: EvalMode, current_cycle: int, total_cycles: int) -> bool:
    """Check if evaluation should occur at current cycle."""
    match mode:
        case EvalMode.NEVER:
            return False
        case EvalMode.FINAL:
            return current_cycle == total_cycles - 1
        case EvalMode.EVERY:
            return True


# ========================================================================
# Hamiltonian Construction
# ========================================================================

def build_hamiltonian(
    s_dets: np.ndarray,
    int_ctx: lever.IntCtx,
    n_orb: int,
    screen_mode: ScreenMode,
    eps1: float,
    variables: Any | None = None,
    logpsi_fn: callable | None = None,
) -> tuple[Any, Any, Any]:
    """
    Build Hamiltonian with configurable screening.
  
    Returns:
        (H_SS, H_SC, space_rep): Hamiltonian blocks and space representation
    """
    match screen_mode:
        case ScreenMode.OFF:
            return engine.hamiltonian.get_ham_proxy(
                S_dets=s_dets, int_ctx=int_ctx, n_orbitals=n_orb, use_heatbath=False
            )
        case ScreenMode.STATIC:
            return engine.hamiltonian.get_ham_proxy(
                S_dets=s_dets, int_ctx=int_ctx, n_orbitals=n_orb, 
                use_heatbath=True, eps1=eps1
            )
        case ScreenMode.DYNAMIC:
            if variables is None or logpsi_fn is None:
                raise ValueError("DYNAMIC screening requires variables and logpsi_fn")
            psi_s = compute_s_space_amplitudes(variables, logpsi_fn, s_dets, n_orb)
            return engine.hamiltonian.get_ham_proxy(
                S_dets=s_dets, int_ctx=int_ctx, n_orbitals=n_orb,
                psi_S=psi_s, use_heatbath=True, eps1=eps1
            )


# ========================================================================
# Optimization
# ========================================================================

def _create_jitted_step_fn(
    logpsi_fn: callable, ham_ss: Any, ham_sc: Any, space_rep: Any,
    n_orb: int, e_nuc: float, optimizer: optax.GradientTransformation,
) -> callable:
    """Factory for JIT-compiled optimization step."""
    def step_fn(variables: Any, opt_state: Any) -> tuple[Any, Any, float]:
        evaluator = engine.Evaluator(
            params=variables, logpsi_fn=logpsi_fn, ham_ss=ham_ss,
            ham_sc=ham_sc, space=space_rep, n_orbitals=n_orb, config=ENGINE_CONFIG,
        )
        result = engine.compute_energy_and_gradient(evaluator)
        updates, new_opt_state = optimizer.update(result.gradient, opt_state, variables)
        new_variables = optax.apply_updates(variables, updates)
        return new_variables, new_opt_state, result.energy_elec + e_nuc
  
    return jax.jit(step_fn)


def run_optimization_cycle(
    variables: Any, logpsi_fn: callable, ham_ss: Any, ham_sc: Any,
    space_rep: Any, n_orb: int, e_nuc: float, num_steps: int,
    lr: float, cycle_num: int, report_freq: int,
) -> tuple[Any, list[float]]:
    """Gradient descent within fixed S-C space partition."""
    optimizer = optax.adamw(lr)
    opt_state = optimizer.init(variables)
    jitted_step = _create_jitted_step_fn(
        logpsi_fn, ham_ss, ham_sc, space_rep, n_orb, e_nuc, optimizer
    )
  
    energy_history = []
    for step in range(num_steps):
        variables, opt_state, total_energy = jitted_step(variables, opt_state)
        energy_history.append(float(total_energy))
      
        if (step + 1) % report_freq == 0:
            print(f"  Cycle {cycle_num} | Step {step+1:4d}/{num_steps} | "
                  f"E = {total_energy:.8f} Ha")
  
    return variables, energy_history


# ========================================================================
# Visualization
# ========================================================================

def create_convergence_plot(
    energy_hist: list[float], cycle_bounds: list[int], exact_energy: float,
    var_energies: list[float] | None, s_ci_energies: list[float] | None,
    system_name: str,
) -> None:
    """Dual-panel plot: energy trajectory and logarithmic error."""
    plt.rcParams.update({
        'font.size': 11, 'axes.labelsize': 12, 'axes.titlesize': 14,
        'lines.linewidth': 2.0, 'grid.alpha': 0.3,
    })
  
    fig, (ax1, ax2) = plt.subplots(
        2, 1, figsize=(8, 6), height_ratios=[2, 1], gridspec_kw={'hspace': 0.15}
    )
  
    # Extract cycle-end energies
    energy = np.array(energy_hist)
    cycle_end_steps = [b - 1 for b in cycle_bounds[1:]]
    cycle_end_energies = [energy[idx] for idx in cycle_end_steps]
    cycle_end_errors = [abs(e - exact_energy) for e in cycle_end_energies]
    cycle_indices = np.arange(1, len(cycle_end_steps) + 1)
  
    chem_acc = 1.6e-3  # Chemical accuracy threshold
  
    # Energy trajectory panel
    ax1.axhspan(exact_energy - chem_acc, exact_energy + chem_acc,
                alpha=0.15, color='green', label='Chemical Accuracy (±1.6 mHa)')
    ax1.axhline(exact_energy, color='black', linestyle='--', linewidth=1.5,
                label=f'FCI: {exact_energy:.6f} Ha')
    ax1.plot(cycle_indices, cycle_end_energies, 'o-', color='steelblue',
             markersize=8, markeredgecolor='white', markeredgewidth=1.5, 
             label='LEVER (opt)')
  
    if var_energies:
        var_indices = np.arange(1, len(var_energies) + 1)
        ax1.plot(var_indices, var_energies, 's-', color='orange',
                 markersize=6, markeredgecolor='white', markeredgewidth=1.0,
                 label='LEVER (var)')
  
    ax1.set_ylim(exact_energy - 2e-3, exact_energy + 10e-3)
    ax1.set_ylabel('Total Energy (Ha)')
    ax1.set_xlabel('Evolution Cycle')
    ax1.set_title(f'LEVER Evolution: {system_name}')
    ax1.legend(loc='upper right')
    ax1.grid(True)
    ax1.set_xticks(cycle_indices)
  
    # Error convergence panel
    ax2.semilogy(cycle_indices, cycle_end_errors, 's-', color='crimson',
                 markersize=8, markeredgecolor='white', markeredgewidth=1.5,
                 label='Absolute Error')
    ax2.axhline(chem_acc, color='green', linestyle='--', alpha=0.6)
    ax2.set_xlabel('Evolution Cycle')
    ax2.set_ylabel(r'$|E_{\mathrm{LEVER}} - E_{\mathrm{FCI}}|$ (Ha)')
    ax2.legend(loc='upper right')
    ax2.grid(True)
    ax2.set_xticks(cycle_indices)
  
    plt.show()


# ========================================================================
# Main Workflow
# ========================================================================

def main() -> None:
    """Execute LEVER evolution with modular strategy."""
    cfg = CONFIG
    
    print(f"\n{'═' * 70}")
    print(f"LEVER {lever.__version__} - S-Space Evolution".center(70))
    print(f"{'═' * 70}\n")
    
    # System initialization
    int_ctx = lever.IntCtx(cfg['fcidump_path'], cfg['n_orbitals'])
    int_ctx.hb_prepare()
    e_nuc = int_ctx.get_e_nuc()
    system_name = Path(cfg['fcidump_path']).stem
    
    print(f"System: {system_name}")
    print(f"  Orbitals: {cfg['n_orbitals']} | α: {cfg['n_alpha']} | β: {cfg['n_beta']}")
    print(f"  E_nuc: {e_nuc:.8f} Ha\n")
    
    # Model initialization
    model = lever.models.Backflow(
        n_orbitals=cfg['n_orbitals'], n_alpha=cfg['n_alpha'], n_beta=cfg['n_beta'],
        seed=cfg['seed'], n_dets=1, generalized=True, restricted=False,
        hidden_dims=(256,), param_dtype=jnp.complex64
    )
    variables = model.variables
    n_params = count_parameters(variables)
    
    # Evolution strategy: amplitude-based top-K selection
    # evolution_strategy = evolution.BasicStrategy(
    #     scorer=evolution.scores.AmplitudeScorer(),
    #     selector=evolution.selectors.TopKSelector(k=cfg['s_space_size'])
    # )
    
    evolution_strategy = evolution.BasicStrategy(
        scorer=evolution.scores.AmplitudeScorer(),
        selector=evolution.selectors.ThresholdSelector(threshold=1E-4)
    )
    
    print(f"Model: {model.__class__.__name__} (Holomorphic: {model.is_holo})")
    print(f"  Parameters: {n_params:,} ({n_params / 1e6:.2f}M)" if n_params >= 1e6 
          else f"  Parameters: {n_params:,}")
    print(f"Evolution Strategy: {evolution_strategy.__class__.__name__}")
    print(f"  Cycles: {cfg['num_cycles']} | S-size: {cfg['s_space_size']} | "
          f"Steps/cycle: {cfg['steps_per_cycle']}")
    print(f"  Learning rate: {cfg['learning_rate']:.4f}")
    print(f"  Screening: {cfg['hb_screen'].value} (eps1={cfg['eps1']:.1e})")
    print(f"  Evaluation: Var={cfg['var_energy_mode'].value} | "
          f"T-CI={cfg['T_CI_energy_mode'].value} | S-CI={cfg['S_CI_energy_mode'].value}\n")
  
    # Evolution loop
    s_dets = get_hf_determinant(cfg['n_orbitals'], cfg['n_alpha'], cfg['n_beta'])
    full_history = []
    cycle_boundaries = [0]
    var_energy_history = []
    t_ci_energy_history = []
    s_ci_energy_history = []
  
    t_start = time.time()
    print(f"{'─' * 70}\nEvolution Progress\n{'─' * 70}")
  
    for cycle in range(cfg['num_cycles']):
        # Build Hamiltonian for current S-space
        ham_ss, ham_sc, space_rep = build_hamiltonian(
            s_dets=s_dets, int_ctx=int_ctx, n_orb=cfg['n_orbitals'],
            screen_mode=cfg['hb_screen'], eps1=cfg['eps1'],
            variables=variables, logpsi_fn=model.log_psi,
        )
      
        print(f"\nCycle {cycle + 1}/{cfg['num_cycles']} | "
              f"S: {space_rep.size_S} | C: {space_rep.size_C} | "
              f"H_SS: {len(ham_ss.vals):,} | H_SC: {len(ham_sc.vals):,}")
      
        # Optimize within current space
        variables, cycle_history = run_optimization_cycle(
            variables, model.log_psi, ham_ss, ham_sc, space_rep,
            cfg['n_orbitals'], e_nuc, cfg['steps_per_cycle'],
            cfg['learning_rate'], cycle + 1, cfg['report_interval'],
        )
      
        full_history.extend(cycle_history)
        cycle_boundaries.append(len(full_history))
        print(f"  Optimization energy: {cycle_history[-1]:.8f} Ha")
      
        # Post-optimization evaluations
        if should_evaluate(cfg['S_CI_energy_mode'], cycle, cfg['num_cycles']):
            e_s_ci = diagonalize_hamiltonian(ham_ss, e_nuc)
            s_ci_energy_history.append(e_s_ci)
            print(f"  S-space CI: {e_s_ci:.8f} Ha")
      
        if (should_evaluate(cfg['var_energy_mode'], cycle, cfg['num_cycles']) or
            should_evaluate(cfg['T_CI_energy_mode'], cycle, cfg['num_cycles'])):
          
            # Build exact T-space Hamiltonian
            t_dets = np.concatenate([space_rep.s_dets, space_rep.c_dets])
            ham_tt, _, _ = engine.hamiltonian.get_ham_proxy(
                S_dets=t_dets, int_ctx=int_ctx, n_orbitals=cfg['n_orbitals'],
                use_heatbath=False
            )
          
            if should_evaluate(cfg['var_energy_mode'], cycle, cfg['num_cycles']):
                e_var = compute_variational_energy(
                    variables, model.log_psi, ham_tt, t_dets,
                    cfg['n_orbitals'], e_nuc
                )
                var_energy_history.append(e_var)
                print(f"  Variational energy (T-space): {e_var:.8f} Ha")
          
            if should_evaluate(cfg['T_CI_energy_mode'], cycle, cfg['num_cycles']):
                e_t_ci = diagonalize_hamiltonian(ham_tt, e_nuc)
                t_ci_energy_history.append(e_t_ci)
                print(f"  T-space CI: {e_t_ci:.8f} Ha | Size: {len(t_dets)}")
      
        # Evolve S-space for next cycle
        if cycle < cfg['num_cycles'] - 1:
            print("  Evolving S-space...")
            evaluator = engine.Evaluator(
                params=variables, logpsi_fn=model.log_psi, ham_ss=ham_ss,
                ham_sc=ham_sc, space=space_rep, n_orbitals=cfg['n_orbitals'],
                config=ENGINE_CONFIG,
            )
            s_dets = evolution_strategy.evolve(evaluator)
            print(f"  New S-space size: {len(s_dets)}")
  
    elapsed = time.time() - t_start
  
    # Final analysis
    print(f"\n{'─' * 70}\nFinal Analysis\n{'─' * 70}\n")
    # e_fci_elec = compute_fci_energy(int_ctx, cfg['n_orbitals'], 
    #                                  cfg['n_alpha'], cfg['n_beta'])
    # e_fci = e_fci_elec + e_nuc
    
    e_fci = -76.12087434594525
  
    print(f"Reference Energy:\n  FCI: {e_fci:.8f} Ha\n")
    print(f"LEVER Results:")
    print(f"  Optimization final: {full_history[-1]:.8f} Ha "
          f"(gap: {(full_history[-1] - e_fci) * 1e3:+.4f} mHa)")
  
    if var_energy_history:
        print(f"  Variational final (T-space): {var_energy_history[-1]:.8f} Ha "
              f"(gap: {(var_energy_history[-1] - e_fci) * 1e3:+.4f} mHa)")
  
    if t_ci_energy_history:
        print(f"  T-space CI final: {t_ci_energy_history[-1]:.8f} Ha "
              f"(gap: {(t_ci_energy_history[-1] - e_fci) * 1e3:+.4f} mHa)")
  
    if s_ci_energy_history:
        print(f"  S-space CI final: {s_ci_energy_history[-1]:.8f} Ha "
              f"(gap: {(s_ci_energy_history[-1] - e_fci) * 1e3:+.4f} mHa)")
  
    print(f"  Evolution time: {elapsed:.2f} s\n")
  
    # Visualization
    create_convergence_plot(
        full_history, cycle_boundaries, e_fci,
        var_energy_history if var_energy_history else None,
        s_ci_energy_history if s_ci_energy_history else None,
        system_name
    )


if __name__ == "__main__":
    main()