In [None]:
# Copyright 2025 The LEVER Authors - All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Löwdin effective Hamiltonian convergence analysis.

Compares five energy estimators with expanding model space S:
  - sCI: Variational H_SS diagonalization
  - sCI+PT2: Epstein-Nesbet second-order correction
  - H_eff: Löwdin effective Hamiltonian (Schur complement approximation)
  - H̃: Full model space (S⊕C) Hamiltonian with diagonal C-block
  - E_var(ψ̃): Exact expectation on model eigenvector

Algorithm: H_eff = H_SS + H_SC·D⁻¹·H_CS where D = diag(E_ref - H_CC)

File: examples/study_effective_ham.py
Author: Zheng (Alex) Che, email: wsmxcz@gmail.com
Date: January 2025
"""

from __future__ import annotations

from dataclasses import dataclass, field
from pathlib import Path
from typing import NamedTuple

import matplotlib.pyplot as plt
import numpy as np
import scipy.sparse as sp
import scipy.sparse.linalg as spla

import lever

# ============================================================================
# Configuration
# ============================================================================

SYSTEM_NAME = "N₂/STO-3G"
FCIDUMP_PATH = Path("../benchmark/FCIDUMP/N2_sto3g.FCIDUMP")
N_ORBITALS = 10
N_ALPHA, N_BETA = 7, 7

K_SEQUENCE = [10, 50, 100, 150, 200, 400, 800]  # Model space sizes
EPSILON = 1e-12  # Unified numerical threshold for PT2/Löwdin/matrix elements

# ============================================================================
# Data Structures
# ============================================================================


class EigenPair(NamedTuple):
    """Ground state eigenvalue and eigenvector."""

    E: float
    ψ: np.ndarray


@dataclass
class EnergyPoint:
    """Energy estimates at fixed |S|."""

    k: int
    size_C: int
    nnz_H_tilde: int
    nnz_H_eff: int
    E_sCI: float
    E_PT2: float
    E_Lowdin: float
    E_H_tilde: float
    E_exact: float


@dataclass
class ConvergenceData:
    """Storage for convergence study results."""

    points: list[EnergyPoint] = field(default_factory=list)

    def append(self, point: EnergyPoint) -> None:
        self.points.append(point)

    def to_arrays(self) -> dict[str, np.ndarray]:
        """Convert to arrays for plotting."""
        return {
            "k": np.array([p.k for p in self.points]),
            "E_sCI": np.array([p.E_sCI for p in self.points]),
            "E_PT2": np.array([p.E_PT2 for p in self.points]),
            "E_Lowdin": np.array([p.E_Lowdin for p in self.points]),
            "E_H_tilde": np.array([p.E_H_tilde for p in self.points]),
            "E_exact": np.array([p.E_exact for p in self.points]),
        }


# ============================================================================
# Core Utilities
# ============================================================================


def to_csr(op: lever.HamOp) -> sp.csr_matrix:
    """Convert HamOp to CSR sparse matrix."""
    mat = sp.coo_matrix((op.vals, (op.rows, op.cols)), shape=op.shape)
    mat.sum_duplicates()
    return mat.tocsr()


def lowest_eigenpair(H: sp.csr_matrix) -> EigenPair:
    """Compute ground state via LAPACK (small) or ARPACK (large)."""
    if H.shape[0] == 0:
        return EigenPair(np.inf, np.array([]))
    if H.shape[0] < 10:  # Dense solver threshold
        evals, evecs = np.linalg.eigh(H.toarray())
        return EigenPair(evals[0], evecs[:, 0])
    evals, evecs = spla.eigsh(H, k=1, which="SA", tol=1e-15, maxiter=3000)
    return EigenPair(evals[0], evecs[:, 0])


def compute_PT2(
    E_var: float, ψ_S: np.ndarray, ham_sc: lever.HamOp, H_diag_C: np.ndarray
) -> float:
    """Epstein-Nesbet PT2: E₂ = Σ |⟨k|H|ψ_S⟩|²/(E_var - H_kk)."""
    if ham_sc.nnz == 0:
        return 0.0

    H_CS = to_csr(
        lever.HamOp(
            rows=ham_sc.cols,
            cols=ham_sc.rows,
            vals=np.conj(ham_sc.vals),
            shape=(ham_sc.shape[1], ham_sc.shape[0]),
        )
    )

    numerator = H_CS @ ψ_S
    denominator = E_var - H_diag_C
    mask = np.abs(denominator) > EPSILON

    return float(np.sum(np.abs(numerator[mask]) ** 2 / denominator[mask])) if np.any(mask) else 0.0


def build_H_tilde(
    ham_ss: lever.HamOp, ham_sc: lever.HamOp, H_diag_C: np.ndarray
) -> sp.csr_matrix:
    """Full model Hamiltonian: H̃ = [[H_SS, H_SC], [H_SC†, diag(H_CC)]]."""
    if ham_sc.shape[1] == 0:
        return to_csr(ham_ss)
    return sp.bmat(
        [[to_csr(ham_ss), to_csr(ham_sc)], [to_csr(ham_sc).T.conj(), sp.diags(H_diag_C)]],
        format="csr",
    )


def compute_exact_expectation(
    ψ_tilde: np.ndarray, space: lever.SpaceRep, int_ctx: lever.IntCtx
) -> float:
    """Exact energy ⟨ψ̃|H_full|ψ̃⟩ on combined T = S∪C space."""
    T_dets = np.vstack([space.s_dets, space.c_dets])
    ham_tt, _ = lever.engine.hamiltonian.get_ham_ss(
        T_dets, int_ctx=int_ctx, n_orbitals=N_ORBITALS
    )
    H_TT = to_csr(ham_tt)
    return float(np.real(np.vdot(ψ_tilde, H_TT @ ψ_tilde) / np.vdot(ψ_tilde, ψ_tilde)))


# ============================================================================
# Convergence Study
# ============================================================================


def run_convergence_study(
    int_ctx: lever.IntCtx, fci_dets: np.ndarray, E_fci: float, ψ_fci: np.ndarray
) -> ConvergenceData:
    """Systematic convergence analysis with expanding S-space."""
    print(f"\n{'═'*70}")
    print(f"Convergence Study: {SYSTEM_NAME}")
    print(f"{'═'*70}\n")

    ranking = np.argsort(-np.abs(ψ_fci))
    E_nuc = int_ctx.get_e_nuc()
    ks = [k for k in K_SEQUENCE if k <= len(fci_dets)]

    print(f"{'|S|':>5} {'|C|':>6} {'nnz(H̃)':>9} {'nnz(Hₑ)':>9} "
          f"{'Δ_sCI':>10} {'Δ_PT2':>10} {'Δ_Löw':>10} {'Δ_H̃':>10}")
    print("─" * 70)

    data = ConvergenceData()

    for k in ks:
        S_dets = fci_dets[ranking[:k]]
        ham_ss, ham_sc, space = lever.get_ham_proxy(
            S_dets, int_ctx=int_ctx, n_orbitals=N_ORBITALS, mode="none"
        )

        # (1) sCI variational energy
        ep_sCI = lowest_eigenpair(to_csr(ham_ss))
        E_sCI = ep_sCI.E + E_nuc

        # (2) sCI+PT2 correction
        E_PT2_corr = compute_PT2(ep_sCI.E, ep_sCI.ψ, ham_sc, space.H_diag_C)
        E_PT2 = ep_sCI.E + E_PT2_corr + E_nuc

        # (3) Löwdin effective Hamiltonian
        ham_eff = lever.engine.hamiltonian.get_ham_eff(
            ham_ss=ham_ss,
            ham_sc=ham_sc,
            h_cc_diag=space.H_diag_C,
            e_ref=ep_sCI.E,
            reg_type="sigma",
            epsilon=EPSILON,
            upper_only=True,
        )
        ep_Lowdin = lowest_eigenpair(to_csr(ham_eff))
        E_Lowdin = ep_Lowdin.E + E_nuc

        # (4) Full model Hamiltonian
        H_tilde = build_H_tilde(ham_ss, ham_sc, space.H_diag_C)
        ep_tilde = lowest_eigenpair(H_tilde)
        E_tilde = ep_tilde.E + E_nuc

        # (5) Exact expectation
        E_exact = compute_exact_expectation(ep_tilde.ψ, space, int_ctx) + E_nuc

        data.append(
            EnergyPoint(
                k=k,
                size_C=space.size_C,
                nnz_H_tilde=H_tilde.nnz,
                nnz_H_eff=ham_eff.nnz,
                E_sCI=E_sCI,
                E_PT2=E_PT2,
                E_Lowdin=E_Lowdin,
                E_H_tilde=E_tilde,
                E_exact=E_exact,
            )
        )

        # Print errors in mHa
        errs = [(E - E_fci) * 1000 for E in [E_sCI, E_PT2, E_Lowdin, E_tilde]]
        print(f"{k:5d} {space.size_C:6d} {H_tilde.nnz:9d} {ham_eff.nnz:9d} "
              f"{errs[0]:10.4f} {errs[1]:10.4f} {errs[2]:10.4f} {errs[3]:10.4f}")

    print(f"\n{'─'*70}\n")
    return data


# ============================================================================
# Visualization
# ============================================================================


def plot_convergence(data: ConvergenceData, E_fci: float) -> None:
    """Two-panel convergence plot: absolute energies and log-scale errors."""
    plt.rcParams.update({
        "font.family": "serif",
        "font.size": 10,
        "lines.linewidth": 1.8,
        "lines.markersize": 6,
        "grid.alpha": 0.25,
    })

    arrays = data.to_arrays()
    k_vals = arrays["k"]
    chem_acc = 1.6e-3  # Chemical accuracy: 1 kcal/mol ≈ 1.6 mHa

    fig, (ax1, ax2) = plt.subplots(
        2, 1, figsize=(8, 6), sharex=True, height_ratios=[3, 2], gridspec_kw={"hspace": 0.08}
    )

    # Panel 1: Absolute energies
    ax1.axhspan(E_fci - chem_acc, E_fci + chem_acc, alpha=0.12, color="green", zorder=0,
                label=f"Chem. Acc. (±{chem_acc*1000:.1f} mHa)")
    ax1.axhline(E_fci, color="black", linestyle="--", linewidth=1.2, alpha=0.7,
                label=f"FCI: {E_fci:.8f} Ha", zorder=1)

    estimators = [
        ("sCI", arrays["E_sCI"], "d", "#1f77b4"),
        ("sCI+PT2", arrays["E_PT2"], "o", "#ff7f0e"),
        ("Löwdin $H_\\mathrm{eff}$", arrays["E_Lowdin"], "v", "#9467bd"),
        ("$\\tilde{H}$ (Full Model)", arrays["E_H_tilde"], "s", "#d62728"),
        ("$E_\\mathrm{var}(\\tilde{\\psi})$", arrays["E_exact"], "^", "#2ca02c"),
    ]

    for label, E_data, marker, color in estimators:
        ax1.plot(k_vals, E_data, marker=marker, linestyle="-", color=color,
                markerfacecolor=color, markeredgecolor="white", markeredgewidth=1.0,
                label=label, zorder=2)

    ax1.set_ylabel("Total Energy (Ha)", fontweight="semibold")
    ax1.set_title(f"Effective Hamiltonian Convergence: {SYSTEM_NAME}", fontweight="bold")
    ax1.legend(loc="upper right", framealpha=0.95, edgecolor="gray")
    ax1.grid(True)
    ax1.tick_params(axis="x", labelbottom=False)
    ax1.ticklabel_format(style="plain", axis="y", useOffset=False)

    # Panel 2: Log-scale errors
    errors = {
        "sCI": np.abs(arrays["E_sCI"] - E_fci),
        "sCI+PT2": np.abs(arrays["E_PT2"] - E_fci),
        "Löwdin": np.abs(arrays["E_Lowdin"] - E_fci),
        "$\\tilde{H}$": np.abs(arrays["E_H_tilde"] - E_fci),
    }

    for (label, err), marker, color in zip(
        errors.items(), ["d", "o", "v", "s"], ["#1f77b4", "#ff7f0e", "#9467bd", "#d62728"]
    ):
        ax2.semilogy(k_vals, err, marker=marker, linestyle="-", color=color,
                    markerfacecolor=color, markeredgecolor="white", markeredgewidth=1.0, label=label)

    ax2.axhline(chem_acc, color="green", linestyle="--", alpha=0.5, linewidth=1.5,
                label=f"Chem. Acc. ({chem_acc*1000:.1f} mHa)")

    ax2.set_xlabel("Model Space Size $|S|$", fontweight="semibold")
    ax2.set_ylabel("$|E - E_\\mathrm{FCI}|$ (Ha)", fontweight="semibold")
    ax2.legend(loc="upper right", framealpha=0.95, edgecolor="gray")
    ax2.grid(True, which="both")
    ax2.set_xticks(k_vals)
    ax2.set_xticklabels([str(k) for k in k_vals])

    plt.show()


# ============================================================================
# Main
# ============================================================================


def main() -> None:
    """Execute convergence analysis workflow."""
    print(f"\n{'═'*70}")
    print(f"{'LEVER Effective Hamiltonian Study':^70}")
    print(f"{'═'*70}\n")

    if not FCIDUMP_PATH.exists():
        raise FileNotFoundError(f"FCIDUMP not found: {FCIDUMP_PATH}")

    # Initialize system
    int_ctx = lever.IntCtx(str(FCIDUMP_PATH), N_ORBITALS)
    fci_dets = lever.gen_fci_dets(N_ORBITALS, N_ALPHA, N_BETA)

    print(f"System: {SYSTEM_NAME} | Orbitals: {N_ORBITALS} | "
          f"Electrons: {N_ALPHA}α+{N_BETA}β | FCI Space: {len(fci_dets):,}")
    print(f"Nuclear Energy: {int_ctx.get_e_nuc():.8f} Ha\n")

    # Compute FCI reference
    print("Computing FCI Reference...")
    ham_fci, _ = lever.engine.hamiltonian.get_ham_ss(
        fci_dets, int_ctx=int_ctx, n_orbitals=N_ORBITALS
    )
    ep_fci = lowest_eigenpair(to_csr(ham_fci))
    E_fci = ep_fci.E + int_ctx.get_e_nuc()
    print(f"E_FCI = {E_fci:.12f} Ha\n")

    # Run convergence study and plot
    conv_data = run_convergence_study(int_ctx, fci_dets, E_fci, ep_fci.ψ)
    plot_convergence(conv_data, E_fci)

    print(f"{'═'*70}\n")


if __name__ == "__main__":
    main()