In [None]:
# Copyright 2025 The LEVER Authors - All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Löwdin partitioning convergence analysis for selected CI.

Compares energy estimators with expanding model space S:
- sCI: Variational energy from H_SS diagonalization
- sCI+PT2: Epstein-Nesbet second-order correction
- H_eff: Löwdin effective Hamiltonian in S-space
- H̃: Full model space (S⊕C) Hamiltonian
- E_var(ψ̃): Exact expectation on model eigenvector

Löwdin formula: H_eff = H_SS + H_SC·D⁻¹·H_CS
where D_jj = E_ref - H_CC,jj with E_ref from H_SS ground state.

File: examples/study_effective_ham.py
Author: Zheng (Alex) Che, wsmxcz@gmail.com
Date: January 2025
"""

from __future__ import annotations

from dataclasses import dataclass, field
from pathlib import Path
from typing import NamedTuple

import numpy as np
import scipy.sparse as sp
import scipy.sparse.linalg as spla
import matplotlib.pyplot as plt

import lever


# --- Configuration ---

SYSTEM_NAME = "N2_STO3G"
FCIDUMP_PATH = Path("../benchmark/FCIDUMP/C2_sto3g_2.00.FCIDUMP")
N_ORBITALS = 10
N_ALPHA, N_BETA = 6, 6

K_MAX = 800  # Max S-space size
PT2_TOL = 1e-12  # Denominator cutoff for PT2
LOWDIN_TOL = 1e-12  # Denominator cutoff for Löwdin
HAM_THRESH = 1e-12  # Matrix element threshold
DENSE_CUTOFF = 10  # Use dense solver below this size


# --- Data Structures ---

class EigenPair(NamedTuple):
    """Eigenvalue-eigenvector pair."""
    E: float
    ψ: np.ndarray


@dataclass
class ConvergenceData:
    """Convergence study results storage."""
    k: list[int] = field(default_factory=list)
    size_C: list[int] = field(default_factory=list)
    nnz_H_tilde: list[int] = field(default_factory=list)  # Non-zeros in H̃
    nnz_H_eff: list[int] = field(default_factory=list)   # Non-zeros in H_eff
    E_sCI: list[float] = field(default_factory=list)
    E_PT2: list[float] = field(default_factory=list)
    E_Lowdin: list[float] = field(default_factory=list)
    E_H_tilde: list[float] = field(default_factory=list)
    E_exact: list[float] = field(default_factory=list)


# --- Sparse Linear Algebra ---

def to_csr(op: lever.HamOp) -> sp.csr_matrix:
    """Convert LEVER HamOp to CSR format."""
    mat = sp.coo_matrix((op.vals, (op.rows, op.cols)), shape=op.shape)
    mat.sum_duplicates()
    return mat.tocsr()


def lowest_eigenpair(H: sp.csr_matrix) -> EigenPair:
    """
    Compute ground state via ARPACK or dense solver.
  
    Uses dense eigh for small matrices (robust), ARPACK otherwise (efficient).
    """
    if H.shape[0] == 0:
        return EigenPair(np.inf, np.array([]))
  
    if H.shape[0] < DENSE_CUTOFF:
        evals, evecs = np.linalg.eigh(H.toarray())
        return EigenPair(evals[0], evecs[:, 0])
  
    evals, evecs = spla.eigsh(H, k=1, which='SA', tol=1e-15, ncv=200, maxiter=3000)
    return EigenPair(evals[0], evecs[:, 0])


# --- Energy Estimators ---

def compute_PT2(
    E_var: float,
    ψ_S: np.ndarray,
    ham_sc: lever.HamOp,
    H_diag_C: np.ndarray,
) -> float:
    """
    Epstein-Nesbet PT2: E₂ = Σ_k |⟨k|H|ψ_S⟩|²/(E_var - H_kk).
  
    Regularized sum over energetically accessible C-space states.
    """
    if ham_sc.nnz == 0:
        return 0.0
  
    # Build H_CS = H_SC^†
    H_CS = to_csr(lever.HamOp(
        rows=ham_sc.cols, cols=ham_sc.rows, vals=np.conj(ham_sc.vals),
        shape=(ham_sc.shape[1], ham_sc.shape[0])
    ))
  
    numerator = H_CS @ ψ_S
    denominator = E_var - H_diag_C
  
    mask = np.abs(denominator) > PT2_TOL
    return float(np.sum(np.abs(numerator[mask])**2 / denominator[mask])) if np.any(mask) else 0.0


def build_H_eff_Lowdin(
    ham_ss: lever.HamOp,
    ham_sc: lever.HamOp,
    H_diag_C: np.ndarray,
    E_ref: float,
) -> sp.csr_matrix:
    """
    Löwdin effective Hamiltonian: H_eff = H_SS + H_SC·D⁻¹·H_CS.
  
    D_jj = E_ref - H_CC,jj where E_ref is H_SS ground state energy.
    Decouples C-space via perturbative downfolding.
    """
    H_SS = to_csr(ham_ss)
  
    if ham_sc.nnz == 0:
        return H_SS
  
    # Regularized inverse denominator
    denom = E_ref - H_diag_C
    mask = np.abs(denom) > LOWDIN_TOL
    D_inv = np.zeros_like(denom)
    D_inv[mask] = 1.0 / denom[mask]
  
    # ΔH = H_SC · diag(D⁻¹) · H_CS
    H_SC = to_csr(ham_sc)
    H_CS = H_SC.T.conj()
    ΔH = (H_SC @ sp.diags(D_inv)) @ H_CS
  
    return H_SS + ΔH


def build_H_tilde(
    ham_ss: lever.HamOp,
    ham_sc: lever.HamOp,
    H_diag_C: np.ndarray,
) -> sp.csr_matrix:
    """
    Full model Hamiltonian H̃ in T=S⊕C space with diagonal C-block.
  
    Validation reference for approximate methods.
    """
    if ham_sc.shape[1] == 0:
        return to_csr(ham_ss)
  
    return sp.bmat([
        [to_csr(ham_ss), to_csr(ham_sc)],
        [to_csr(ham_sc).T.conj(), sp.diags(H_diag_C)]
    ], format='csr')


def compute_exact_expectation(
    ψ_tilde: np.ndarray,
    space: lever.SpaceRep,
    int_ctx: lever.IntCtx,
) -> float:
    """
    Exact energy ⟨ψ̃|H_full|ψ̃⟩ on combined space T=S∪C.
  
    Builds full Hamiltonian without approximations for validation.
    """
    T_dets = np.vstack([space.s_dets, space.c_dets])
    ham_tt, _, _ = lever.get_ham_proxy(
        T_dets, int_ctx=int_ctx, n_orbitals=N_ORBITALS,
        use_heatbath=False, thresh=HAM_THRESH
    )
    H_TT = to_csr(ham_tt)
  
    return float(np.real(np.vdot(ψ_tilde, H_TT @ ψ_tilde) / 
                         np.vdot(ψ_tilde, ψ_tilde)))


# --- Convergence Study ---

def run_convergence_study(
    int_ctx: lever.IntCtx,
    fci_dets: np.ndarray,
    E_fci: float,
    ψ_fci: np.ndarray,
) -> ConvergenceData:
    """
    Analyze energy convergence with expanding S-space.
  
    Selects determinants by FCI coefficient magnitude and tracks five
    energy estimators plus matrix sparsity statistics.
    """
    print(f"\n{'─'*80}\nConvergence Study: {SYSTEM_NAME}\n{'─'*80}")
  
    # Select S-space sizes
    ranking = np.argsort(-np.abs(ψ_fci))
    E_nuc = int_ctx.get_e_nuc()
    k_points = [1, 10, 50, 100, 150, 200, 400]
    ks = [k for k in k_points if k <= min(len(fci_dets), K_MAX)]
  
    data = ConvergenceData()
  
    # Table header
    print(f"{'k':>4} {'|C|':>6} {'nnz(H̃)':>9} {'nnz(Hₑ)':>9} "
          f"{'ΔE_sCI':>10} {'ΔE_PT2':>10} {'ΔE_Löw':>10} {'ΔE_H̃':>10}")
    print('─' * 80)
  
    for k in ks:
        S_dets = fci_dets[ranking[:k]]
      
        # Build Hamiltonian blocks
        ham_ss, ham_sc, space = lever.get_ham_proxy(
            S_dets, int_ctx=int_ctx, n_orbitals=N_ORBITALS,
            use_heatbath=False, thresh=HAM_THRESH
        )
      
        # (1) sCI: Variational in S-space
        ep_sCI = lowest_eigenpair(to_csr(ham_ss))
        E_sCI = ep_sCI.E + E_nuc
      
        # (2) sCI+PT2: Epstein-Nesbet correction
        E_PT2_corr = compute_PT2(ep_sCI.E, ep_sCI.ψ, ham_sc, space.H_diag_C)
        E_PT2 = ep_sCI.E + E_PT2_corr + E_nuc
      
        # (3) H_eff: Löwdin partitioning
        H_eff = build_H_eff_Lowdin(ham_ss, ham_sc, space.H_diag_C, E_ref=ep_sCI.E)
        ep_Lowdin = lowest_eigenpair(H_eff)
        E_Lowdin = ep_Lowdin.E + E_nuc
      
        # (4) H̃: Full model Hamiltonian
        H_tilde = build_H_tilde(ham_ss, ham_sc, space.H_diag_C)
        ep_tilde = lowest_eigenpair(H_tilde)
        E_tilde = ep_tilde.E + E_nuc
      
        # (5) Exact expectation
        E_exact = compute_exact_expectation(ep_tilde.ψ, space, int_ctx) + E_nuc
      
        # Record results
        data.k.append(k)
        data.size_C.append(space.size_C)
        data.nnz_H_tilde.append(H_tilde.nnz)
        data.nnz_H_eff.append(H_eff.nnz)
        data.E_sCI.append(E_sCI)
        data.E_PT2.append(E_PT2)
        data.E_Lowdin.append(E_Lowdin)
        data.E_H_tilde.append(E_tilde)
        data.E_exact.append(E_exact)
      
        # Print errors (mHa)
        errors_mHa = [(E - E_fci) * 1000 for E in [E_sCI, E_PT2, E_Lowdin, E_tilde]]
        print(f"{k:4d} {space.size_C:6d} {H_tilde.nnz:9d} {H_eff.nnz:9d} "
              f"{errors_mHa[0]:10.3f} {errors_mHa[1]:10.3f} "
              f"{errors_mHa[2]:10.3f} {errors_mHa[3]:10.3f}")
  
    return data


# --- Visualization ---

def plot_convergence(data: ConvergenceData, E_fci: float) -> None:
    """
    Dual-panel convergence plot.
  
    Top: Energy estimators vs |S| with chemical accuracy band
    Bottom: Absolute error on log scale
    """
    plt.rcParams.update({
        'font.size': 11,
        'axes.labelsize': 12,
        'axes.titlesize': 14,
        'lines.linewidth': 2.0,
        'grid.alpha': 0.3,
    })
  
    fig, (ax1, ax2) = plt.subplots(
        2, 1, figsize=(12, 9), sharex=True,
        height_ratios=[2.5, 1], gridspec_kw={'hspace': 0.10}
    )
  
    k_vals = np.array(data.k)
    chem_acc = 1.6e-3  # Chemical accuracy: 1.6 mHa
  
    # --- Panel 1: Energy trajectories ---
  
    ax1.axhspan(
        E_fci - chem_acc, E_fci + chem_acc,
        alpha=0.15, color='green', zorder=0,
        label='Chemical Accuracy (±1.6 mHa)'
    )
    ax1.axhline(
        E_fci, color='black', linestyle='--', linewidth=1.5,
        label=f'FCI: {E_fci:.6f} Ha', zorder=1
    )
  
    # Energy curves
    estimators = [
        ('sCI', data.E_sCI, 'd', 'steelblue'),
        ('sCI+PT2', data.E_PT2, 'o', 'orange'),
        ('Hₑff (Löwdin)', data.E_Lowdin, 'v', 'purple'),
        ('E(H̃)', data.E_H_tilde, 's', 'crimson'),
        ('E_var(ψ̃)', data.E_exact, '^', 'darkgreen'),
    ]
  
    for label, E_data, marker, color in estimators:
        ax1.plot(
            k_vals, E_data, marker=marker, linestyle='-',
            color=color, markersize=7, markeredgecolor='white',
            markeredgewidth=1.2, label=label, zorder=2
        )
  
    ax1.set_ylabel('Total Energy (Ha)')
    ax1.set_title(f'Effective Hamiltonian Convergence: {SYSTEM_NAME}')
    ax1.legend(loc='upper right', framealpha=0.95, fontsize=10)
    ax1.grid(True)
    ax1.tick_params(axis='x', which='both', bottom=False, labelbottom=False)
    ax1.ticklabel_format(style='plain', axis='y', useOffset=False)
  
    # --- Panel 2: Absolute errors ---
  
    errors = {
        'sCI': np.abs(np.array(data.E_sCI) - E_fci),
        'sCI+PT2': np.abs(np.array(data.E_PT2) - E_fci),
        'Hₑff': np.abs(np.array(data.E_Lowdin) - E_fci),
        'E(H̃)': np.abs(np.array(data.E_H_tilde) - E_fci),
    }
  
    markers_err = ['d', 'o', 'v', 's']
    colors_err = ['steelblue', 'orange', 'purple', 'crimson']
  
    for (label, err), marker, color in zip(errors.items(), markers_err, colors_err):
        ax2.semilogy(
            k_vals, err, marker=marker, linestyle='-',
            color=color, markersize=7, markeredgecolor='white',
            markeredgewidth=1.2, label=label
        )
  
    ax2.axhline(chem_acc, color='green', linestyle='--', alpha=0.6, linewidth=1.5)
    ax2.set_xlabel('Model Space Size |S|')
    ax2.set_ylabel(r'$|E - E_\mathrm{FCI}|$ (Ha)')
    ax2.legend(loc='upper right', fontsize=10)
    ax2.grid(True)
    ax2.set_xticks(k_vals)
  
    # plt.tight_layout()
    # plt.savefig('heff_convergence.pdf', dpi=300, bbox_inches='tight')
    plt.show()


# --- Main Execution ---

def main() -> None:
    """Execute convergence analysis workflow."""
    print(f"\n{'═'*80}")
    print(f"{'LEVER Effective Hamiltonian Study (Löwdin Partitioning)':^80}")
    print(f"{'═'*80}\n")
  
    if not FCIDUMP_PATH.exists():
        raise FileNotFoundError(f"FCIDUMP not found: {FCIDUMP_PATH}")
  
    # Initialize system
    int_ctx = lever.IntCtx(str(FCIDUMP_PATH), N_ORBITALS)
    fci_dets = lever.gen_fci_dets(N_ORBITALS, N_ALPHA, N_BETA)
  
    print(f"System:     {SYSTEM_NAME}")
    print(f"Orbitals:   {N_ORBITALS}")
    print(f"Electrons:  ({N_ALPHA}α, {N_BETA}β)")
    print(f"FCI Space:  {len(fci_dets):,} determinants\n")
  
    # Compute FCI reference
    print(f"{'─'*80}\nSolving FCI Reference\n{'─'*80}")
    ham_fci, _, _ = lever.get_ham_proxy(
        fci_dets, int_ctx=int_ctx, n_orbitals=N_ORBITALS, use_heatbath=False
    )
    ep_fci = lowest_eigenpair(to_csr(ham_fci))
    E_fci = ep_fci.E + int_ctx.get_e_nuc()
    print(f"E_FCI = {E_fci:.12f} Ha\n")
  
    # Run convergence study
    conv_data = run_convergence_study(int_ctx, fci_dets, E_fci, ep_fci.ψ)
  
    # Visualize results
    print(f"\n{'─'*80}\nGenerating Convergence Plot\n{'─'*80}")
    plot_convergence(conv_data, E_fci)
  
    print(f"\n{'═'*80}")
    print(f"{'Analysis Complete':^80}")
    print(f"{'═'*80}\n")


if __name__ == "__main__":
    main()