<a href="https://colab.research.google.com/github/yhkimlab/ESCSchool/blob/main/exercise1/HartreeFock_ipynb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Hartree-Fock calculations for Atoms
Ryong-Gyu Lee, Minsu Jeong, & Kaptan Rajput (KAIST EE, Prof. Yong-Hoon Kim Group)

2025.04.25 Ryong-Gyu Lee - Revised: design, OOP, occupation, spin & XC  
2024.07.03 Minsu Jeong - Revised: interface & modulation  
2024.05.14 Kaptan Rajput - Fortran-to-Python translation of Koonin Project 3

**Reference:**
1. Steven E. Koonin & Dawn C. Meredith, Computational Physics: Fortran Version (Addison-Wesley, 1990)  
   Project III: Atomic structure in the Hartree-Fock approximation

- Main code

In [2]:
# =============================================================================
# Hartree-Fock calculations for Atomic Systems
#
# Version History
# -----------------------------------------------------------------------------
# 2024.05.14   Written by Dr. Kaptan Rajput
# 2024.07.03   Revised by Minsu Jeong         - Interface & modulation
# 2025.04.25   Revised by Ryong-Gyu Lee       - Occupation, Spin & XC
#
# Reference
# -----------------------------------------------------------------------------
# [1] COMPUTATIONAL PHYSICS (Fortran Version)
#     by Steven E. Koonin & Dawn C. Meredith
# =============================================================================


import numpy as np
import matplotlib.pyplot as plt
import pandas as pd


# =============================================================================
# Periodic Table Data and UI Widgets for Element Selection
# =============================================================================
data = {
    'Atomic number': ['1', '2',
                      '3', '4', '5', '6', '7', '8', '9', '10',
                      '11', '12', '13', '14', '15', '16', '17', '18',
                      '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36',
                      '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54',
                      '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', '86'],
    'Element symbol': ['H', 'He',
                       'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne',
                       'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar',
                       'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr',
                       'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe',
                       'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn'],
    'Period': [1, 1,
               2, 2, 2, 2, 2, 2, 2, 2,
               3, 3, 3, 3, 3, 3, 3, 3,
               4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
               5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
               6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6],
    'Orbital occupation': [(2,0,0,0,0,0,0,0,0,0,0,0,0,0,0), (2,0,0,0,0,0,0,0,0,0,0,0,0,0,0),
                           (2,1,0,0,0,0,0,0,0,0,0,0,0,0,0), (2,2,0,0,0,0,0,0,0,0,0,0,0,0,0), (2,2,1,0,0,0,0,0,0,0,0,0,0,0,0), (2,2,2,0,0,0,0,0,0,0,0,0,0,0,0), (2,2,3,0,0,0,0,0,0,0,0,0,0,0,0), (2,2,4,0,0,0,0,0,0,0,0,0,0,0,0), (2,2,5,0,0,0,0,0,0,0,0,0,0,0,0), (2,2,6,0,0,0,0,0,0,0,0,0,0,0,0),
                           (2,2,6,1,0,0,0,0,0,0,0,0,0,0,0), (2,2,6,2,0,0,0,0,0,0,0,0,0,0,0), (2,2,6,2,1,0,0,0,0,0,0,0,0,0,0), (2,2,6,2,2,0,0,0,0,0,0,0,0,0,0), (2,2,6,2,3,0,0,0,0,0,0,0,0,0,0), (2,2,6,2,4,0,0,0,0,0,0,0,0,0,0), (2,2,6,2,5,0,0,0,0,0,0,0,0,0,0), (2,2,6,2,6,0,0,0,0,0,0,0,0,0,0),
                           (2,2,6,2,6,1,0,0,0,0,0,0,0,0,0), (2,2,6,2,6,2,0,0,0,0,0,0,0,0,0), (2,2,6,2,6,2,1,0,0,0,0,0,0,0,0), (2,2,6,2,6,2,2,0,0,0,0,0,0,0,0), (2,2,6,2,6,2,3,0,0,0,0,0,0,0,0), (2,2,6,2,6,1,5,0,0,0,0,0,0,0,0), (2,2,6,2,6,2,5,0,0,0,0,0,0,0,0), (2,2,6,2,6,2,6,0,0,0,0,0,0,0,0),
                           (2,2,6,2,6,2,7,0,0,0,0,0,0,0,0), (2,2,6,2,6,2,8,0,0,0,0,0,0,0,0), (2,2,6,2,6,1,10,0,0,0,0,0,0,0,0), (2,2,6,2,6,2,10,0,0,0,0,0,0,0,0), (2,2,6,2,6,2,10,1,0,0,0,0,0,0,0), (2,2,6,2,6,2,10,2,0,0,0,0,0,0,0), (2,2,6,2,6,2,10,3,0,0,0,0,0,0,0), (2,2,6,2,6,2,10,4,0,0,0,0,0,0,0),
                           (2,2,6,2,6,2,10,5,0,0,0,0,0,0,0), (2,2,6,2,6,2,10,6,0,0,0,0,0,0,0), (2,2,6,2,6,2,10,6,1,0,0,0,0,0,0), (2,2,6,2,6,2,10,6,2,0,0,0,0,0,0), (2,2,6,2,6,2,10,6,2,1,0,0,0,0,0), (2,2,6,2,6,2,10,6,2,2,0,0,0,0,0), (2,2,6,2,6,2,10,6,1,4,0,0,0,0,0), (2,2,6,2,6,2,10,6,1,5,0,0,0,0,0),
                           (2,2,6,2,6,2,10,6,2,5,0,0,0,0,0), (2,2,6,2,6,2,10,6,1,7,0,0,0,0,0), (2,2,6,2,6,2,10,6,1,8,0,0,0,0,0), (2,2,6,2,6,2,10,6,0,10,0,0,0,0,0), (2,2,6,2,6,2,10,6,1,10,0,0,0,0,0), (2,2,6,2,6,2,10,6,2,10,0,0,0,0,0), (2,2,6,2,6,2,10,6,2,10,1,0,0,0,0), (2,2,6,2,6,2,10,6,2,10,2,0,0,0,0),
                           (2,2,6,2,6,2,10,6,2,10,3,0,0,0,0), (2,2,6,2,6,2,10,6,2,10,4,0,0,0,0), (2,2,6,2,6,2,10,6,2,10,5,0,0,0,0), (2,2,6,2,6,2,10,6,2,10,6,0,0,0,0), (2,2,6,2,6,2,10,6,2,10,6,1,0,0,0), (2,2,6,2,6,2,10,6,2,10,6,2,0,0,0), (2,2,6,2,6,2,10,6,2,10,6,2,1,0,0), (2,2,6,2,6,2,10,6,2,10,6,2,1,1,0),
                           (2,2,6,2,6,2,10,6,2,10,6,2,0,3,0), (2,2,6,2,6,2,10,6,2,10,6,2,0,4,0), (2,2,6,2,6,2,10,6,2,10,6,2,0,5,0), (2,2,6,2,6,2,10,6,2,10,6,2,0,6,0), (2,2,6,2,6,2,10,6,2,10,6,2,0,7,0), (2,2,6,2,6,2,10,6,2,10,6,2,1,7,0), (2,2,6,2,6,2,10,6,2,10,6,2,0,9,0), (2,2,6,2,6,2,10,6,2,10,6,2,0,10,0),
                           (2,2,6,2,6,2,10,6,2,10,6,2,0,11,0), (2,2,6,2,6,2,10,6,2,10,6,2,0,12,0), (2,2,6,2,6,2,10,6,2,10,6,2,0,13,0), (2,2,6,2,6,2,10,6,2,10,6,2,0,14,0), (2,2,6,2,6,2,10,6,2,10,6,2,1,14,0), (2,2,6,2,6,2,10,6,2,10,6,2,2,14,0), (2,2,6,2,6,2,10,6,2,10,6,2,3,14,0), (2,2,6,2,6,2,10,6,2,10,6,2,4,14,0),
                           (2,2,6,2,6,2,10,6,2,10,6,2,5,14,0), (2,2,6,2,6,2,10,6,2,10,6,2,6,14,0), (2,2,6,2,6,2,10,6,2,10,6,2,7,14,0), (2,2,6,2,6,2,10,6,2,10,6,1,9,14,0), (2,2,6,2,6,2,10,6,2,10,6,1,10,14,0), (2,2,6,2,6,2,10,6,2,10,6,2,10,14,0), (2,2,6,2,6,2,10,6,2,10,6,2,10,14,1), (2,2,6,2,6,2,10,6,2,10,6,2,10,14,2),
                           (2,2,6,2,6,2,10,6,2,10,6,2,10,14,3), (2,2,6,2,6,2,10,6,2,10,6,2,10,14,4), (2,2,6,2,6,2,10,6,2,10,6,2,10,14,5), (2,2,6,2,6,2,10,6,2,10,6,2,10,14,6)],
}
orbital_symbol = ['1s', '2s', '2p', '3s', '3p', '4s', '3d', '4p', '5s' , '4d', '5p', '6s', '5d', '4f', '6p']
orbital_capacity = (2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 10, 14, 6)

periodic_table = pd.DataFrame(data)

# =============================================================================
# Helper functions
# =============================================================================
def setup_occupation(nelec):
    orbital_capacity = (2,2,6,2,6,2,10,6,2,10,6,2,10,14,6)
    occ = [0]*15
    num = nelec
    for i, cap in enumerate(orbital_capacity):
        if num<=0:
            break
        take = min(cap, num)
        occ[i] = take
        num -= take
    return tuple(occ)


class HartreeFockCalculator:
    """
    A class that reorganizes the Fortran-style Hartree-Fock calculation code into a Python OOP structure.
    Each subroutine is implemented as a method, and instance variables (self.…) are used instead of global variables.
    """
    def __init__(self, spin_polarized=False):
        # Set spin polarization: NSPIN=2 if spin polarized, otherwise 1
        self.spin_polarized = spin_polarized
        self.NSPIN = 2 if spin_polarized else 1

        # Maximum array sizes and number of states
        self.MAXSTP = 20000   # Maximum size of data arrays (radial points)
        self.MAXSTT = 15       # Maximum number of states (15 orbitals + total)
        if self.NSPIN == 1:
            self.NSTATE = self.MAXSTT
        else:
            self.NSTATE = int(sum(orbital_capacity)/2)

        # Main varibales
        self.E = np.zeros((self.NSTATE + 1, 9, self.NSPIN))       # Energies and related components (rows: states, columns: energy components)
        self.FOCK = np.zeros((self.MAXSTP + 1, self.NSTATE, self.NSPIN))
        self.PSTOR = np.zeros((self.MAXSTP + 1, self.NSTATE, self.NSPIN))
        self.NOCC = np.zeros((self.NSTATE + 1, self.NSPIN))  # Separate occupation per spin
        self.RHO = np.zeros((self.MAXSTP + 1, self.NSPIN))
        self.RHOTOT = np.zeros(self.MAXSTP + 1)
        self.EX = np.zeros(self.MAXSTP + 1)
        self.EC = np.zeros(self.MAXSTP + 1)
        self.VX = np.zeros((self.MAXSTP + 1, self.NSPIN))
        self.VC = np.zeros((self.MAXSTP + 1, self.NSPIN))

        # Internal variables
        self.BANDS = np.zeros((self.NSTATE + 1, self.NSPIN))
        self.PSTOR_OLD = np.zeros((self.MAXSTP + 1, self.NSTATE, self.NSPIN))
        self.NORB = np.zeros((self.MAXSTT + 1, self.NSPIN))
        self.PHI = np.zeros(self.MAXSTP + 1)
        self.PSIIN = np.zeros(self.MAXSTP + 1)
        self.PSIOUT = np.zeros(self.MAXSTP + 1)
        self.ANGMOM = np.zeros(self.NSTATE)       # Angular momentum for each istate

        self.ID = []

        if self.NSPIN == 2:
            for lab, cap in zip(orbital_symbol, orbital_capacity):
                nbasis = int(cap/2)
                for ibasis in range(nbasis):
                    self.ID.append(f"{lab}")
        else:
            self.ID = orbital_symbol.copy()


        # finally add the total-energy entry
        self.ID.append("Total")


        # Factorial array (for the 3-j coefficient)
        self.FACTRL = np.zeros(21)
        self.FACTRL[0] = 1
        for i in range(1, 21):
            self.FACTRL[i] = i * self.FACTRL[i - 1]

        # Constants (electronic constants)
        self.HBARM = 7.6359    # hbar^2/m for the electron (eV-Å²)
        self.CHARGE = 14.409   # Square of the electron charge (eV-Å)
        self.ABOHR = self.HBARM / self.CHARGE  # Bohr radius (Å)
        self.HARTREE2EV = 27.211386245988

        # Energy array column indices (as in the original Fortran code)
        self.IKEN = 0     # Kinetic energy
        self.ICEN = 1     # Centrifugal energy
        self.IVEN = 2     # Electron-nuclear interaction energy
        self.IVEE = 3     # Electron-electron interaction energy
        self.IVXC = 4     # Exchange-correlation energy
        self.IKTOT = 5    # Total kinetic energy
        self.IVTOT = 6    # Total potential energy
        self.ITOT = 7     # Total energy (single-particle)
        self.ITOT_LAST = 8  # Previous step's total energy

        # Initial angular momentum values for 15 orbitals
        init_angmom = [0, 0, 1, 0, 1, 0, 2, 1, 0, 2, 1, 0, 2, 3, 1]
        if self.NSPIN == 2:
            idx = 0
            for ang, cap in zip(init_angmom, orbital_capacity):
                nbasis = int(cap/2)
                self.ANGMOM[idx:idx+nbasis] = ang
                idx += nbasis
        else:
            n = min(len(init_angmom), len(self.ANGMOM))
            self.ANGMOM[:n] = init_angmom[:n]


        # Parameters to be set later (via set_params)
        self.Z = None        # Atomic number (nuclear charge)
        self.NR = None       # Number of radial grid points (set later)
        self.ZCHARG = None
        self.NE = None       # Total number of electrons
        self.DR = None       # Radial step size (Å)
        self.RMAX = None     # Outer radius (Å)
        self.MIX = None      # Mixing factor (user input)
        self.MIX_SCHEME = 'rho'
        self.Total_Energy_Diff = None

    def set_params(self, z, bands, nocc, h, x, c, xalpha = 1, calpha = 1):
        """
        (Equivalent to the PARAM subroutine)
        Set the atomic number and the electron occupations for each orbital.

        Parameters
        ----------
        z : int
            Atomic number (nuclear charge)
        nocc : tuple or list
            Electron occupations for 15 orbitals (order: 1s, 2s, 2p, 3s, 3p, ..., 6p)
        x_tpye : str
            Exchange functiona; type
        c_tpye : str
            Correlation functional type
        """
        self.Z = z
        self.hartree_check = h
        self.exchange_type = x
        self.correlation_type = c
        self.xalpha = xalpha
        self.calpha = calpha

        # Construct NOCC array
        if self.NSPIN == 1:
            for istate in range(15):
                self.NOCC[istate,0] = nocc[istate]
        else:
            ibasis = 0
            for istate in range(15):
                etot = nocc[istate]
                nbasis = int(orbital_capacity[istate]/2)
                if etot >= nbasis:
                    self.NOCC[ibasis:ibasis+nbasis,0] = 1
                    self.NOCC[ibasis:ibasis+nbasis,1] = (etot-nbasis)/nbasis
                else:
                    self.NOCC[ibasis:ibasis+nbasis,0] = etot/nbasis
                ibasis += nbasis


        # Construct BANDS array
        if self.NSPIN == 1:
            for istate in range(15):
                self.BANDS[istate,0] = bands[istate]
        else:
            ibasis = 0
            for istate in range(15):
                etot = bands[istate]
                nbasis = int(orbital_capacity[istate]/2)
                if etot >= nbasis:
                    self.BANDS[ibasis:ibasis+nbasis,0] = 1
                    self.BANDS[ibasis:ibasis+nbasis,1] = (etot-nbasis)/nbasis
                else:
                    self.BANDS[ibasis:ibasis+nbasis,0] = etot/nbasis
                ibasis += nbasis



        for ispin in range(self.NSPIN):
            self.NOCC[self.NSTATE,ispin] = np.sum(self.NOCC[:self.NSTATE,ispin])

        # Radial lattice parameters
        self.DR = 0.001    # Radial step size in Angstroms (example value)
        self.RMAX = 5.0   # Outer radius in Angstroms (example value)
        self.NR = int(self.RMAX / self.DR)
        if self.NR >= self.MAXSTP:
            raise ValueError(f"The radial grid has {self.NR} points, which exceeds the maximum allowed {self.MAXSTP}.")

        self.ZCHARG = self.Z * self.CHARGE
        self.NE = np.sum(self.NOCC[self.NSTATE,:])

    def hydrgn(self, zstar):
        """
        (Equivalent to the HYDRGN subroutine)
        Create hydrogen-like radial wavefunctions (PSTOR) for the effective nuclear charge zstar.
        """
        if self.NSPIN == 2:
            for ispin in range(self.NSPIN):
                ibasis = 0
                for istate in range(15):
                    nbasis = int(orbital_capacity[istate]/2)
                    for iorb in range(nbasis):
                        for ir in range(1, self.NR + 1):
                            rstar = ir * self.DR * zstar / self.ABOHR  # Scaled radius
                            erstar = np.exp(-rstar / 2)
                            erstar1 = np.exp(-rstar / 3)
                            erstar2 = np.exp(-rstar / 4)
                            erstar4 = np.exp(-rstar / 4)
                            erstar5 = np.exp(-rstar / 5)
                            erstar6 = np.exp(-rstar / 6)

                            if self.BANDS[ibasis+iorb,ispin] != 0:
                                if istate == 0: #1s
                                    self.PSTOR[ir, ibasis+iorb, ispin] = rstar * (erstar ** 2)
                                if istate == 1: #2s
                                    self.PSTOR[ir, ibasis+iorb, ispin] = (2 - rstar) * rstar * erstar
                                if istate == 2: #2p
                                    self.PSTOR[ir, ibasis+iorb, ispin] = (rstar ** 2) * erstar
                                if istate == 3: #3s
                                    self.PSTOR[ir, ibasis+iorb, ispin] = (1 - (2/3) * rstar + (2/27) * rstar ** 2) * rstar * erstar1
                                if istate == 4: #3p
                                    self.PSTOR[ir, ibasis+iorb, ispin] = (rstar - rstar ** 2 / 6) * rstar * erstar1
                                if istate == 5: #4s
                                    self.PSTOR[ir, ibasis+iorb, ispin] = (1 - (3/4) * rstar  + (1/8) * rstar ** 2 - (1/192) * rstar ** 3) * rstar * erstar2
                                if istate == 6: #3d
                                    self.PSTOR[ir, ibasis+iorb, ispin] = (rstar ** 3) * erstar1
                                if istate == 7: #4p
                                    self.PSTOR[ir, ibasis+iorb, ispin] = (rstar - rstar ** 2 / 4 + rstar ** 3 / 80) * rstar * erstar2
                                if istate == 8: #5s
                                    self.PSTOR[ir, ibasis+iorb, ispin] = (1 - (4/5) * rstar + (4/25) * rstar ** 2 -  (20/1875) * rstar ** 3  + (10/46875) * rstar ** 4 ) * rstar * erstar5
                                if istate == 9: #4d
                                    self.PSTOR[ir, ibasis+iorb, ispin] = (rstar ** 3) * (1 - rstar / 12) * erstar4
                                if istate == 10: #5p
                                    self.PSTOR[ir, ibasis+iorb, ispin] = (rstar - rstar ** 2 / 10 + rstar ** 3 / 300) * rstar * erstar5
                                if istate == 11: #6s
                                    self.PSTOR[ir, ibasis+iorb, ispin] = (1 - rstar / 6 + rstar ** 2 / 72 - rstar ** 3 / 1296 + rstar ** 4 / 31104) * rstar * erstar6
                                if istate == 12: #5d
                                    self.PSTOR[ir, ibasis+iorb, ispin] = (rstar ** 2) * (1 - rstar / 10 + rstar ** 2 / 300) * erstar5
                                if istate == 13: #4f
                                    self.PSTOR[ir, ibasis+iorb, ispin] = (rstar ** 4) * erstar4
                                if istate == 14: #6p
                                    self.PSTOR[ir, ibasis+iorb, ispin] = (rstar - rstar ** 2 / 12 + rstar ** 3 / 432) * rstar * erstar6
                    ibasis += nbasis

                # Normalize the wavefunctions for each orbital
                for istate in range(self.NSTATE):
                    if self.BANDS[istate, ispin] != 0:
                        norm = np.sum(self.PSTOR[:self.NR + 1, istate, ispin] ** 2)
                        norm = 1 / np.sqrt(norm * self.DR)
                        self.PSTOR[:self.NR + 1, istate, ispin] *= norm
        else:
            for istate in range(15):
                for ir in range(1, self.NR + 1):
                    rstar = ir * self.DR * zstar / self.ABOHR  # Scaled radius
                    erstar = np.exp(-rstar / 2)
                    erstar1 = np.exp(-rstar / 3)
                    erstar2 = np.exp(-rstar / 4)
                    erstar4 = np.exp(-rstar / 4)
                    erstar5 = np.exp(-rstar / 5)
                    erstar6 = np.exp(-rstar / 6)
                    if self.BANDS[istate,0] != 0:
                        if istate == 0:
                            self.PSTOR[ir, istate, 0] = rstar * (erstar ** 2)
                        if istate == 1:
                            self.PSTOR[ir, istate, 0] = (2 - rstar) * rstar * erstar
                        if istate == 2:
                            self.PSTOR[ir, istate, 0] = (rstar ** 2) * erstar
                        if istate == 3: # 3s
                            self.PSTOR[ir, istate, 0] = (1 - (2/3) * rstar + (2/27) * rstar ** 2) * rstar * erstar1
                        if istate == 4:
                            self.PSTOR[ir, istate, 0] = (rstar - rstar ** 2 / 6) * rstar * erstar1
                        if istate == 5: # 4s
                            self.PSTOR[ir, istate, 0] = (1 - (3/4) * rstar  + (1/8) * rstar ** 2 - (1/192) * rstar ** 3) * rstar * erstar2
                        if istate == 6:
                            self.PSTOR[ir, istate, 0] = (rstar ** 3) * erstar1
                        if istate == 7: # 4p
                            self.PSTOR[ir, istate, 0] = (rstar - rstar ** 2 / 4 + rstar ** 3 / 80) * rstar * erstar2
                        if istate == 8: # 5s
                            self.PSTOR[ir, istate, 0] = (1 - (4/5) * rstar + (4/25) * rstar ** 2 -  (20/1875) * rstar ** 3  + (10/46875) * rstar ** 4 ) * rstar * erstar5
                        if istate == 9: # 4d
                            self.PSTOR[ir, istate, 0] = (rstar ** 3) * (1 - rstar / 12) * erstar4
                        if istate == 10:
                            self.PSTOR[ir, istate, 0] = (rstar - rstar ** 2 / 10 + rstar ** 3 / 300) * rstar * erstar5
                        if istate == 11:
                            self.PSTOR[ir, istate, 0] = (1 - rstar / 6 + rstar ** 2 / 72 - rstar ** 3 / 1296 + rstar ** 4 / 31104) * rstar * erstar6
                        if istate == 12:
                            self.PSTOR[ir, istate, 0] = (rstar ** 2) * (1 - rstar / 10 + rstar ** 2 / 300) * erstar5
                        if istate == 13:
                            self.PSTOR[ir, istate, 0] = (rstar ** 3) * erstar4
                        if istate == 14:
                            self.PSTOR[ir, istate, 0] = (rstar - rstar ** 2 / 12 + rstar ** 3 / 432) * rstar * erstar6
            # Normalize the wavefunctions for each orbital
            for istate in range(self.NSTATE):
                if self.BANDS[istate, 0] != 0:
                    norm = np.sum(self.PSTOR[:self.NR + 1, istate, 0] ** 2)
                    norm = 1 / np.sqrt(norm * self.DR)
                    self.PSTOR[:self.NR + 1, istate, 0] *= norm

    def energy(self):
        """
        (Equivalent to the ENERGY subroutine)
        Calculate the energy components of the normalized single-particle wavefunctions (PSTOR);
        also compute the Fock terms, electron density, and electron potential.
        """
        self.source()   # Update the electron density
        self.xc()
        if self.hartree_check != 'None':
            self.poissn()   # Compute the electron potential based on density

        for ispin in range(self.NSPIN):
            self.E[self.NSTATE, self.IKEN:self.ITOT+1, ispin] = 0
            for istate in range(self.NSTATE):
                if self.BANDS[istate, ispin] != 0:
                    self.E[istate, self.IKEN:self.ITOT+1, ispin] = 0
                    ll1 = self.ANGMOM[istate] * (self.ANGMOM[istate] + 1)
                    pm = 0
                    for ir in range(1, self.NR + 1):
                        r = ir * self.DR
                        pz = self.PSTOR[ir, istate, ispin]
                        pz2 = pz ** 2
                        self.E[istate, self.IKEN, ispin] += (pz - pm) ** 2
                        self.E[istate, self.ICEN, ispin] += pz2 * ll1 / (r ** 2)
                        self.E[istate, self.IVEN, ispin] += -pz2 / r
                        self.E[istate, self.IVEE, ispin] += self.PHI[ir] * pz2
                        self.E[istate, self.IVXC, ispin] += self.FOCK[ir, istate, ispin] * pz
                        self.E[istate, self.IVXC, ispin] += self.VX[ir, ispin] * pz2
                        self.E[istate, self.IVXC, ispin] += self.VC[ir, ispin] * pz2
                        pm = pz

                    self.E[istate, self.IKEN, ispin] *= self.HBARM / (2 * self.DR)
                    self.E[istate, self.ICEN, ispin] *= (self.DR * self.HBARM / 2)
                    self.E[istate, self.IVEN, ispin] *= self.ZCHARG * self.DR
                    self.E[istate, self.IVEE, ispin] *= self.DR
                    self.E[istate, self.IVXC, ispin] *= self.DR

                    self.E[istate, self.IKTOT, ispin] = self.E[istate, self.IKEN, ispin] + self.E[istate, self.ICEN, ispin]
                    self.E[istate, self.IVTOT, ispin] = (self.E[istate, self.IVEN, ispin] +
                                                self.E[istate, self.IVEE, ispin] +
                                                self.E[istate, self.IVXC, ispin])
                    self.E[istate, self.ITOT, ispin] = self.E[istate, self.IKTOT, ispin] + self.E[istate, self.IVTOT, ispin]

                    for ie in range(self.IKEN, self.IVXC):
                        self.E[self.NSTATE, ie, ispin] += self.E[istate, ie, ispin] * self.NOCC[istate, ispin]

                    for ir in range(1, self.NR + 1):
                        pz = self.PSTOR[ir, istate, ispin]
                        self.E[self.NSTATE, self.IVXC, ispin] += self.NOCC[istate, ispin] * self.FOCK[ir, istate, ispin] * pz / 2

            for ir in range(1, self.NR + 1):
                self.E[self.NSTATE, self.IVXC, ispin] += self.EX[ir] * self.RHO[ir,ispin]
                self.E[self.NSTATE, self.IVXC, ispin] += self.EC[ir]


            self.E[self.NSTATE, self.IVXC, ispin] *= self.DR
            self.E[self.NSTATE, self.IVEE, ispin] /= 2
            self.E[self.NSTATE, self.IKTOT, ispin] = self.E[self.NSTATE, self.IKEN, ispin] + self.E[self.NSTATE, self.ICEN, ispin]
            self.E[self.NSTATE, self.IVTOT, ispin] = (self.E[self.NSTATE, self.IVEN, ispin] +
                                            self.E[self.NSTATE, self.IVEE, ispin] +
                                            self.E[self.NSTATE, self.IVXC, ispin])
            self.E[self.NSTATE, self.ITOT, ispin] = self.E[self.NSTATE, self.IKTOT, ispin] + self.E[self.NSTATE, self.IVTOT, ispin]

        self.Total_Energy_Diff = np.sum(self.E[self.NSTATE, self.ITOT_LAST, :] - self.E[self.NSTATE, self.ITOT, :])
        self.E[self.NSTATE, self.ITOT_LAST, :] = self.E[self.NSTATE, self.ITOT, :]

    def source(self):
        """
        (Equivalent to the SOURCE subroutine)
        Compute the electron density (RHO) using the squared wavefunctions (PSTOR).
        """
        scheme = self.MIX_SCHEME

        if scheme == 'rho':
            # update electron densities
            for ispin in range(self.NSPIN):
                for ir in range(1, self.NR + 1):
                    self.RHO[ir, ispin] = (1 - self.MIX) * self.RHO[ir, ispin]
                for istate in range(self.NSTATE):
                    if self.NOCC[istate, ispin] != 0:
                        for ir in range(1, self.NR + 1):
                            self.RHO[ir, ispin] += self.MIX * self.NOCC[istate, ispin] * (self.PSTOR[ir, istate, ispin] ** 2)
        elif scheme == 'wave':
            # update electron densities
            for ispin in range(self.NSPIN):
                for istate in range(self.NSTATE):
                    if self.NOCC[istate, ispin] != 0:
                        for ir in range(1, self.NR + 1):
                            self.PSTOR[ir, istate, ispin] = self.MIX * self.PSTOR[ir, istate, ispin] + (1 - self.MIX) * self.PSTOR_OLD[ir, istate, ispin]
                            self.RHO[ir, ispin] = self.NOCC[istate, ispin] * (self.PSTOR[ir, istate, ispin] ** 2)

        self.RHOTOT[:] = 0
        for ispin in range(self.NSPIN):
            for ir in range(1, self.NR +1):
                self.RHOTOT[ir] += self.RHO[ir, ispin]

    def xc(self):
        """
        Compute exchange and correlation contributions.
        """
        # --- exchange ---
        if self.exchange_type == 'HF':
            self.fock()
        elif self.exchange_type == 'LDA':
            self.lda_x(self.xalpha)
        elif self.exchange_type == 'None':
            pass
        else:
            raise NotImplementedError(f"XC: exchange '{self.exchange_type}' not implemented")

        # --- correlation ---
        if self.correlation_type == 'LDA':
            self.lda_c(self.calpha)
        elif self.correlation_type == 'None':
            pass
        else:
            raise NotImplementedError(f"XC: correlation '{self.correlation_type}' not implemented")

    def fock(self):
        """
        (Equivalent to the SOURCE subroutine)
        Compute the Fock terms (FOCK) using the squared wavefunctions (PSTOR).
        """
        # calculate the Fock terms (exchange)
        if self.NSPIN == 1:
                for istate in range(self.NSTATE):
                    if self.BANDS[istate, 0] != 0:
                        for ir in range(1, self.NR + 1):
                            self.FOCK[ir, istate, 0] = 0
                        for jstate in range(self.NSTATE):
                            if self.NOCC[jstate, 0] != 0:
                                l1 = self.ANGMOM[istate]
                                l2 = self.ANGMOM[jstate]
                                lstart = abs(l1 - l2)
                                lstop = l1 + l2
                                for lam in range(int(lstart), int(lstop + 1), 2):
                                    threej = self.sqr3j(l1, l2, lam)
                                    fac = (-self.CHARGE / 2) * self.NOCC[jstate, 0] * threej
                                    sum_out = 0
                                    for ir in range(1, self.NR + 1):
                                        r = ir * self.DR
                                        rlam = r ** lam
                                        term = self.PSTOR[ir, jstate, 0] * self.PSTOR[ir, istate, 0] * rlam / 2
                                        sum_out += term
                                        df = self.PSTOR[ir, jstate, 0] * fac * sum_out * self.DR / (rlam * r)
                                        self.FOCK[ir, istate, 0] += df
                                        sum_out += term
                                    sum_in = 0
                                    for ir in range(self.NR, 0, -1):
                                        r = ir * self.DR
                                        rlam1 = r ** (lam + 1)
                                        term = self.PSTOR[ir, jstate, 0] * self.PSTOR[ir, istate, 0] / (rlam1 * 2)
                                        sum_in += term
                                        df = self.PSTOR[ir, jstate, 0] * fac * sum_in * self.DR * rlam1 / r
                                        self.FOCK[ir, istate, 0] += df
                                        sum_in += term
        else:
            for ispin in range(self.NSPIN):
                for istate in range(self.NSTATE):
                    if self.BANDS[istate, ispin] != 0:
                        for ir in range(1, self.NR + 1):
                            self.FOCK[ir, istate, ispin] = 0
                        for jstate in range(self.NSTATE):
                            if self.NOCC[jstate, ispin] != 0:
                                l1 = self.ANGMOM[istate]
                                l2 = self.ANGMOM[jstate]
                                lstart = abs(l1 - l2)
                                lstop = l1 + l2
                                for lam in range(int(lstart), int(lstop + 1), 2):
                                    threej = self.sqr3j(l1, l2, lam)
                                    fac = (-self.CHARGE) * threej * self.NOCC[jstate,ispin]
                                    sum_out = 0
                                    for ir in range(1, self.NR + 1):
                                        r = ir * self.DR
                                        rlam = r ** lam
                                        term = self.PSTOR[ir, jstate, ispin] * self.PSTOR[ir, istate, ispin] * rlam / 2
                                        sum_out += term
                                        df = self.PSTOR[ir, jstate, ispin] * fac * sum_out * self.DR / (rlam * r)
                                        self.FOCK[ir, istate, ispin] += df
                                        sum_out += term
                                    sum_in = 0
                                    for ir in range(self.NR, 0, -1):
                                        r = ir * self.DR
                                        rlam1 = r ** (lam + 1)
                                        term = self.PSTOR[ir, jstate, ispin] * self.PSTOR[ir, istate, ispin] / (rlam1 * 2)
                                        sum_in += term
                                        df = self.PSTOR[ir, jstate, ispin] * fac * sum_in * self.DR * rlam1 / r
                                        self.FOCK[ir, istate, ispin] += df
                                        sum_in += term


    def lda_x(self, alpha=1.0):
        '''
        Slater–Dirac exchange with optional spin interpolation.
        If spin_polarized=True, uses densities self.RHO[:,0] and self.RHO[:,1].
        '''
        # Slater–Dirac prefactors:
        Cx = alpha * 0.75 * (3.0/np.pi)**(1.0/3.0)
        Bx = alpha *        (3.0/np.pi)**(1.0/3.0)
        # spin‐limit factor 2^(1/3):
        CXF = 2.0**(1.0/3.0)
        # interpolation normalizer: (2^(4/3) - 2)
        norm_f = 2.0**(4.0/3.0) - 2.0

        for ir in range(1, self.NR+1):
            r = ir * self.DR
            # radial → volume densities (e/Bohr^3):
            if self.NSPIN == 2:
                rho_u = self.RHO[ir,0] / (4*np.pi*r*r)
                rho_d = self.RHO[ir,1] / (4*np.pi*r*r)
            else:
                rho_tot = self.RHO[ir,0]  / (4*np.pi*r*r)
                rho_u = rho_tot
                rho_d = 0.0

            # convert to Bohr^−3
            ru = rho_u * self.ABOHR**3
            rd = rho_d * self.ABOHR**3
            D  = ru + rd

            if D < 1e-12:
                # zero out both spins
                self.EX[ir] = 0.0
                self.VX[ir,:] = 0.0
                continue

            # unpolarized (ζ=0) exchange energy & potential per volume
            ex0 = -Cx * D**(1.0/3.0)               # eH/Bohr^3 per "single electron"
            vx0 = -Bx * D**(1.0/3.0)               # eH

            # fully polarized (ζ=±1) limits
            ex1 = CXF * ex0
            vx1 = CXF * vx0

            # convert to eV and to radial line density:
            ex0_vol = ex0 * self.HARTREE2EV    # eV per single electron
            ex1_vol = ex1 * self.HARTREE2EV

            vx0_eV  = vx0  * self.HARTREE2EV
            vx1_eV  = vx1  * self.HARTREE2EV

            if self.NSPIN == 1:
                # exactly your original unpolarized code
                self.EX[ir] = ex0_vol
                self.VX[ir,0] = vx0_eV
                continue

            # spin‐polarization interpolation
            z   = (ru - rd) / D
            Fz  = ((1+z)**(4/3) + (1-z)**(4/3) - 2.0) / norm_f
            Fzp = (4.0/3.0)*((1+z)**(1/3) - (1-z)**(1/3)) / norm_f

            # energy‐density for each spin (same radial line density)
            self.EX[ir] = (ex0_vol + Fz*(ex1_vol - ex0_vol))

            # potentials per spin
            delta_ex = ex1_vol - ex0_vol
            self.VX[ir,0] = (
                vx0_eV
                + Fz*(vx1_eV - vx0_eV)
                + (1 - z)*Fzp*delta_ex
            )
            self.VX[ir,1] = (
                vx0_eV
                + Fz*(vx1_eV - vx0_eV)
                - (1 + z)*Fzp*delta_ex
            )

    def lda_c(self, alpha = 1):
            """
            LDA correlation, Ceperley–Alder via Perdew–Zunger parameterization (unpolarized).
            E_c (energy density) in eV/Å^3, V_c (potential) in eV.
            """
            # --- constants from PZ81 (all in Hartree units) ---
            c1p0529 = 1.0529
            c3334   = 0.3334
            c2846   = 0.2846
            con10   = 7.3703 / 6.0
            con11   = 1.3336 / 3.0

            c0622   = 0.0622
            c004    = 0.004
            c0232   = 0.0232
            con2    = 0.008 / 3.0
            con3    = 0.3502 / 3.0
            con4    = 0.0504 / 3.0

            # prefactor for r_s
            pref = alpha * (3.0 / (4.0 * np.pi))**(1.0/3.0)

            for ir in range(1, self.NR+1):
                #rho = self.RHOTOT[ir] * self.ABOHR
                r = ir * self.DR
                rho = self.RHOTOT[ir]
                rho = rho / (4.0 * np.pi * r ** 2)

                if rho <= 1e-9:
                    self.EC[ir] = 0.0
                    self.VC[ir,:] = 0.0
                    continue

                # Wigner–Seitz radius in Bohr
                rho = rho * self.ABOHR**3 # e/Bohr^3
                rs = pref * rho**(-1.0/3.0)

                if rs >= 1.0:
                    # high-density branch
                    sqrs = np.sqrt(rs)
                    # correlation energy per electron (Hartree)
                    ec_h = -c2846 / (1.0 + c1p0529 * sqrs + c3334 * rs)
                    # correlation potential per electron (Hartree)
                    vc_h = ec_h * (1.0 + con10 * sqrs + con11 * rs) / (1.0 + c1p0529 * sqrs + c3334 * rs)
                else:
                    # low-density branch
                    rslog = np.log(rs)
                    ec_h = (c0622 + c004*rs) * rslog - 0.096 - c0232*rs
                    vc_h = (c0622 + con2*rs) * rslog - con3 - con4*rs

                self.EC[ir] = rho * ec_h * (4.0 * np.pi * r ** 2) * self.HARTREE2EV / self.ABOHR
                self.VC[ir,:] = vc_h * self.HARTREE2EV

    def poissn(self):
        """
        (Equivalent to the POISSN subroutine)
        Solve the Poisson equation for the electron potential (PHI) based on the electron density (RHO).
        """

        summ = 0
        for ir in range(1, self.NR + 1):
            summ += self.RHOTOT[ir] / ir
        con = self.DR ** 2 / 12
        sm = 0
        sz = -self.CHARGE * self.RHOTOT[1] / self.DR
        self.PHI[0] = 0
        self.PHI[1] = summ * self.CHARGE * self.DR
        for ir in range(1, self.NR):
            sp = -self.CHARGE * self.RHOTOT[ir + 1] / ((ir + 1) * self.DR)
            self.PHI[ir + 1] = 2 * self.PHI[ir] - self.PHI[ir - 1] + con * (10 * sz + sp + sm)
            sm = sz
            sz = sp
        m_val = (self.PHI[self.NR] - self.PHI[self.NR - 10]) / (10 * self.DR)
        for ir in range(1, self.NR + 1):
            r = ir * self.DR
            self.PHI[ir] = self.PHI[ir] / r - m_val


    def sngwfn(self, state, energy, spin):
        """
        (Equivalent to the SNGWFN subroutine)
        Solve the single-particle wavefunction for a given state and energy as a boundary value problem
        using the Numerov algorithm along with the Green’s function technique and normalization.

        Parameters
        ----------
        state : int
            Index of the orbital to be calculated.
        energy : float
            Single-particle energy for the orbital.
        """
        drhb = self.DR ** 2 / (self.HBARM * 6)
        ll1 = self.ANGMOM[state] * (self.ANGMOM[state] + 1) * self.HBARM / 2

        # Outward homogeneous solution using the Numerov algorithm
        k2m = 0
        k2z = drhb * (energy - self.PHI[1] - self.VX[1, spin] - self.VC[1, spin] + (self.ZCHARG - ll1 / self.DR) / self.DR)
        self.PSIOUT[0] = 0
        self.PSIOUT[1] = self.DR * 1e-9
        for ir in range(2, self.NR + 1):
            r = self.DR * ir
            k2p = drhb * (energy - self.PHI[ir] - self.VX[ir, spin] - self.VC[ir, spin] + (self.ZCHARG - ll1 / r) / r)
            self.PSIOUT[ir] = (self.PSIOUT[ir - 1] * (2 - 10 * k2z) - self.PSIOUT[ir - 2] * (1 + k2m)) / (1 + k2p)
            k2m = k2z
            k2z = k2p

        # Inward homogeneous solution using the Numerov algorithm
        k2p = 0
        r = (self.NR - 1) * self.DR
        k2z = drhb * (energy - self.PHI[self.NR - 1] - self.VX[self.NR - 1, spin] - self.VC[self.NR - 1, spin] + (self.ZCHARG - ll1 / r) / r)
        self.PSIIN[self.NR] = 0
        self.PSIIN[self.NR - 1] = self.DR * 1e-9
        for ir in range(self.NR - 2, 0, -1):
            r = self.DR * ir
            k2m = drhb * (energy - self.PHI[ir] - self.VX[ir, spin] - self.VC[ir, spin] + (self.ZCHARG - ll1 / r) / r)
            self.PSIIN[ir] = (self.PSIIN[ir + 1] * (2 - 10 * k2z) - self.PSIIN[ir + 2] * (1 + k2p)) / (1 + k2m)
            k2p = k2z
            k2z = k2m

        nr2 = self.NR // 2

        if (self.exchange_type == 'HF'):
            # Calculate the Wronskian at the mid-point of the grid
            wronsk = ((self.PSIIN[nr2 + 1] - self.PSIIN[nr2 - 1]) / (2 * self.DR)) * self.PSIOUT[nr2]
            wronsk -= ((self.PSIOUT[nr2 + 1] - self.PSIOUT[nr2 - 1]) / (2 * self.DR)) * self.PSIIN[nr2]

            # Outward integration using the Green's function approach
            summ = 0
            for ir in range(1, self.NR + 1):
                term = -self.PSIOUT[ir] * self.FOCK[ir, state, spin] / 2
                summ += term
                self.PSTOR[ir, state, spin] = self.PSIIN[ir] * summ * self.DR
                summ += term

            # Inward integration using the Green's function approach
            summ = 0
            for ir in range(self.NR, 0, -1):
                term = -self.PSIIN[ir] * self.FOCK[ir, state, spin] / 2
                summ += term
                self.PSTOR[ir, state, spin] = (self.PSTOR[ir, state, spin] + self.PSIOUT[ir] * summ * self.DR) / wronsk
                summ += term

        else:

            kim1 = (energy - self.PHI[1] - self.VX[1, spin] - self.VC[1, spin] + (self.ZCHARG - ll1 / self.DR) / self.DR)
            imax = self.NR - 3
            for ir in range(2, self.NR + 1):
                r = self.DR * ir
                ki = energy - self.PHI[ir] - self.VX[ir, spin] - self.VC[ir, spin] + (self.ZCHARG - ll1 / r) / r
                if (ki * kim1 < 0):
                    imax = ir
                    break

            scale = self.PSIIN[imax] / self.PSIOUT[imax]
            for ir in range(1, self.NR+1):
                    if ir <= imax:
                        self.PSTOR[ir,state,spin] = self.PSIOUT[ir] * scale
                    else:
                        self.PSTOR[ir,state,spin] = self.PSIIN[ir]


        # Normalize the solution for this state
        norm = np.sum(self.PSTOR[:self.NR + 1, state, spin] ** 2)
        norm = 1 / np.sqrt(norm * self.DR)
        self.PSTOR[:self.NR + 1, state, spin] *= norm


        # Orthogonalize specific states (if necessary, e.g., 1S-2S, 1S-3S, etc.)
        if self.NSPIN == 1:
            if state == 1:  # Orthogonalize 1S and 2S
                summ_val = np.sum(self.PSTOR[:self.NR + 1, 0, spin] * self.PSTOR[:self.NR + 1, 1, spin]) * self.DR
                self.PSTOR[:self.NR + 1, 1, spin] -= summ_val * self.PSTOR[:self.NR + 1, 0, spin]
            if state == 3:  # Orthogonalize 1S & 3S and 2S & 3S
                summ_val = np.sum(self.PSTOR[:self.NR + 1, 0, spin] * self.PSTOR[:self.NR + 1, 3, spin]) * self.DR
                self.PSTOR[:self.NR + 1, 3, spin] -= summ_val * self.PSTOR[:self.NR + 1, 0, spin]
                summ_val = np.sum(self.PSTOR[:self.NR + 1, 1, spin] * self.PSTOR[:self.NR + 1, 3, spin]) * self.DR
                self.PSTOR[:self.NR + 1, 3, spin] -= summ_val * self.PSTOR[:self.NR + 1, 1, spin]
            if state == 4:  # Orthogonalize 2P and 3P
                summ_val = np.sum(self.PSTOR[:self.NR + 1, 2, spin] * self.PSTOR[:self.NR + 1, 4, spin]) * self.DR
                self.PSTOR[:self.NR + 1, 4, spin] -= summ_val * self.PSTOR[:self.NR + 1, 2, spin]
            if state == 5:  # Orthogonalize 1S, 2S, 3S with 4S
                summ_val = np.sum(self.PSTOR[:self.NR + 1, 0, spin] * self.PSTOR[:self.NR + 1, 5, spin]) * self.DR
                self.PSTOR[:self.NR + 1, 5, spin] -= summ_val * self.PSTOR[:self.NR + 1, 0, spin]
                summ_val = np.sum(self.PSTOR[:self.NR + 1, 1, spin] * self.PSTOR[:self.NR + 1, 5, spin]) * self.DR
                self.PSTOR[:self.NR + 1, 5, spin] -= summ_val * self.PSTOR[:self.NR + 1, 1, spin]
                summ_val = np.sum(self.PSTOR[:self.NR + 1, 3, spin] * self.PSTOR[:self.NR + 1, 5, spin]) * self.DR
                self.PSTOR[:self.NR + 1, 5, spin] -= summ_val * self.PSTOR[:self.NR + 1, 3, spin]
        else:
            if state in [1]:  # Orthogonalize 1S and 2S
                summ_val = np.sum(self.PSTOR[:self.NR + 1, 0, spin] * self.PSTOR[:self.NR + 1, state, spin]) * self.DR
                self.PSTOR[:self.NR + 1, state, spin] -= summ_val * self.PSTOR[:self.NR + 1, 0, spin]
            if state in [5]:  # Orthogonalize 1S & 3S and 2S & 3S
                summ_val = np.sum(self.PSTOR[:self.NR + 1, 0, spin] * self.PSTOR[:self.NR + 1, state, spin]) * self.DR
                self.PSTOR[:self.NR + 1, state, spin] -= summ_val * self.PSTOR[:self.NR + 1, 0, spin]
                summ_val = np.sum(self.PSTOR[:self.NR + 1, 1, spin] * self.PSTOR[:self.NR + 1, state, spin]) * self.DR
                self.PSTOR[:self.NR + 1, state, spin] -= summ_val * self.PSTOR[:self.NR + 1, 1, spin]
            if state in [6,7,8] :  # Orthogonalize 2P and 3P
                summ_val = np.sum(self.PSTOR[:self.NR + 1, 2, spin] * self.PSTOR[:self.NR + 1, state, spin]) * self.DR
                self.PSTOR[:self.NR + 1, state, spin] -= summ_val * self.PSTOR[:self.NR + 1, 2, spin]
                summ_val = np.sum(self.PSTOR[:self.NR + 1, 3, spin] * self.PSTOR[:self.NR + 1, state, spin]) * self.DR
                self.PSTOR[:self.NR + 1, state, spin] -= summ_val * self.PSTOR[:self.NR + 1, 3, spin]
                summ_val = np.sum(self.PSTOR[:self.NR + 1, 4, spin] * self.PSTOR[:self.NR + 1, state, spin]) * self.DR
                self.PSTOR[:self.NR + 1, state, spin] -= summ_val * self.PSTOR[:self.NR + 1, 4, spin]
            if state in [9]:  # Orthogonalize 1S, 2S, 3S with 4S
                summ_val = np.sum(self.PSTOR[:self.NR + 1, 0, spin] * self.PSTOR[:self.NR + 1, state, spin]) * self.DR
                self.PSTOR[:self.NR + 1, state, spin] -= summ_val * self.PSTOR[:self.NR + 1, 0, spin]
                summ_val = np.sum(self.PSTOR[:self.NR + 1, 1, spin] * self.PSTOR[:self.NR + 1, state, spin]) * self.DR
                self.PSTOR[:self.NR + 1, state, spin] -= summ_val * self.PSTOR[:self.NR + 1, 1, spin]
                summ_val = np.sum(self.PSTOR[:self.NR + 1, 5, spin] * self.PSTOR[:self.NR + 1, state, spin]) * self.DR
                self.PSTOR[:self.NR + 1, state, spin] -= summ_val * self.PSTOR[:self.NR + 1, 5, spin]


        # Normalize the solution for this state
        norm = np.sum(self.PSTOR[:self.NR + 1, state, spin] ** 2)
        norm = 1 / np.sqrt(norm * self.DR)
        self.PSTOR[:self.NR + 1, state, spin] *= norm


    def sqr3j(self, l1, l2, lam):
        """
        (Equivalent to the SQR3J subroutine)
        Calculate the square of the 3-j coefficient used in the exchange energy.

        Parameters
        ----------
        l1, l2, lam : float or int
            Angular momentum and related parameters.
        Returns
        -------
        threej : float
            The calculated 3-j coefficient.
        """
        imax = l1 + l2 + lam + 1
        p = (l1 + l2 + lam) / 2
        delta = self.FACTRL[int(l1 + l2 - lam)] * self.FACTRL[int(-l1 + l2 + lam)]
        delta = delta * self.FACTRL[int(l1 - l2 + lam)] / self.FACTRL[int(imax)]
        threej = delta * (self.FACTRL[int(p)] ** 2)
        threej /= (self.FACTRL[int(p - l1)] ** 2)
        threej /= (self.FACTRL[int(p - l2)] ** 2)
        threej /= (self.FACTRL[int(p - lam)] ** 2)
        return threej

    def run_simulation(self, max_iter=300, tol=1e-7):
        """
        (Equivalent to the ARCHON subroutine)
        Perform the Hartree-Fock calculation iteratively. The simulation starts with an initial guess
        (hydrogen-like orbitals) and then adjusts the effective nuclear charge using the Virial theorem.
        Iteration stops when the absolute energy difference (ediff) is less than tol.

        During each iteration the energy difference is printed.

        Parameters
        ----------
        max_iter : int
            Maximum number of iterations.
        tol : float
            Convergence criterion for total energy difference.
        """
        # Initial guess: set MIX factor and compute hydrogen-like orbitals and energy
        self.MIX = self.MIX if self.MIX is not None else 1.0
        zstar = self.Z
        self.hydrgn(zstar)
        self.energy()

        # Adjust effective nuclear charge using the Virial theorem
        zstar = -self.Z * (np.sum(self.E[self.NSTATE, self.IVTOT, :]) / (2 * np.sum(self.E[self.NSTATE, self.IKTOT, :])))
        self.hydrgn(zstar)
        self.energy()

        print(f"Nuclear Charge (Z): {self.Z}")
        print(f"Number of Electrons: {self.NE}")
        print("Occupation of the States:")

        if self.NSPIN == 2:
            istate = 0
            for iorb in range(15):
                nbasis = int(orbital_capacity[iorb]/2)
                for ispin in range(self.NSPIN):
                    for ibasis in range(nbasis):
                        if self.NOCC[istate+ibasis, ispin] != 0:
                            symbol = 'up' if ispin == 0 else 'dw'
                            print(f"  {self.ID[istate]} {symbol}: {self.NOCC[istate,ispin]:3.2f}")
                istate += nbasis
        else:
            for istate in range(15):
                for ispin in range(self.NSPIN):
                    if self.NOCC[istate,ispin] != 0:
                        symbol = 'up' if ispin == 0 else 'dw'
                        print(f"  {self.ID[istate]} {symbol}: {self.NOCC[istate,ispin]:2.1f}")

        print(f"Radial Step (Å): {self.DR}")
        print(f"Maximum Radius (Å): {self.DR * self.NR}")
        print(f"Effective Nuclear Charge (Z*): {zstar}")
        print("All energies are in electron volts (eV).\n")
        print("*" * 50)
        print("Iteration: 0")

        iter_count = 0
        while iter_count < max_iter:
            iter_count += 1
            for ispin in range(self.NSPIN):
                for istate in range(self.NSTATE):
                    if self.BANDS[istate, ispin] != 0:
                        eigenvalue = self.E[istate, self.ITOT, ispin]
                        if eigenvalue > 0:
                            eigenvalue = -10  # Ensure the orbital remains bound
                        self.sngwfn(istate, eigenvalue, ispin)

            self.energy()

            # summarize the energies
            total_energy = np.sum(self.E[self.NSTATE, self.ITOT,:])
            kinetic_energy = np.sum(self.E[self.NSTATE, self.IKTOT,:])
            ion_energy = np.sum(self.E[self.NSTATE, self.IVEN, :])
            hartree_energy = np.sum(self.E[self.NSTATE, self.IVEE, :])
            exchange_energy = np.sum(self.E[self.NSTATE, self.IVXC, :])
            electrostatic_energy = np.sum(self.E[self.NSTATE, self.IVTOT, :])

            print(f"Iteration: {iter_count} \t Total Energy: {total_energy:12.8f} eV \t Ediff: {self.Total_Energy_Diff:12.8f} eV")
            if abs(self.Total_Energy_Diff) <= tol:
                print("*" * 50)
                print(f"Convergence reached at iteration: {iter_count}")
                print(f"Final Converged Total Energy is : {total_energy} eV")
                break

        # Summarize final output of energies
        total_energy = np.sum(self.E[self.NSTATE, self.ITOT,:])
        kinetic_energy = np.sum(self.E[self.NSTATE, self.IKTOT,:])
        ion_energy = np.sum(self.E[self.NSTATE, self.IVEN, :])
        hartree_energy = np.sum(self.E[self.NSTATE, self.IVEE, :])
        exchange_energy = np.sum(self.E[self.NSTATE, self.IVXC, :])
        electrostatic_energy = np.sum(self.E[self.NSTATE, self.IVTOT, :])

        header = f"{'States':<9s}{'Nocc':<6s} {'Ktot':>14s} {'Ven':>14s} {'Vee':>14s} {'Vex':>14s} {'Vtot':>14s} {'Etot':>14s}"
        print(header)

        if self.NSPIN == 2:
            istate = 0
            for iorb in range(15):
                nbasis = int(orbital_capacity[iorb]/2)
                for ispin in range(self.NSPIN):
                    for ibasis in range(nbasis):
                        if self.BANDS[istate+ibasis, ispin] != 0:
                            symbol = 'up' if ispin == 0 else 'dw'
                            line = f"{self.ID[istate]:<3s} {symbol}   {self.NOCC[istate+ibasis, ispin]:<6.2f} {self.E[istate+ibasis, self.IKTOT, ispin]:>14.8f} " \
                                f"{self.E[istate+ibasis, self.IVEN, ispin]:>14.8f} {self.E[istate+ibasis, self.IVEE, ispin]:>14.8f} " \
                                f"{self.E[istate+ibasis, self.IVXC, ispin]:>14.8f} {self.E[istate+ibasis, self.IVTOT, ispin]:>14.8f} " \
                                f"{self.E[istate+ibasis, self.ITOT, ispin]:>14.8f}"
                            print(line)
                istate += nbasis
        else:
            for istate in range(self.NSTATE):
                for ispin in range(self.NSPIN):
                    if self.BANDS[istate, ispin] != 0:
                        symbol = 'up' if ispin == 0 else 'dw'
                        line = f"{self.ID[istate]:<3s} {symbol}   {self.NOCC[istate, ispin]:<6.2f} {self.E[istate, self.IKTOT, ispin]:>14.8f} " \
                            f"{self.E[istate, self.IVEN, ispin]:>14.8f} {self.E[istate, self.IVEE, ispin]:>14.8f} " \
                            f"{self.E[istate, self.IVXC, ispin]:>14.8f} {self.E[istate, self.IVTOT, ispin]:>14.8f} " \
                            f"{self.E[istate, self.ITOT, ispin]:>14.8f}"
                        print(line)
        print("\nTOTALS")
        print(f"NE : {self.NE}")
        print(f"KTOT : {kinetic_energy} eV")
        print(f"VENTOT : {ion_energy} eV")
        print(f"VEETOT : {hartree_energy} eV")
        print(f"VXCTOT : {exchange_energy} eV")
        print(f"VENTOT + VEETOT + VEXTOT : {electrostatic_energy} eV")
        print(f"Total Energy : {total_energy} eV")


- Interface

In [5]:
from io import StringIO
import sys
import ipywidgets as widgets
from IPython.display import display, Markdown, clear_output


'''
Last update:

25.06.26. Converted Streamlit version to IPython interface (Ryong-Gyu Lee)

'''


# --- Helper function for running HF calculation ---
def run_hf(atom_num, num_band, num_elec,
           hartree, exchange, correlation,
           spin, mix, max_iter, tol, rmax, dr):

    xalpha = 1
    calpha = 1
    calc = HartreeFockCalculator(spin_polarized=spin)
    bands = setup_occupation(num_band)
    occ   = setup_occupation(num_elec)

    if exchange == 'Xα':
        exchange = 'LDA'
        xalpha = 2/3

    calc.set_params(atom_num, bands, occ,
                    hartree if hartree != 'None' else 'None',
                    exchange if exchange != 'None' else 'None',
                    correlation if correlation != 'None' else 'None',
                    xalpha, calpha)
    calc.MIX  = mix
    calc.DR   = dr
    calc.RMAX = rmax
    calc.NR   = int(rmax/dr)
    buf = StringIO()
    old = sys.stdout
    sys.stdout = buf
    calc.run_simulation(max_iter=max_iter, tol=tol)
    sys.stdout = old
    logs = buf.getvalue()
    return calc, logs

# --- Widgets ---
period_widget   = widgets.Dropdown(
    options=sorted(periodic_table['Period'].unique()),
    description='Period:'
)
elements_widget = widgets.Dropdown(description='Element:')

# Initialize element options based on period
initial_period = period_widget.value
initial_elements = sorted(
    periodic_table[periodic_table['Period'] == initial_period]['Element symbol'].unique()
)

elements_widget.options = initial_elements

if initial_elements:
    elements_widget.value = initial_elements[0]

# Update elements when period changes
def on_period_change(change):
    new_period = change['new']
    elems = periodic_table[periodic_table['Period'] == new_period]['Element symbol'].unique()
    elements_widget.options = sorted(elems)
period_widget.observe(on_period_change, names='value')

# Widgets for parameters; placeholders for max values to be set
atom_num_widget = widgets.FloatSlider(description='Num Elec:', min=1.0, max=1.0, step=0.05)
num_band_widget = widgets.IntSlider(description='Num Band:', min=1, max=50, step=1)
hartree_widget     = widgets.Dropdown(options=['Poisson', 'None'], description='Coulomb:')
exchange_widget    = widgets.Dropdown(options=['HF', 'Xα', 'LDA', 'None'], description='Exchange:')
#xalpha_widget      = widgets.FloatSlider(description='X α:', min=0.0, max=2.0, step=0.02)
correlation_widget = widgets.Dropdown(options=['LDA', 'None'], value='None', description='Correlation:')
#calpha_widget      = widgets.FloatSlider(description='C α:', min=0.0, max=2.0, step=0.02)
spin_widget        = widgets.Checkbox(description='Spin Polarized')
mix_widget         = widgets.FloatSlider(description='SCF Mix:', value=0.5, min=0.0, max=1.0, step=0.01)
max_iter_widget    = widgets.IntText(description='Max Iter:', value=300)
tol_widget         = widgets.FloatText(description='Tol [eV]:', value=1e-7)
rmax_widget        = widgets.FloatSlider(description='Rmax [Å]:', value=5.0, min=3.0, max=30.0, step=0.1)
dr_widget          = widgets.FloatText(description='DR [Å]:', value=0.02)

# Update Num Elec and Num Band ranges when element changes
def on_element_change(change):
    # Get atomic number from selected period and element
    period = period_widget.value
    elem = change['new']
    atom_num = int(
        periodic_table.loc[
            (periodic_table['Period']==period) &
            (periodic_table['Element symbol']==elem),
            'Atomic number'
        ].iloc[0]
    )
    # Num Elec slider: from 1 to atom_num+2, default atom_num
    atom_num_widget.min = 1.0
    atom_num_widget.max = atom_num + 2.0
    atom_num_widget.value = float(atom_num)
    # Num Band slider: min and value = int(atom_num)
    num_band_widget.min = atom_num
    num_band_widget.value = atom_num

elements_widget.observe(on_element_change, names='value')

# Trigger initial element-based setup
on_element_change({'new': elements_widget.options[0] if elements_widget.options else None})

run_button  = widgets.Button(description='Calculate')
output_log  = widgets.Output()
output_plot = widgets.Output()
plot_choice = widgets.Dropdown(
    options=['Electron Density', 'Wavefunctions', 'Hartree Potential', 'XC Potential'],
    description='Plot:'
)

# Callback for run_button
def on_run_clicked(b):
    with output_log:
        clear_output()
        period = period_widget.value
        elem   = elements_widget.value
        atom_num = int(
            periodic_table.loc[
                (periodic_table['Period'] == period) &
                (periodic_table['Element symbol'] == elem),
                'Atomic number'
            ].iloc[0]
        )
        num_elec = atom_num_widget.value
        calc, logs = run_hf(
            atom_num,
            num_band_widget.value,
            num_elec,
            hartree_widget.value,
            exchange_widget.value,
            correlation_widget.value,
            spin_widget.value,
            mix_widget.value,
            max_iter_widget.value,
            tol_widget.value,
            rmax_widget.value,
            dr_widget.value
        )
        run_button.calc = calc
        print(logs)
    update_plot(None)

# Callback for plot_choice
def update_plot(change):
    if not hasattr(run_button, 'calc'):
        return
    calc = run_button.calc
    r = np.linspace(0, calc.NR * calc.DR, calc.NR)
    with output_plot:
        clear_output()
        choice = plot_choice.value
        if choice == 'Electron Density':
            fig, ax = plt.subplots(figsize=(5, 4))
            ax.plot(r, calc.RHOTOT[:calc.NR])
            ax.set(title='Electron density', xlabel='Radius [Å]', ylabel=r'$4\pi r^2 \rho(r)$')
        elif choice == 'Wavefunctions':
            nspin = calc.NSPIN
            fig, axs = plt.subplots(1, nspin, figsize=(5 * nspin, 4))
            if nspin == 1: axs = [axs]
            for ispin in range(nspin):
                for i in range(calc.NSTATE):
                    if calc.BANDS[i, ispin] != 0:
                        sign = 1 if abs(calc.PSTOR[:, i, ispin].max()) > abs(calc.PSTOR[:, i, ispin].min()) else -1
                        axs[ispin].plot(r, sign * calc.PSTOR[:calc.NR, i, ispin], label=f'State {calc.ID[i]}')
                # only show legend if there are plotted states
                handles, labels = axs[ispin].get_legend_handles_labels()
                if labels:
                    axs[ispin].legend(fontsize=8)
                axs[ispin].set(title='Wavefunctions', xlabel='r [Å]')

        elif choice == 'Hartree Potential':
            fig, ax = plt.subplots(figsize=(5, 4))
            ax.plot(r[1:], -calc.PHI[1:calc.NR])
            ax.set(title='Hartree Potential', xlabel='Radius [Å]', ylabel=r'$\Phi(r)$')
        elif choice == 'XC Potential':
            nspin = calc.NSPIN
            fig, axs = plt.subplots(1, nspin, figsize=(5 * nspin, 4))
            if nspin == 1: axs = [axs]
            for ispin in range(nspin):
                axs[ispin].plot(r[1:], calc.VX[1:calc.NR, ispin] + calc.VC[1:calc.NR, ispin])
                axs[ispin].set(title='XC Potential', xlabel='Radius [Å]')
        plt.tight_layout()
        plt.show()

run_button.on_click(on_run_clicked)
plot_choice.observe(update_plot, names='value')

# Arrange layout
ui = widgets.VBox([
    widgets.HBox([period_widget, elements_widget]),
    widgets.HBox([atom_num_widget, num_band_widget]),
    widgets.HBox([hartree_widget]),
    widgets.HBox([exchange_widget, correlation_widget]),
    widgets.HBox([mix_widget, max_iter_widget]),
    widgets.HBox([tol_widget, rmax_widget, dr_widget]),
    widgets.HBox([spin_widget]),
    run_button,
    output_log,
    plot_choice,
    output_plot
])

display(Markdown('### Hartree–Fock calculations for Atomic Systems'))
display(ui)


### Hartree–Fock calculations for Atomic Systems

VBox(children=(HBox(children=(Dropdown(description='Period:', options=(np.int64(1), np.int64(2), np.int64(3), …