In [None]:
# examples/run_evolution.py

import lever
from lever import config, models, evolution, driver, analysis

import jax
import jax.numpy as jnp
print("JAX detected devices:", jax.devices())
jax.config.update("jax_platforms", "cuda")
jax.config.update("jax_log_compiles", False)

def main():
    # 1. Define Configuration
    sys_cfg = config.SystemConfig(
        fcidump_path="../benchmark/FCIDUMP/H2O_631g.FCIDUMP",
        n_orbitals=13, n_alpha=5, n_beta=5
    )
    opt_cfg = config.OptimizationConfig(
        s_space_size=400, steps_per_cycle=500, num_cycles=10
    )
    eval_cfg = config.EvaluationConfig(
        var_energy_mode=config.EvalMode.FINAL,
        s_ci_energy_mode=config.EvalMode.FINAL
    )
    screen_cfg = config.ScreeningConfig(mode=config.ScreenMode.DYNAMIC, eps1=1e-6)
    
    # Use float64 for higher precision
    engine_cfg = lever.engine.EngineConfig(compute_dtype=jnp.float64)

    lever_config = config.LeverConfig(
        system=sys_cfg, optimization=opt_cfg, evaluation=eval_cfg,
        screening=screen_cfg, engine=engine_cfg
    )

    # 2. Initialize Components
    model = models.Backflow(
        n_orbitals=sys_cfg.n_orbitals, n_alpha=sys_cfg.n_alpha, n_beta=sys_cfg.n_beta,
        seed=opt_cfg.seed, n_dets=1, generalized=True, restricted=False,
        hidden_dims=(256,), param_dtype=jnp.complex64
    )
    
    evo_strategy = evolution.BasicStrategy(
        scorer=evolution.scores.AmplitudeScorer(),
        selector=evolution.selectors.TopKSelector(k=opt_cfg.s_space_size)
    )

    # 3. Create and Run the Driver
    lever_driver = driver.Driver(lever_config, model, evo_strategy)
    results = lever_driver.run()

    # 4. Perform Analysis
    analysis_suite = analysis.AnalysisSuite(results, lever_driver.int_ctx)
    analysis_suite.print_summary()
    analysis_suite.plot_convergence()

if __name__ == "__main__":
    main()