In [1]:
import functools

import tqdm
import equinox as eqx
import jax

jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
import numpy as np
import optax
from jaxtyping import Array, PyTree, Float, Complex, PRNGKeyArray

import netket as nk
from netket.operator import FermionOperator2ndJax
from netket.operator.fermion import destroy as c
from netket.operator.fermion import create as cdag
from netket.operator.fermion import number as nc

import matplotlib.pyplot as plt

from models.slaternet import SlaterNet
from models.psi_solid import PsiSolid

from systems.continuous import moire
from utils.sampler import metropolis_hastings


key = jax.random.key(42)

In [2]:
@functools.partial(jax.jit, static_argnums=(1,))
def nkstate_to_indexseq(
    state, n_particles:int
) -> jnp.ndarray:
    return jnp.nonzero(state, size=n_particles)

In [3]:
L = 12  # Side of the square
graph = nk.graph.Chain(L)
N = graph.n_nodes

In [4]:
N_f = 4
hi = nk.hilbert.SpinOrbitalFermions(N, s=1/2, n_fermions=N_f)

In [5]:
states = (nkstate_to_indexseq(state, N_f) for state in hi.states())
next(states)

(Array([20, 21, 22, 23], dtype=int64),)

In [None]:
t = 1.0
U = 4.0

H = FermionOperator2ndJax(hi)
for (i, j) in graph.edges():
    for s in [-1, 1]:
        H += -t * (cdag(hi, i, s) * c(hi, j, s) + cdag(hi, j, s) * c(hi, i, s))
        
for i in graph.nodes():
    H += -U * nc(hi, i, 1) * nc(hi, i, -1)

