In [1]:
import os, sys, platform, socket
print("Host:", socket.gethostname())
print("Python:", sys.executable)
print("CUDA_VISIBLE_DEVICES:", os.getenv("CUDA_VISIBLE_DEVICES"))

Host: submit63.mit.edu
Python: /work/submit/tzaklama/ct_NNVMC/venvs/ct_nnvmc_venv/bin/python
CUDA_VISIBLE_DEVICES: 0


In [2]:
import os, socket, sys
print("Host:", socket.gethostname())      # MUST be c1234 (NOT submit66)
print("Python:", sys.executable)          # /work/.../ct_nnvmc_venv/bin/python
!ls /dev/nvidia*                           # should list /dev/nvidia0, etc.
!nvidia-smi -L

Host: submit63.mit.edu
Python: /work/submit/tzaklama/ct_NNVMC/venvs/ct_nnvmc_venv/bin/python
/dev/nvidia0  /dev/nvidia2  /dev/nvidiactl   /dev/nvidia-uvm-tools
/dev/nvidia1  /dev/nvidia3  /dev/nvidia-uvm

/dev/nvidia-caps:
nvidia-cap1  nvidia-cap2
GPU 0: NVIDIA GeForce GTX 1080 Ti (UUID: GPU-8cc96b65-f552-c8c5-ad52-03e6616b2931)


In [3]:
# or JAX
import jax; jax.devices()

[CudaDevice(id=0)]

In [4]:
# non_intTest.py
"""
Tests model for non-interacting 1d chain of fermions.

1d non-int chain is calculated analytically for given configuration and then used as target for model.
"""
import jax, jax.numpy as jnp

In [5]:
## Calculate exact coefficients for non-interacting 1d chain of fermions
import numpy as np
import math

def slater_coeffs_1d_sitebasis(L, N, bc=None):
    """Return (coeffs_complex, basis_masks) for the noninteracting ground state."""
    if bc is None:
        bc = 'pbc' if (N % 2 == 1) else 'apbc'
    phi = 0.0 if bc == 'pbc' else np.pi

    # choose occupied momenta
    if N % 2 == 0:
        ns = np.arange(-N//2, N//2)
    else:
        ns = np.arange(-(N//2), N//2 + 1)
    ks = (2*np.pi*ns + phi)/L

    # all N-particle bitmasks (ascending order of positions)
    from itertools import combinations
    masks, positions = [], []
    for occ in combinations(range(L), N):
        masks.append(sum(1<<i for i in occ))
        positions.append(np.array(occ, dtype=int))
    positions = np.stack(positions)                          # (H, N)

    # build orbital matrix Phi[a,x] = e^{ik_a x}/sqrt(L)
    Phi = np.exp(1j*np.outer(ks, np.arange(L))) / np.sqrt(L) # (N, L)

    # for each configuration, slice columns at positions and take det
    coeffs = np.empty(len(masks), dtype=np.complex128)
    for idx, pos in enumerate(positions):
        coeffs[idx] = np.linalg.det(Phi[:, pos]) / np.sqrt(math.factorial(N))
    return coeffs, np.array(masks, dtype=np.uint64)


In [6]:
## Calculate exact ground-state energy for non-interacting 1D chain of fermions for reference
%cd /work/submit/tzaklama/ct_NNVMC
import config as C
def tb_E0_1d(L, N, t, bc=None):
    """
    Exact ground-state energy for non-interacting 1D chain.
    bc: 'pbc' or 'apbc' (auto-choose if None: pbc for odd N, apbc for even N)
    """
    if bc is None:
        bc = 'pbc' if (N % 2 == 1) else 'apbc'
    phi = 0.0 if bc == 'pbc' else np.pi
    # choose the N integers centered around 0
    if N % 2 == 0:
        ns = np.arange(-N//2, N//2)                # even N
    else:
        ns = np.arange(-(N//2), N//2 + 1)          # odd N
    ks = (2*np.pi*ns + phi)/L
    eps = -2*t*np.cos(ks)
    return np.sum(np.sort(eps))  # already lowest; sort is harmless


E_exact = tb_E0_1d(L=C.N_SITES, N=C.N_PART, t=2*C.T_HOP, bc='pbc')
print("Exact TB E0 =", E_exact)

/work/submit/tzaklama/ct_NNVMC
Exact TB E0 = -4.82842712474619


In [7]:
"""
Latest version of the code to train a Lattice-TransFormer on synthetic data for non-interacting fermions in 1D.
"""
import jax, jax.numpy as jnp, optax
from flax.training import train_state
import config as C
T_LIST = [1.2, 1.0, 0.4, 0.6, 0.8]          # example
V_LIST = [0.0, 0.0, 0.0, 0.0, 0.0]          # non-interacting, so V=0
N_LIST = [2, 3, 5, 4, 4]          # particle numbers (can all differ)

assert len(T_LIST)==len(V_LIST)==len(N_LIST)
G = len(T_LIST)              # number of Hamiltonians

# ---------------------------------------------------------------------------
# 2)  Build concatenated training set
# ---------------------------------------------------------------------------
OCC_ALL, LAM_ALL, TARGET_ALL, GID_ALL = [], [], [], []
for gid,(t,v,npart) in enumerate(zip(T_LIST, V_LIST, N_LIST)):
    # basis for this particle number --------------------
    if C.LATTICE == "1d":
        from phys_system.lattice1D import enumerate_fock, mask_to_array
        basis = enumerate_fock(C.N_SITES, npart)
        occ   = jnp.array([mask_to_array(m, C.N_SITES) for m in basis],
                          dtype=jnp.int32)
    else:
        import phys_system.honeycomb as hc
        basis = hc.enumerate_fock(C.N_SITES, npart)
        occ   = jnp.array([hc.mask_to_array(m, C.N_SITES) for m in basis],
                          dtype=jnp.int32)

    # λ-vector extended to include N --------------------
    lam_vec = jnp.array([t, v, npart], dtype=jnp.float32)
    lam     = jnp.tile(lam_vec, (len(basis),1))

    
    # synthetic target coefficients replaced by exact non-interacting ones
    #key  = jax.random.PRNGKey(C.SEED + gid)
    #targ = jax.random.normal(key, (len(basis), 2))*0.1
    coeffs, masks = slater_coeffs_1d_sitebasis(L=C.N_SITES, N=npart, bc='pbc')
    # TARGET aligned to OCC (Re, Im)
    targ = jnp.stack([jnp.array(coeffs.real), jnp.array(coeffs.imag)], axis=1)

    gid_vec = jnp.full((len(basis),), gid, dtype=jnp.int32)

    OCC_ALL.append(occ)
    LAM_ALL.append(lam)
    TARGET_ALL.append(targ)
    GID_ALL.append(gid_vec)

# concatenate everything --------------------------------
OCC     = jnp.concatenate(OCC_ALL,    axis=0)
LAM     = jnp.concatenate(LAM_ALL,    axis=0)   # shape (B,3)
TARGET  = jnp.concatenate(TARGET_ALL, axis=0)
GIDS    = jnp.concatenate(GID_ALL,    axis=0)   # shape (B,)

print("Total training states:", OCC.shape[0], "| Hamiltonians:", G)

# ---------- model -----------------------------------------------------------
from networks.model import LatticeTransFormer
model = LatticeTransFormer(n_sites=C.N_SITES, depth=8, d_model=256)

# ---------------------------------------------------------------------------
# 3)  Loss & training step
# ---------------------------------------------------------------------------
from Loss.loss import overlap_loss_multi, amp_phase_loss, loss_normDiff   
LOSS_TYPE = "overlap_multi"   # choose in config or set here

def create_state(rng):
    params = model.init(rng, OCC, LAM, train=False)
    tx     = optax.adam(1e-3)
    return train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx)

def loss_fn(params):
    preds = model.apply(params, OCC, LAM, train=False)
    if LOSS_TYPE == "overlap_multi":
        return overlap_loss_multi(preds, TARGET, GIDS, num_groups=G)
    elif LOSS_TYPE == "amp_phase":
        return amp_phase_loss(preds, TARGET, neighbours)
    elif LOSS_TYPE == "original":
        return loss_normDiff(preds, TARGET)
    else:
        raise ValueError("Unknown LOSS_TYPE")

@jax.jit
def train_step(state):
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

# ---------- run -------------------------------------------------------------
print("JAX backend:", jax.default_backend().upper(),
      "| lattice:", C.LATTICE,
      f"| N={C.N_SITES}, Np={C.N_PART}")

state = create_state(jax.random.PRNGKey(42))
for epoch in range(3*C.EPOCHS):
    state, loss = train_step(state)
    if epoch % C.PRINT_EVERY == 0:
        print(f"epoch {epoch:4d}  loss = {float(loss):.4e}")

# ---------- output ----------------------------------------------------------
coeffs = model.apply(state.params, OCC, LAM, train=False)
print(f"\nFirst {C.N_PRINT} coefficients for (t,V)=({C.T_HOP},{0.0}):")
for i in range(min(C.N_PRINT, len(basis))):
    n_str = ''.join(str(int(b)) for b in OCC[i])
    re, im = map(float, coeffs[i])
    print(f"|{n_str}⟩ → {re:+.5f} {im:+.5f} i")


Total training states: 280 | Hamiltonians: 5
JAX backend: GPU | lattice: 1d | N=8, Np=4
epoch    0  loss = 5.4999e-01
epoch  300  loss = 1.5936e-03
epoch  600  loss = 8.5592e-06
epoch  900  loss = 3.4672e-01
epoch 1200  loss = 3.2688e-04
epoch 1500  loss = 8.9407e-06
epoch 1800  loss = 7.8678e-07
epoch 2100  loss = 2.7895e-06
epoch 2400  loss = 3.5763e-07
epoch 2700  loss = 2.1458e-07
epoch 3000  loss = 4.0531e-07
epoch 3300  loss = 5.0068e-07

First 10 coefficients for (t,V)=(0.5,0.0):
|11000000⟩ → -3.91898 +2.09372 i
|10100000⟩ → -3.92693 +6.84304 i
|10010000⟩ → -2.45174 +8.18225 i
|10001000⟩ → +1.58746 +11.37895 i
|10000100⟩ → +7.04532 +6.03460 i
|10000010⟩ → +5.42719 +5.38753 i
|10000001⟩ → +4.42431 +3.73778 i
|01100000⟩ → -0.74017 +5.49745 i
|01010000⟩ → -0.79228 +7.10872 i
|01001000⟩ → +4.48184 +5.99664 i


In [8]:
# Build the Hamiltonian for the non-interacting 1D chain of fermions
Array = jax.Array
def H_builder(C) -> Array:
    import itertools 
    from scipy.sparse import lil_matrix, csr_matrix 
    Ns              = C.N_SITES          # number of sites of the 1D chain
    Nparticelle     = C.N_PART    # number of particles

    # basis vectors  
    klattice = 2*np.pi/Ns * np.arange(Ns)                                               # k lattice points  
    linking_table = {jr: [np.mod(jr-1, Ns), np.mod(jr+1, Ns)] for jr in range(Ns)}      # 1D linking table 
    # generate Hilbert space 
    def n2occ(n,Nsize):
        binr=np.binary_repr(n,width=Nsize) 
        return binr
    def generate_combinations(N,Np):
        """ Generate all combinations of Np particles from N available positions, returning them as binary numbers. """
        combinations = itertools.combinations(range(N), Np)
        results = []
        for comb in combinations:
            binary_num = 0  # Start with an empty binary number (all zeros)
            for i in comb:
                binary_num |= (1 << i)  # Set the bit at position i
            results.append(binary_num)  # Append the binary number to the results
        return results
    ##### Nparticelle number of particles specified previously >> N=2*Ns 
    combinations = generate_combinations(N=Ns,Np=Nparticelle)
    Nstates = len(combinations)
    states = np.arange(Nstates)
    state2ind = dict(zip(combinations,states))
    print("Number of states =",Nstates)  
    ###### hopping 
    def hop(string,j1,j2): # hopping 
        Ncount = sum( int(string[j]) for j in range(j2) ) + sum( int(string[j]) for j in range(j1) )  
        if j2 >= j1: 
            hop_sign = Ncount
        else: 
            hop_sign = 1 + Ncount 
        # distruggo in j2 
        tmp = string 
        tmp = "o"+tmp+"o" 
        tmp = tmp[:1+j2]+"0"+tmp[2+j2:] 
        tmp_p = tmp[1:-1] 
        # creo in j1 
        tmp_p = "o"+tmp_p+"o" 
        tmp_p = tmp_p[:1+j1]+"1"+tmp_p[2+j1:]  
        out = tmp_p[1:-1]
        return out, (-1)**hop_sign
    # Build the Hamiltonian
    Htunnel = lil_matrix((Nstates, Nstates), dtype=np.float64)   # LIL format for easy assignment
    #Hint    = lil_matrix((Nstates, Nstates), dtype=np.float64)   # LIL format for easy assignment 
    #Hloc    = lil_matrix((Nstates, Nstates), dtype=np.float64)   # LIL format for easy assignment
    for key, index in state2ind.items():
        string = n2occ(n=key,Nsize=Ns)  
        occ_vals = np.array([int(bit) for bit in string])  
        posAH = np.where(occ_vals == 1)[0]
        #posAH_odd = posAH[posAH % 2 == 1]
        #Hloc[index,index] = len(posAH_odd)  # number of odd particles in the state 
        for xx in posAH: 
            jxx = linking_table[xx]
            #Pxx = [ occ_vals[j] for j in jxx]
            # print( xx , f'linking table, {jxx}', f'Pxx {Pxx}' , string , Pxx.count(1) )
            #Hint[index,index] += Pxx.count(1)/2 
            for yy in jxx:  
                if occ_vals[yy] == 0: 
                    string_new, segno = hop(string=string,j1=yy,j2=xx)
                    # print( xx , yy , f'new string {string_new}, sign {segno}' )
                    occ_new = np.array([int(bit) for bit in string_new])  
                    key_new = int(string_new,2) 
                    index_new = state2ind[key_new] 
                    Htunnel[index_new,index] += -segno
    # Convert LIL matrix to dense numpy array, then to jax array
    H_dense = np.array(Htunnel.todense(), dtype=np.float32)
    H_jax = jnp.array(H_dense)
    return H_jax

In [9]:
# sampler_jax.py
"""
Metropolis–Hastings sampler for spinless fermion Fock states at fixed N.
Targets p(σ) ∝ |ψ(σ; λ)|^2 using your trained JAX model.

API:
  occ_batch = sample_occ_batch(
      model_apply, params, lam_vec, L, N,
      num_samples=4096, burn_in=1024, thin=4, n_chains=16, rng_seed=0)

where
  • model_apply(params, occ, lam, train=False) -> (B, 2) [Re, Im]
  • lam_vec is a 1D array, e.g. [t, V, N]  or  [t, V]
  • returns occ_batch as int32 array of shape (num_samples, L)
"""

from typing import Callable, Tuple
import numpy as np
import jax
import jax.numpy as jnp


# ------------------------- utilities ----------------------------------------
def _random_occ(L: int, N: int, rng: np.random.Generator) -> np.ndarray:
    """Random bitstring with exactly N ones among L sites."""
    occ = np.zeros(L, dtype=np.int8)
    occ[rng.choice(L, size=N, replace=False)] = 1
    return occ

def _propose_exchange(occ: np.ndarray, rng: np.random.Generator) -> Tuple[np.ndarray, int, int]:
    """
    Propose a number-conserving move: pick one occupied site i and one empty site j and swap.
    Returns (new_occ, i, j).  If all 0s or 1s (should not happen), returns copy.
    """
    ones  = np.flatnonzero(occ == 1)
    zeros = np.flatnonzero(occ == 0)
    if len(ones) == 0 or len(zeros) == 0:
        return occ.copy(), -1, -1
    i = int(rng.choice(ones))
    j = int(rng.choice(zeros))
    new_occ = occ.copy()
    new_occ[i] = 0
    new_occ[j] = 1
    return new_occ, i, j

@jax.jit
def _coeffs_to_probs(coeff_ri: jnp.ndarray) -> jnp.ndarray:
    """(B,2)->(B,) probabilities |ψ|^2 with small epsilon for stability."""
    psi = coeff_ri[:, 0] + 1j * coeff_ri[:, 1]
    return (jnp.abs(psi) ** 2 + 1e-30).real

def _eval_probs(model_apply, params, occ_batch: np.ndarray, lam_vec: jnp.ndarray) -> np.ndarray:
    """Evaluate |ψ|^2 for a batch of integer 0/1 arrays (np) at fixed λ."""
    occ_j = jnp.asarray(occ_batch, dtype=jnp.int32)
    lam_b = jnp.tile(lam_vec[None, :], (occ_j.shape[0], 1))
    coeff = model_apply(params, occ_j, lam_b, train=False)        # (B,2)
    probs = _coeffs_to_probs(coeff)                               # (B,)
    # IMPORTANT: return a WRITEABLE NumPy array
    return np.array(probs, dtype=np.float64, copy=True)


# ------------------------- main sampler -------------------------------------
def sample_occ_batch(model_apply: Callable,
                     params,
                     lam_vec: jnp.ndarray,
                     L: int,
                     N: int,
                     num_samples: int = 4096,
                     burn_in: int = 1024,
                     thin: int = 4,
                     n_chains: int = 16,
                     rng_seed: int = 0) -> jnp.ndarray:
    """
    Return an array of shape (num_samples, L) with entries in {0,1},
    distributed approximately as |ψ(σ;λ)|^2, using n_chains independent MH chains.
    """
    assert 0 <= N <= L
    rng = np.random.default_rng(rng_seed)

    # allocate per-chain state
    chains = np.stack([_random_occ(L, N, rng) for _ in range(n_chains)], axis=0)  # (C, L)
    probs  = _eval_probs(model_apply, params, chains, lam_vec)                    # (C,)
    if not probs.flags.writeable:
        probs = probs.copy()

    # how many samples per chain (ceil)
    per_chain = (num_samples + n_chains - 1) // n_chains
    samples = []

    total_steps = burn_in + thin * per_chain
    for step in range(total_steps):
        # Propose for all chains in parallel (Python level)
        props = []
        for c in range(n_chains):
            occ_new, _, _ = _propose_exchange(chains[c], rng)
            props.append(occ_new)
        props = np.stack(props, axis=0)                                # (C, L)

        probs_new = _eval_probs(model_apply, params, props, lam_vec)   # (C,)
        # MH accept
        accept_ratio = probs_new / (probs + 1e-300)
        accept = rng.random(n_chains) < np.minimum(1.0, accept_ratio)

        # update chains & probs
        chains[accept] = props[accept]
        probs[accept]  = probs_new[accept]

        # record (after burn-in) with thinning
        if step >= burn_in and ((step - burn_in) % thin == 0):
            samples.append(chains.copy())

    # collect & trim
    if len(samples) == 0:
        # edge case: too small total_steps
        return jnp.asarray(chains[:num_samples], dtype=jnp.int32)

    samples = np.concatenate(samples, axis=0)           # (~per_chain, C, L)
    samples = samples.reshape(-1, L)                    # (C*~, L)
    samples = samples[:num_samples]                     # exact M
    return jnp.asarray(samples, dtype=jnp.int32)

In [10]:
## Given ground state wavefunction, what is ground state energy?
import observables.energy as E
import sampler.sampler_jax as sampler
# (Assumes you have (i) a sampler that draws σ ~ |ψ|^2 and (ii) H for the basis.)

model_apply = lambda params, occ, lam, train=False: \
    model.apply(params, occ, lam, train=train)

L, N, t = C.N_SITES, C.N_PART, C.T_HOP

# 1) Basis + index map (must match H ordering!) (For ED and beyond)
basis_masks = enumerate_fock(L, N)
occ_basis   = jnp.array([mask_to_array(m, L) for m in basis_masks], dtype=jnp.int32)
state_index = E.build_state_index([int(m) for m in basis_masks])

# 2) Hamiltonian for appropriate basis
H = H_builder(C) # can change this but currently C is set to npart = 4 and L = 8

Number of states = 70


In [11]:
# 3) Draw M samples σ ~ |ψ|^2  (use your sampler; pseudo-code shown)
M = 1024
lam_vec   = jnp.array([2*t, 0.0, float(N)], dtype=jnp.float32) # non-interacting, so V=0 # C.T_HOP is 0.5
lam_batch = jnp.tile(lam_vec, (M, 1))

# 1) Draw configurations σ ~ |ψ|^2
occ_batch = sample_occ_batch(model_apply, state.params, lam_vec,
                             L=L, N=N,
                             num_samples=M, burn_in=256, thin=4,
                             n_chains=32, rng_seed=0)

def model_fn(occ, lam):
    coeff = model_apply(state.params, occ, lam, train=False)  # (B,2)
    return (coeff[:, 0] + 1j * coeff[:, 1]).astype(jnp.complex64)

# 4) Monte Carlo energy (module function averages local energies)
E_mc = E.expectation_local_energy(model=model_fn,
                      occ_batch=occ_batch,
                      params_batch=lam_batch,
                      H=H,
                      occ_basis=occ_basis,
                      state_index=state_index)
print("E0(model, MC) =", float(E_mc)) 

E0(model, MC) = -4.266851902008057


In [12]:
# Larger M for better accuracy, takes far longer and is marginally better
M = 4096
lam_vec   = jnp.array([t, 0.0, float(N)], dtype=jnp.float32) # non-interacting, so V=0
lam_batch = jnp.tile(lam_vec, (M, 1))

# 1) Draw configurations σ ~ |ψ|^2
occ_batch = sample_occ_batch(model_apply, state.params, lam_vec,
                             L=L, N=N,
                             num_samples=M, burn_in=1024, thin=4,
                             n_chains=32, rng_seed=0)

def model_fn(occ, lam):
    coeff = model_apply(state.params, occ, lam, train=False)  # (B,2)
    return (coeff[:, 0] + 1j * coeff[:, 1]).astype(jnp.complex64)

# 4) Monte Carlo energy (module function averages local energies)
E_mc = E.expectation_local_energy(model=model_fn,
                      occ_batch=occ_batch,
                      params_batch=lam_batch,
                      H=H,
                      occ_basis=occ_basis,
                      state_index=state_index)
print("E0(model, MC) =", float(E_mc))

E0(model, MC) = -4.828277587890625


In [9]:
# use the SAME H you pass to expectation_local_energy (straight exact diagonalization)
evals = jnp.linalg.eigvalsh(H)      # H is (Hdim,Hdim) in N-sector
E_exact_from_H = float(evals[0])
print("E0(H) =", E_exact_from_H)

E0(H) = -4.828428745269775


In [8]:
# Rayleigh quotient method
# psi_full on the full basis used to build H
psi   = model_fn(occ, lam)
E_ray = float((jnp.vdot(psi, H @ psi) / (jnp.vdot(psi, psi) + 1e-12)).real) ## Hilbert space needs to be same for model and H
# Throws error because model is npart = 3 and L = 8 while H is npart = 4 and L = 8
print("E0(Rayleigh) =", E_ray)


E0(Rayleigh) = -3.919802665710449


In [46]:
E_mc = float(E.expectation_local_energy(
              model=model_fn, occ_batch=occ_batch,
              params_batch=lam_batch, H=H,
              occ_basis=occ_basis, state_index=state_index).real)
print("E0(model, MC) =", E_mc)

E0(model, MC) = -3.914226531982422


In [None]:
# local_energy_1d.py  (drop-in helper)

"""

def local_energy_1d_tb(model_fn, occ_batch: jnp.ndarray, lam_batch: jnp.ndarray,
                       t: float, bc: str = "pbc") -> jnp.ndarray:
    M, L = occ_batch.shape
    N = jnp.sum(occ_batch[0])  # fixed-N sampler
    # base amplitudes
    psi = model_fn(occ_batch, lam_batch)  # (M,)
    eps = 1e-30

    # helper to flip i->j
    def move(occ_np, i, j):
        occ2 = occ_np.copy()
        occ2[i] = 0; occ2[j] = 1
        return occ2

    # JW sign for wrap hops:
    jw = 1.0 if (int(N) % 2 == 1) else -1.0
    bc_phase = 1.0 if bc == "pbc" else -1.0  # add a twist if you want APBC
    wrap_sign = jw * bc_phase

    E_loc = np.zeros(M, dtype=np.complex64)

    # loop in Python (clear & reliable; vectorise later if needed)
    for m in range(M):
        occ = np.asarray(occ_batch[m])
        lam = np.asarray(lam_batch[m])
        denom = complex(psi[m])
        if abs(denom) < eps:
            continue
        acc = 0.0 + 0.0j
        for i in range(L):
            if occ[i] == 0: 
                continue
            # hop right
            j = (i + 1) % L
            if occ[j] == 0:
                occ2 = move(occ, i, j)
                # wrap sign if i->j crosses boundary
                sgn = wrap_sign if (i == L-1 and j == 0) else 1.0
                num = complex(model_fn(jnp.asarray(occ2[None,:], dtype=jnp.int32),
                                       jnp.asarray(lam[None,:], dtype=jnp.float32))[0])
                acc += -t * sgn * (num / denom)
            # hop left
            j = (i - 1) % L
            if occ[j] == 0:
                occ2 = move(occ, i, j)
                sgn = wrap_sign if (i == 0 and j == L-1) else 1.0
                num = complex(model_fn(jnp.asarray(occ2[None,:], dtype=jnp.int32),
                                       jnp.asarray(lam[None,:], dtype=jnp.float32))[0])
                acc += -t * sgn * (num / denom)
        E_loc[m] = acc
    return jnp.asarray(E_loc)

E_loc = local_energy_1d_tb(model_fn, occ_batch, lam_batch, C.T_HOP, bc="pbc")
E_mc  = float(jnp.mean(E_loc).real)
print("E0(model, local energy) =", E_mc) """

E0(model, local energy) = -1.9571131467819214


In [None]:
## Old version sketch of how to get local energy 
""" ## Given ground state wavefunction, what is ground state energy?
import observables.energy as E
import sampler.mcmc as sampler
# (Assumes you have (i) a sampler that draws σ ~ |ψ|^2 and (ii) H for the basis.)

model_apply = lambda params, occ, lam, train=False: \
    model.apply(params, occ, lam, train=train)

# 1) Basis + index map (must match H ordering!)
basis_masks = enumerate_fock(C.N_SITES, C.N_PART)
occ_basis   = jnp.array([mask_to_array(m, C.N_SITES) for m in basis_masks], dtype=jnp.int32)
state_index = E.build_state_index([int(m) for m in basis_masks])

# 2) Your Hamiltonian for that basis
# H = your_ed_builder(...)

# 3) Draw M samples σ ~ |ψ|^2  (use your sampler; pseudo-code shown)
M = 4096
lam_vec   = jnp.array([C.T_HOP, 0.0, C.N_PART], dtype=jnp.float32) # non-interacting, so V=0
lam_batch = jnp.tile(lam_vec, (M, 1))
# occ_batch = sampler.sample(model, state.params, lam_vec, M)  # shape (M, N_sites)
occ_batch = sampler.sample_chain(init_state: Array,
                 log_prob_fn: model_apply,
                 key: PRNGKey,
                 n_samples: 4096,
                 burn_in= 100,
                 thin=1)

# 4) Monte Carlo energy (module function averages local energies)
E_mc = E.expectation_local_energy(model=model_apply,
                      occ_batch=occ_batch,
                      params_batch=lam_batch,
                      H=H,
                      occ_basis=occ_basis,
                      state_index=state_index)
print("E0(model, MC) =", float(E_mc)) """

Number of states = 56
E0(H) = -4.828426837921143
