In [None]:
# Copyright 2025 The LEVER Authors - All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
LEVER optimization example with EFFECTIVE mode.

File: examples/run_evolution.py
Author: Zheng (Alex) Che, email: wsmxcz@gmail.com
Date: November, 2025
"""

import lever
from lever import config, models, evolution, driver, analysis

import jax
import jax.numpy as jnp

# JAX configuration
print("JAX detected devices:", jax.devices())
jax.config.update("jax_platforms", "cuda")
jax.config.update("jax_log_compiles", False)


def main():
    # System configuration
    sys_cfg = config.SystemConfig(
        fcidump_path="../benchmark/FCIDUMP/N2_sto3g_2.50.FCIDUMP",
        n_orbitals=10, 
        n_alpha=7, 
        n_beta=7
    )
    
    # Optimization configuration
    opt_cfg = config.OptimizationConfig(
        seed=42,
        learning_rate=5e-4,
        s_space_size=200, 
        steps_per_cycle=200, 
        num_cycles=10,
        report_interval=50
    )
    
    # Evaluation configuration for EFFECTIVE mode
    eval_cfg = config.EvaluationConfig(
        var_energy_mode=config.EvalMode.NEVER,  # T-space not used in EFFECTIVE
        s_ci_energy_mode=config.EvalMode.NEVER, # S-space CI at end
        t_ci_energy_mode=config.EvalMode.NEVER, # T-space not applicable
    )
    
    # Screening configuration
    screen_cfg = config.ScreeningConfig(
        mode=config.ScreenMode.DYNAMIC,
        eps1=1e-6
    )
    
    # LEVER configuration with EFFECTIVE mode
    lever_config = config.LeverConfig(
        system=sys_cfg, 
        optimization=opt_cfg, 
        evaluation=eval_cfg,
        screening=screen_cfg,
        compute_mode=config.ComputeMode.EFFECTIVE,  # Schur downfolding
    )
    
    print(f"Compute mode: {lever_config.compute_mode.value}")
    print(f"Epsilon: {lever_config.epsilon}")
    print(f"Normalize wavefunction: {lever_config.normalize_wf}")

    # Model initialization
    model = models.Backflow(
        n_orbitals=sys_cfg.n_orbitals, 
        n_alpha=sys_cfg.n_alpha, 
        n_beta=sys_cfg.n_beta,
        seed=opt_cfg.seed, 
        n_dets=1, 
        generalized=True, 
        restricted=False,
        hidden_dims=(256,), 
        param_dtype=jnp.complex64
    )
    
    # Evolution strategy with amplitude-based scoring
    evo_strategy = evolution.BasicStrategy(
        scorer=evolution.scores.AmplitudeScorer(),
        selector=evolution.selectors.TopKSelector(k=opt_cfg.s_space_size)
    )

    # Run LEVER driver
    lever_driver = driver.Driver(lever_config, model, evo_strategy)
    results = lever_driver.run()

    # Analysis and visualization
    analysis_suite = analysis.AnalysisSuite(results, lever_driver.int_ctx)
    analysis_suite.print_summary()
    analysis_suite.plot_convergence(system_name="N2_STO-3G_EFFECTIVE")


if __name__ == "__main__":
    main()
