In [1]:
# wavefunction.py
"""
Trains the Lattice-PsiFormer for a *single* (t,V) pair on synthetic targets,
then prints the first 10 complex coefficients ψ(n; t,V).

Swap OCC/TARGET with your exact-diagonalisation data for real physics runs.
"""
import jax, jax.numpy as jnp

In [5]:
# Example output, takes about 5 min to run
from train import (                     
    OCC,                    # (H, N_sites)  integer array of Fock states
    create_state,           # helper that initialises model & optimiser
    train_step,             # one SGD step (jit-compiled)
    model                   # the LatticePsiFormer instance
)

print("Detected JAX backend →", jax.default_backend().upper(),
      "with", len(jax.devices()), "device(s)")

# -------- training set: choose a (t,V) and build batch ----------------------
t, V = 0.5, 4.0
TV_BATCH = jnp.tile(jnp.array([t, V], dtype=jnp.float32), (OCC.shape[0], 1))

# synthetic target coefficients (Re, Im)  –– replace with ED data
key    = jax.random.PRNGKey(123)
TARGET = jax.random.normal(key, (OCC.shape[0], 2)) * 0.1

# -------- quick training ----------------------------------------------------
state = create_state(jax.random.PRNGKey(0))

for epoch in range(1200):                
    state, loss = train_step(state, OCC, TV_BATCH, TARGET)
    if epoch % 300 == 0:
        print(f"epoch {epoch:4d} | MSE = {loss:.3e}")

# -------- wave-function output ---------------------------------------------
coeffs = model.apply(state.params, OCC, TV_BATCH, train=False)  # (H, 2)
print("\nFirst 10 coefficients for (t,V)=(", t, ",", V, "):")
for i in range(10):
    n_str = ''.join(str(int(b)) for b in OCC[i])
    re, im = map(float, coeffs[i])
    print(f"|{n_str}⟩  →  {re:+.5f}  {im:+.5f} i")


Detected JAX backend → CPU with 1 device(s)
epoch    0 | MSE = 9.984e-01
epoch  300 | MSE = 7.962e-01
epoch  600 | MSE = 9.956e-01
epoch  900 | MSE = 8.566e-01

First 10 coefficients for (t,V)=( 0.5 , 4.0 ):
|1111100000⟩  →  +0.78708  +7.96918 i
|1111010000⟩  →  +1.64936  +4.33142 i
|1111001000⟩  →  -2.59221  -1.57305 i
|1111000100⟩  →  -11.15819  +3.05209 i
|1111000010⟩  →  -0.65483  -4.02020 i
|1111000001⟩  →  +3.58885  +4.49224 i
|1110110000⟩  →  -2.93802  -1.41375 i
|1110101000⟩  →  -3.81637  -0.89691 i
|1110100100⟩  →  -7.02246  +0.06303 i
|1110100010⟩  →  +3.34094  +2.98012 i


In [8]:
# originally got better results with old loss function => dig deeper into code and see why
from train import (                     
    OCC,                    # (H, N_sites)  integer array of Fock states
    LOSS_TYPE,              # 'overlap' or 'amp_phase' or 'original'
    create_state,           # helper that initialises model & optimiser
    train_step,             # one SGD step (jit-compiled)
    model                   # the LatticePsiFormer instance
)
LOSS_TYPE = 'original'
print(f"Loss type:{LOSS_TYPE}\n")

print("Detected JAX backend →", jax.default_backend().upper(),
      "with", len(jax.devices()), "device(s)")

# -------- training set: choose a (t,V) and build batch ----------------------
t, V = 0.5, 4.0
TV_BATCH = jnp.tile(jnp.array([t, V], dtype=jnp.float32), (OCC.shape[0], 1))

# synthetic target coefficients (Re, Im)  –– replace with ED data
key    = jax.random.PRNGKey(123)
TARGET = jax.random.normal(key, (OCC.shape[0], 2)) * 0.1

# -------- quick training ----------------------------------------------------
state = create_state(jax.random.PRNGKey(0))

for epoch in range(1200):                
    state, loss = train_step(state, OCC, TV_BATCH, TARGET)
    if epoch % 300 == 0:
        print(f"epoch {epoch:4d} | MSE = {loss:.3e}")

# -------- wave-function output ---------------------------------------------
coeffs = model.apply(state.params, OCC, TV_BATCH, train=False)  # (H, 2)
print("\nFirst 10 coefficients for (t,V)=(", t, ",", V, "):")
for i in range(10):
    n_str = ''.join(str(int(b)) for b in OCC[i])
    re, im = map(float, coeffs[i])
    print(f"|{n_str}⟩  →  {re:+.5f}  {im:+.5f} i")

Loss type:original

Detected JAX backend → CPU with 1 device(s)
epoch    0 | MSE = 9.984e-01
epoch  300 | MSE = 7.962e-01
epoch  600 | MSE = 9.956e-01
epoch  900 | MSE = 8.566e-01

First 10 coefficients for (t,V)=( 0.5 , 4.0 ):
|1111100000⟩  →  +0.78708  +7.96918 i
|1111010000⟩  →  +1.64936  +4.33142 i
|1111001000⟩  →  -2.59221  -1.57305 i
|1111000100⟩  →  -11.15819  +3.05209 i
|1111000010⟩  →  -0.65483  -4.02020 i
|1111000001⟩  →  +3.58885  +4.49224 i
|1110110000⟩  →  -2.93802  -1.41375 i
|1110101000⟩  →  -3.81637  -0.89691 i
|1110100100⟩  →  -7.02246  +0.06303 i
|1110100010⟩  →  +3.34094  +2.98012 i


In [None]:
"""
Demo script that:
  • reads hyper-parameters from  config.py
  • enumerates the Fock basis (1-D or honeycomb)
  • builds the λ-conditioned PsiFormer
  • trains with the requested loss for EPOCHS iterations
  • prints the first N_PRINT wave-function amplitudes
Replace the synthetic TARGET with ED coefficients for real use.
"""
import jax, jax.numpy as jnp, optax
from flax.training import train_state
import config as C
#### Loss was original here

# ---------- lattice & basis -------------------------------------------------
if C.LATTICE == "1d":
    from phys_system.lattice1D import enumerate_fock, mask_to_array
    BASIS   = enumerate_fock(C.N_SITES, C.N_PART)
    OCC     = jnp.array([mask_to_array(m, C.N_SITES) for m in BASIS],
                        dtype=jnp.int32)
else:                                    # honeycomb demo (Lx×Ly torus)
    from phys_system import honeycomb
    Lx = Ly = int(C.N_SITES // 2)**0.5   # crude: make a square cell
    BASIS   = honeycomb.enumerate_fock(C.N_SITES, C.N_PART)
    OCC     = jnp.array([honeycomb.mask_to_array(m, C.N_SITES) for m in BASIS],
                        dtype=jnp.int32)

# ---------- model -----------------------------------------------------------
from networks.model import LatticeTransFormer
model = LatticeTransFormer(n_sites=C.N_SITES)

# ---------- loss ------------------------------------------------------------
from Loss.loss import overlap_loss, amp_phase_loss, loss_normDiff
LOSS_FN = overlap_loss if C.LOSS_TYPE == "overlap" else amp_phase_loss if C.LOSS_TYPE == "amp_phase" else loss_normDiff

# build neighbour table if amp/phase loss is selected
if C.LOSS_TYPE == "amp_phase":
    # simple K=1 table: for every state flip the first bit
    neighbours = jnp.array([[ (i+1) % len(BASIS) ] for i in range(len(BASIS))],
                           dtype=jnp.int32)
else:
    neighbours = None

# ---------- optimiser & helpers --------------------------------------------
tv_pair  = jnp.array([C.T_HOP, C.V_INT], dtype=jnp.float32)
TV_BATCH = jnp.tile(tv_pair, (len(BASIS),1))

key   = jax.random.PRNGKey(C.SEED)
# synthetic ED target  –– swap with real data
TARGET = jax.random.normal(key, (len(BASIS), 2))*0.1

def create_state(rng):
    params = model.init(rng, OCC, TV_BATCH, train=False)
    tx     = optax.adam(1e-3)
    return train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx)

def loss_fn(params):
    preds = model.apply(params, OCC, TV_BATCH, train=False)
    return LOSS_FN(preds, TARGET) if neighbours is None \
           else LOSS_FN(preds, TARGET, neighbours)

@jax.jit
def train_step(state):
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

# ---------- run -------------------------------------------------------------
print("JAX backend:", jax.default_backend().upper(),
      "| lattice:", C.LATTICE,
      f"| N={C.N_SITES}, Np={C.N_PART}")

state = create_state(jax.random.PRNGKey(42))
for epoch in range(C.EPOCHS):
    state, loss = train_step(state)
    if epoch % C.PRINT_EVERY == 0:
        print(f"epoch {epoch:4d}  loss = {float(loss):.4e}")

# ---------- output ----------------------------------------------------------
coeffs = model.apply(state.params, OCC, TV_BATCH, train=False)
print(f"\nFirst {C.N_PRINT} coefficients for (t,V)=({C.T_HOP},{C.V_INT}):")
for i in range(min(C.N_PRINT, len(BASIS))):
    n_str = ''.join(str(int(b)) for b in OCC[i])
    re, im = map(float, coeffs[i])
    print(f"|{n_str}⟩ → {re:+.5f} {im:+.5f} i")


JAX backend: CPU | lattice: 1d | N=8, Np=4
epoch    0  loss = 3.5439e-01
epoch  300  loss = 8.3593e-03
epoch  600  loss = 7.5911e-03
epoch  900  loss = 6.8420e-03

First 10 coefficients for (t,V)=(0.5,4.0):
|11110000⟩ → +0.17926 +0.11561 i
|11101000⟩ → -0.01923 +0.01427 i
|11100100⟩ → +0.01521 -0.01066 i
|11100010⟩ → +0.00347 +0.06559 i
|11100001⟩ → +0.05009 +0.01844 i
|11011000⟩ → +0.15975 +0.00959 i
|11010100⟩ → +0.04839 -0.00713 i
|11010010⟩ → +0.09587 +0.11234 i
|11010001⟩ → +0.12981 +0.03281 i
|11001100⟩ → -0.03394 -0.05011 i


In [2]:
"""
Demo script that:
  • reads hyper-parameters from  config.py
  • enumerates the Fock basis (1-D or honeycomb)
  • builds the λ-conditioned PsiFormer
  • trains with the requested loss for EPOCHS iterations
  • prints the first N_PRINT wave-function amplitudes
Replace the synthetic TARGET with ED coefficients for real use.
"""
import jax, jax.numpy as jnp, optax
from flax.training import train_state
import config as C

print(f"LOSS_TYPE: {C.LOSS_TYPE}, N_SITES: {C.N_SITES}, N_PART: {C.N_PART}\n")

# ---------- lattice & basis -------------------------------------------------
if C.LATTICE == "1d":
    from phys_system.lattice1D import enumerate_fock, mask_to_array
    BASIS   = enumerate_fock(C.N_SITES, C.N_PART)
    OCC     = jnp.array([mask_to_array(m, C.N_SITES) for m in BASIS],
                        dtype=jnp.int32)
else:                                    # honeycomb demo (Lx×Ly torus)
    from phys_system import honeycomb
    Lx = Ly = int(C.N_SITES // 2)**0.5   # crude: make a square cell
    BASIS   = honeycomb.enumerate_fock(C.N_SITES, C.N_PART)
    OCC     = jnp.array([honeycomb.mask_to_array(m, C.N_SITES) for m in BASIS],
                        dtype=jnp.int32)

# ---------- model -----------------------------------------------------------
from networks.model import LatticeTransFormer
model = LatticeTransFormer(n_sites=C.N_SITES)

# ---------- loss ------------------------------------------------------------
from Loss.loss import overlap_loss, amp_phase_loss, loss_normDiff
LOSS_FN = overlap_loss if C.LOSS_TYPE == "overlap" else amp_phase_loss if C.LOSS_TYPE == "amp_phase" else loss_normDiff

# build neighbour table if amp/phase loss is selected
if C.LOSS_TYPE == "amp_phase":
    # simple K=1 table: for every state flip the first bit
    neighbours = jnp.array([[ (i+1) % len(BASIS) ] for i in range(len(BASIS))],
                           dtype=jnp.int32)
else:
    neighbours = None

# ---------- optimiser & helpers --------------------------------------------
tv_pair  = jnp.array([C.T_HOP, C.V_INT], dtype=jnp.float32)
TV_BATCH = jnp.tile(tv_pair, (len(BASIS),1))

key   = jax.random.PRNGKey(C.SEED)
# synthetic ED target  –– swap with real data
TARGET = jax.random.normal(key, (len(BASIS), 2))*0.1

def create_state(rng):
    params = model.init(rng, OCC, TV_BATCH, train=False)
    tx     = optax.adam(1e-3)
    return train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx)

def loss_fn(params):
    preds = model.apply(params, OCC, TV_BATCH, train=False)
    return LOSS_FN(preds, TARGET) if neighbours is None \
           else LOSS_FN(preds, TARGET, neighbours)

@jax.jit
def train_step(state):
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

# ---------- run -------------------------------------------------------------
print("JAX backend:", jax.default_backend().upper(),
      "| lattice:", C.LATTICE,
      f"| N={C.N_SITES}, Np={C.N_PART}")

state = create_state(jax.random.PRNGKey(42))
for epoch in range(C.EPOCHS):
    state, loss = train_step(state)
    if epoch % C.PRINT_EVERY == 0:
        print(f"epoch {epoch:4d}  loss = {float(loss):.4e}")

# ---------- output ----------------------------------------------------------
coeffs = model.apply(state.params, OCC, TV_BATCH, train=False)
print(f"\nFirst {C.N_PRINT} coefficients for (t,V)=({C.T_HOP},{C.V_INT}):")
for i in range(min(C.N_PRINT, len(BASIS))):
    n_str = ''.join(str(int(b)) for b in OCC[i])
    re, im = map(float, coeffs[i])
    print(f"|{n_str}⟩ → {re:+.5f} {im:+.5f} i")


LOSS_TYPE: overlap, N_SITES: 8, N_PART: 4

JAX backend: CPU | lattice: 1d | N=8, Np=4
epoch    0  loss = 9.9563e-01
epoch  300  loss = -4.7684e-07
epoch  600  loss = 0.0000e+00
epoch  900  loss = 0.0000e+00

First 10 coefficients for (t,V)=(0.5,4.0):
|11110000⟩ → -10.57874 -9.42932 i
|11101000⟩ → +2.40518 +0.03086 i
|11100100⟩ → -0.06480 +5.39429 i
|11100010⟩ → +2.22153 -3.11550 i
|11100001⟩ → -2.70599 +5.71653 i
|11011000⟩ → -9.96054 +12.50556 i
|11010100⟩ → -2.07514 -0.52794 i
|11010010⟩ → -8.24413 -6.97237 i
|11010001⟩ → -5.76745 -2.34796 i
|11001100⟩ → +1.61014 +10.33784 i


In [2]:
"""
Demo script that:
  • reads hyper-parameters from  config.py
  • enumerates the Fock basis (1-D or honeycomb)
  • builds the λ-conditioned PsiFormer
  • trains with the requested loss for EPOCHS iterations
  • prints the first N_PRINT wave-function amplitudes
Replace the synthetic TARGET with ED coefficients for real use.
"""
import jax, jax.numpy as jnp, optax
from flax.training import train_state
import config as C

print(f"LOSS_TYPE: {C.LOSS_TYPE}, N_SITES: {C.N_SITES}, N_PART: {C.N_PART}\n")

# ---------- lattice & basis -------------------------------------------------
if C.LATTICE == "1d":
    from phys_system.lattice1D import enumerate_fock, mask_to_array
    BASIS   = enumerate_fock(C.N_SITES, C.N_PART)
    OCC     = jnp.array([mask_to_array(m, C.N_SITES) for m in BASIS],
                        dtype=jnp.int32)
else:                                    # honeycomb demo (Lx×Ly torus)
    from phys_system import honeycomb
    Lx = Ly = int(C.N_SITES // 2)**0.5   # crude: make a square cell
    BASIS   = honeycomb.enumerate_fock(C.N_SITES, C.N_PART)
    OCC     = jnp.array([honeycomb.mask_to_array(m, C.N_SITES) for m in BASIS],
                        dtype=jnp.int32)

# ---------- model -----------------------------------------------------------
from networks.model import LatticeTransFormer
model = LatticeTransFormer(n_sites=C.N_SITES)

# ---------- loss ------------------------------------------------------------
from Loss.loss import overlap_loss, amp_phase_loss, loss_normDiff
LOSS_FN = overlap_loss if C.LOSS_TYPE == "overlap" else amp_phase_loss if C.LOSS_TYPE == "amp_phase" else loss_normDiff

# build neighbour table if amp/phase loss is selected
if C.LOSS_TYPE == "amp_phase":
    # simple K=1 table: for every state flip the first bit
    neighbours = jnp.array([[ (i+1) % len(BASIS) ] for i in range(len(BASIS))],
                           dtype=jnp.int32)
else:
    neighbours = None

# ---------- optimiser & helpers --------------------------------------------
tv_pair  = jnp.array([C.T_HOP, C.V_INT], dtype=jnp.float32)
TV_BATCH = jnp.tile(tv_pair, (len(BASIS),1))

key   = jax.random.PRNGKey(C.SEED)
# synthetic ED target  –– swap with real data
TARGET = jax.random.normal(key, (len(BASIS), 2))*0.1

def create_state(rng):
    params = model.init(rng, OCC, TV_BATCH, train=False)
    tx     = optax.adam(1e-3)
    return train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx)

def loss_fn(params):
    preds = model.apply(params, OCC, TV_BATCH, train=False)
    return LOSS_FN(preds, TARGET) if neighbours is None \
           else LOSS_FN(preds, TARGET, neighbours)

@jax.jit
def train_step(state):
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

# ---------- run -------------------------------------------------------------
print("JAX backend:", jax.default_backend().upper(),
      "| lattice:", C.LATTICE,
      f"| N={C.N_SITES}, Np={C.N_PART}")

state = create_state(jax.random.PRNGKey(42))
for epoch in range(C.EPOCHS):
    state, loss = train_step(state)
    if epoch % C.PRINT_EVERY == 0:
        print(f"epoch {epoch:4d}  loss = {float(loss):.4e}")

# ---------- output ----------------------------------------------------------
coeffs = model.apply(state.params, OCC, TV_BATCH, train=False)
print(f"\nFirst {C.N_PRINT} coefficients for (t,V)=({C.T_HOP},{C.V_INT}):")
for i in range(min(C.N_PRINT, len(BASIS))):
    n_str = ''.join(str(int(b)) for b in OCC[i])
    re, im = map(float, coeffs[i])
    print(f"|{n_str}⟩ → {re:+.5f} {im:+.5f} i")


LOSS_TYPE: amp_phase, N_SITES: 8, N_PART: 4

JAX backend: CPU | lattice: 1d | N=8, Np=4
epoch    0  loss = 6.5938e+00
epoch  300  loss = 6.6211e+00
epoch  600  loss = 2.5931e+00
epoch  900  loss = 6.9722e+00

First 10 coefficients for (t,V)=(0.5,4.0):
|11110000⟩ → +0.25724 -0.19093 i
|11101000⟩ → +0.20006 -0.34435 i
|11100100⟩ → +0.21316 -0.32568 i
|11100010⟩ → +0.19771 -0.31076 i
|11100001⟩ → +0.22936 -0.25183 i
|11011000⟩ → +0.21408 +0.02825 i
|11010100⟩ → +0.25752 -0.03252 i
|11010010⟩ → +0.25345 -0.04475 i
|11010001⟩ → +0.26106 -0.01124 i
|11001100⟩ → +0.15463 -0.08958 i


In [4]:
"""
Demo script that:
  • reads hyper-parameters from  config.py
  • enumerates the Fock basis (1-D or honeycomb)
  • builds the λ-conditioned PsiFormer
  • trains with the requested loss for EPOCHS iterations
  • prints the first N_PRINT wave-function amplitudes
Replace the synthetic TARGET with ED coefficients for real use.
"""
import jax, jax.numpy as jnp, optax
from flax.training import train_state
import config as C
T_LIST = [0.5, 1.0]          # example
V_LIST = [4.0, 2.5]
N_LIST = [4  , 3  ]          # particle numbers (can all differ)

assert len(T_LIST)==len(V_LIST)==len(N_LIST)
G = len(T_LIST)              # number of Hamiltonians

# ---------------------------------------------------------------------------
# 2)  Build concatenated training set
# ---------------------------------------------------------------------------
OCC_ALL, LAM_ALL, TARGET_ALL, GID_ALL = [], [], [], []
for gid,(t,v,npart) in enumerate(zip(T_LIST, V_LIST, N_LIST)):
    # basis for this particle number --------------------
    if C.LATTICE == "1d":
        from phys_system.lattice1D import enumerate_fock, mask_to_array
        basis = enumerate_fock(C.N_SITES, npart)
        occ   = jnp.array([mask_to_array(m, C.N_SITES) for m in basis],
                          dtype=jnp.int32)
    else:
        import phys_system.honeycomb as hc
        basis = hc.enumerate_fock(C.N_SITES, npart)
        occ   = jnp.array([hc.mask_to_array(m, C.N_SITES) for m in basis],
                          dtype=jnp.int32)

    # λ-vector extended to include N --------------------
    lam_vec = jnp.array([t, v, npart], dtype=jnp.float32)
    lam     = jnp.tile(lam_vec, (len(basis),1))

    # synthetic target coefficients --------------------- =====> REPLACE WITH ED DATA
    ### read from ED/output/ED/output/gsWaveFn_Ns12_Np2_Vnn8.0.csv for example ==> See ED/ed_driver.py to see how it data can be generated
    key  = jax.random.PRNGKey(C.SEED + gid)
    targ = jax.random.normal(key, (len(basis), 2))*0.1

    gid_vec = jnp.full((len(basis),), gid, dtype=jnp.int32)

    OCC_ALL.append(occ)
    LAM_ALL.append(lam)
    TARGET_ALL.append(targ)
    GID_ALL.append(gid_vec)

# concatenate everything --------------------------------
OCC     = jnp.concatenate(OCC_ALL,    axis=0)
LAM     = jnp.concatenate(LAM_ALL,    axis=0)   # shape (B,3)
TARGET  = jnp.concatenate(TARGET_ALL, axis=0)
GIDS    = jnp.concatenate(GID_ALL,    axis=0)   # shape (B,)

print("Total training states:", OCC.shape[0], "| Hamiltonians:", G)

# ---------- model -----------------------------------------------------------
from networks.model import LatticeTransFormer
model = LatticeTransFormer(n_sites=C.N_SITES, depth=8, d_model=256)

# ---------------------------------------------------------------------------
# 3)  Loss & training step
# ---------------------------------------------------------------------------
from Loss.loss import overlap_loss_multi, amp_phase_loss, loss_normDiff   
LOSS_TYPE = "overlap_multi"   # choose in config or set here

def create_state(rng):
    params = model.init(rng, OCC, LAM, train=False)
    tx     = optax.adam(1e-3)
    return train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx)

def loss_fn(params):
    preds = model.apply(params, OCC, LAM, train=False)
    if LOSS_TYPE == "overlap_multi":
        return overlap_loss_multi(preds, TARGET, GIDS, num_groups=G)
    elif LOSS_TYPE == "amp_phase":
        return amp_phase_loss(preds, TARGET, neighbours)
    elif LOSS_TYPE == "original":
        return loss_normDiff(preds, TARGET)
    else:
        raise ValueError("Unknown LOSS_TYPE")

@jax.jit
def train_step(state):
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

# ---------- run -------------------------------------------------------------
print("JAX backend:", jax.default_backend().upper(),
      "| lattice:", C.LATTICE,
      f"| N={C.N_SITES}, Np={C.N_PART}")

state = create_state(jax.random.PRNGKey(42))
for epoch in range(C.EPOCHS):
    state, loss = train_step(state)
    if epoch % C.PRINT_EVERY == 0:
        print(f"epoch {epoch:4d}  loss = {float(loss):.4e}")

# ---------- output ----------------------------------------------------------
coeffs = model.apply(state.params, OCC, LAM, train=False)
print(f"\nFirst {C.N_PRINT} coefficients for (t,V)=({C.T_HOP},{C.V_INT}):")
for i in range(min(C.N_PRINT, len(basis))):
    n_str = ''.join(str(int(b)) for b in OCC[i])
    re, im = map(float, coeffs[i])
    print(f"|{n_str}⟩ → {re:+.5f} {im:+.5f} i")


Total training states: 126 | Hamiltonians: 2
JAX backend: CPU | lattice: 1d | N=8, Np=4
epoch    0  loss = 9.9778e-01
epoch  300  loss = 1.9491e-05
epoch  600  loss = 8.5818e-04
epoch  900  loss = 2.3842e-07

First 10 coefficients for (t,V)=(0.5,4.0):
|11110000⟩ → -10.66331 +9.49846 i
|11101000⟩ → +0.30749 -2.40529 i
|11100100⟩ → +5.39244 +0.68848 i
|11100010⟩ → -2.85898 -2.57981 i
|11100001⟩ → +5.41218 +3.37788 i
|11011000⟩ → +11.36997 +11.41555 i
|11010100⟩ → -0.76415 +2.01575 i
|11010010⟩ → -7.93361 +7.45169 i
|11010001⟩ → -3.01623 +5.50620 i
|11001100⟩ → +10.53860 -0.41466 i
