In [7]:
# wavefunction.py
"""
Trains the Lattice-PsiFormer for a *single* (t,V) pair on synthetic targets,
then prints the first 10 complex coefficients ψ(n; t,V).

Swap OCC/TARGET with your exact-diagonalisation data for real physics runs.
"""
import jax, jax.numpy as jnp

In [5]:
# Example output, takes about 5 min to run
from train import (                     
    OCC,                    # (H, N_sites)  integer array of Fock states
    create_state,           # helper that initialises model & optimiser
    train_step,             # one SGD step (jit-compiled)
    model                   # the LatticePsiFormer instance
)

print("Detected JAX backend →", jax.default_backend().upper(),
      "with", len(jax.devices()), "device(s)")

# -------- training set: choose a (t,V) and build batch ----------------------
t, V = 0.5, 4.0
TV_BATCH = jnp.tile(jnp.array([t, V], dtype=jnp.float32), (OCC.shape[0], 1))

# synthetic target coefficients (Re, Im)  –– replace with ED data
key    = jax.random.PRNGKey(123)
TARGET = jax.random.normal(key, (OCC.shape[0], 2)) * 0.1

# -------- quick training ----------------------------------------------------
state = create_state(jax.random.PRNGKey(0))

for epoch in range(1200):                
    state, loss = train_step(state, OCC, TV_BATCH, TARGET)
    if epoch % 300 == 0:
        print(f"epoch {epoch:4d} | MSE = {loss:.3e}")

# -------- wave-function output ---------------------------------------------
coeffs = model.apply(state.params, OCC, TV_BATCH, train=False)  # (H, 2)
print("\nFirst 10 coefficients for (t,V)=(", t, ",", V, "):")
for i in range(10):
    n_str = ''.join(str(int(b)) for b in OCC[i])
    re, im = map(float, coeffs[i])
    print(f"|{n_str}⟩  →  {re:+.5f}  {im:+.5f} i")


Detected JAX backend → CPU with 1 device(s)
epoch    0 | MSE = 9.984e-01
epoch  300 | MSE = 7.962e-01
epoch  600 | MSE = 9.956e-01
epoch  900 | MSE = 8.566e-01

First 10 coefficients for (t,V)=( 0.5 , 4.0 ):
|1111100000⟩  →  +0.78708  +7.96918 i
|1111010000⟩  →  +1.64936  +4.33142 i
|1111001000⟩  →  -2.59221  -1.57305 i
|1111000100⟩  →  -11.15819  +3.05209 i
|1111000010⟩  →  -0.65483  -4.02020 i
|1111000001⟩  →  +3.58885  +4.49224 i
|1110110000⟩  →  -2.93802  -1.41375 i
|1110101000⟩  →  -3.81637  -0.89691 i
|1110100100⟩  →  -7.02246  +0.06303 i
|1110100010⟩  →  +3.34094  +2.98012 i


In [8]:
# originally got better results with old loss function => dig deeper into code and see why
from train import (                     
    OCC,                    # (H, N_sites)  integer array of Fock states
    LOSS_TYPE,              # 'overlap' or 'amp_phase' or 'original'
    create_state,           # helper that initialises model & optimiser
    train_step,             # one SGD step (jit-compiled)
    model                   # the LatticePsiFormer instance
)
LOSS_TYPE = 'original'
print(f"Loss type:{LOSS_TYPE}\n")

print("Detected JAX backend →", jax.default_backend().upper(),
      "with", len(jax.devices()), "device(s)")

# -------- training set: choose a (t,V) and build batch ----------------------
t, V = 0.5, 4.0
TV_BATCH = jnp.tile(jnp.array([t, V], dtype=jnp.float32), (OCC.shape[0], 1))

# synthetic target coefficients (Re, Im)  –– replace with ED data
key    = jax.random.PRNGKey(123)
TARGET = jax.random.normal(key, (OCC.shape[0], 2)) * 0.1

# -------- quick training ----------------------------------------------------
state = create_state(jax.random.PRNGKey(0))

for epoch in range(1200):                
    state, loss = train_step(state, OCC, TV_BATCH, TARGET)
    if epoch % 300 == 0:
        print(f"epoch {epoch:4d} | MSE = {loss:.3e}")

# -------- wave-function output ---------------------------------------------
coeffs = model.apply(state.params, OCC, TV_BATCH, train=False)  # (H, 2)
print("\nFirst 10 coefficients for (t,V)=(", t, ",", V, "):")
for i in range(10):
    n_str = ''.join(str(int(b)) for b in OCC[i])
    re, im = map(float, coeffs[i])
    print(f"|{n_str}⟩  →  {re:+.5f}  {im:+.5f} i")

Loss type:original

Detected JAX backend → CPU with 1 device(s)
epoch    0 | MSE = 9.984e-01
epoch  300 | MSE = 7.962e-01
epoch  600 | MSE = 9.956e-01
epoch  900 | MSE = 8.566e-01

First 10 coefficients for (t,V)=( 0.5 , 4.0 ):
|1111100000⟩  →  +0.78708  +7.96918 i
|1111010000⟩  →  +1.64936  +4.33142 i
|1111001000⟩  →  -2.59221  -1.57305 i
|1111000100⟩  →  -11.15819  +3.05209 i
|1111000010⟩  →  -0.65483  -4.02020 i
|1111000001⟩  →  +3.58885  +4.49224 i
|1110110000⟩  →  -2.93802  -1.41375 i
|1110101000⟩  →  -3.81637  -0.89691 i
|1110100100⟩  →  -7.02246  +0.06303 i
|1110100010⟩  →  +3.34094  +2.98012 i


In [None]:
"""
Demo script that:
  • reads hyper-parameters from  config.py
  • enumerates the Fock basis (1-D or honeycomb)
  • builds the λ-conditioned PsiFormer
  • trains with the requested loss for EPOCHS iterations
  • prints the first N_PRINT wave-function amplitudes
Replace the synthetic TARGET with ED coefficients for real use.
"""
import jax, jax.numpy as jnp, optax
from flax.training import train_state
import config as C

# ---------- lattice & basis -------------------------------------------------
if C.LATTICE == "1d":
    from phys_system.lattice1D import enumerate_fock, mask_to_array
    BASIS   = enumerate_fock(C.N_SITES, C.N_PART)
    OCC     = jnp.array([mask_to_array(m, C.N_SITES) for m in BASIS],
                        dtype=jnp.int32)
else:                                    # honeycomb demo (Lx×Ly torus)
    from phys_system import honeycomb
    Lx = Ly = int(C.N_SITES // 2)**0.5   # crude: make a square cell
    BASIS   = honeycomb.enumerate_fock(C.N_SITES, C.N_PART)
    OCC     = jnp.array([honeycomb.mask_to_array(m, C.N_SITES) for m in BASIS],
                        dtype=jnp.int32)

# ---------- model -----------------------------------------------------------
from networks.model import LatticeTransFormer
model = LatticeTransFormer(n_sites=C.N_SITES)

# ---------- loss ------------------------------------------------------------
from Loss.loss import overlap_loss, amp_phase_loss, loss_normDiff
LOSS_FN = overlap_loss if C.LOSS_TYPE == "overlap" else amp_phase_loss if C.LOSS_TYPE == "amp_phase" else loss_normDiff

# build neighbour table if amp/phase loss is selected
if C.LOSS_TYPE == "amp_phase":
    # simple K=1 table: for every state flip the first bit
    neighbours = jnp.array([[ (i+1) % len(BASIS) ] for i in range(len(BASIS))],
                           dtype=jnp.int32)
else:
    neighbours = None

# ---------- optimiser & helpers --------------------------------------------
tv_pair  = jnp.array([C.T_HOP, C.V_INT], dtype=jnp.float32)
TV_BATCH = jnp.tile(tv_pair, (len(BASIS),1))

key   = jax.random.PRNGKey(C.SEED)
# synthetic ED target  –– swap with real data
TARGET = jax.random.normal(key, (len(BASIS), 2))*0.1

def create_state(rng):
    params = model.init(rng, OCC, TV_BATCH, train=False)
    tx     = optax.adam(1e-3)
    return train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx)

def loss_fn(params):
    preds = model.apply(params, OCC, TV_BATCH, train=False)
    return LOSS_FN(preds, TARGET) if neighbours is None \
           else LOSS_FN(preds, TARGET, neighbours)

@jax.jit
def train_step(state):
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

# ---------- run -------------------------------------------------------------
print("JAX backend:", jax.default_backend().upper(),
      "| lattice:", C.LATTICE,
      f"| N={C.N_SITES}, Np={C.N_PART}")

state = create_state(jax.random.PRNGKey(42))
for epoch in range(C.EPOCHS):
    state, loss = train_step(state)
    if epoch % C.PRINT_EVERY == 0:
        print(f"epoch {epoch:4d}  loss = {float(loss):.4e}")

# ---------- output ----------------------------------------------------------
coeffs = model.apply(state.params, OCC, TV_BATCH, train=False)
print(f"\nFirst {C.N_PRINT} coefficients for (t,V)=({C.T_HOP},{C.V_INT}):")
for i in range(min(C.N_PRINT, len(BASIS))):
    n_str = ''.join(str(int(b)) for b in OCC[i])
    re, im = map(float, coeffs[i])
    print(f"|{n_str}⟩ → {re:+.5f} {im:+.5f} i")


ImportError: cannot import name 'loss_normDiff' from 'Loss.loss' (/Users/zaklama/VS_Code/Physics_Classes/Python/ct_NNVMC/Loss/loss.py)