In [None]:
from functools import partial
import itertools
import numpy as np
import scipy
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from netket.hilbert import SpinOrbitalFermions
from netket.operator import FermionOperator2nd
from netket.experimental.operator import ParticleNumberConservingFermioperator2nd
from netket.operator.fermion import (
    destroy as c,
    create as cdag
)

In [None]:
from src.qm_utils.fermion.fermion_utils import (
    bitset_to_array, 
    bitset_to_mode_indices,
    array_to_bitset
)
from src.qm_utils.fermion.fermionic_fock import DiscreteFermionicFockSpace
from src.acband import acband_form_factors, interaction_matrix, K_func1, interaction_hamiltonian_terms
from brillouin_zones import construct_brillouin_zones
from src.netket_compat import (
    get_sector_constraints,
    csr_from_nk_fermion_op
)

from src.qm_utils.lattice.lattice import Lattice2D

In [None]:
N_s = 36
N_f = 26
N_f_minus_2 = N_f - 2

In [None]:
full_hilb = DiscreteFermionicFockSpace(mode_labels=list(range(N_s)), particle_numbers=[N_f_minus_2, N_f])

In [None]:
sqrt3 = 3.0 ** 0.5

# lB = 1.0
# a_M = (((4 * np.pi) / sqrt3) ** 0.5) * lB
a_M = 1
lB = ((sqrt3 / (4 * np.pi)) ** 0.5) * a_M
     
# fourier_resolution = 128
fourier_resolution = 256
G_radius = 64
V1 = 1.0
# v1 = 3 * V1 * (a_M ** 4) / (4 * np.pi)
v1 = 3 * V1 * (a_M ** 4) / (8 * np.pi) # ????? 4 pi -> 8 pi

# Lattice and Brillouin zones
e1 = np.array([1, 0])
e2 = np.array([0, 1])
a1 = a_M * e2
a2 = a_M * ((-sqrt3 / 2) * e1 + (1 / 2) * e2)
lattice = Lattice2D(np.stack([a1, a2]))
recip_lattice = lattice.reciprocal()

bz = construct_brillouin_zones(lattice)

bz_N_s = bz[N_s]

In [None]:
zero_idx = jnp.array(bz_N_s.zero())
sum_table_jax = jnp.array(bz_N_s.sum_table)

def sum_indices(indices):
    def scan_fn(carry, x):
        new_carry = sum_table_jax[carry, x]
        return new_carry, None
    init_carry = zero_idx
    final_carry, _ = jax.lax.scan(scan_fn, init_carry, indices)
    return final_carry

def labeling_fn(state_int):
    n_particles = jax.lax.population_count(state_int)
    mode_indices = bitset_to_mode_indices(state_int, n_modes=N_s, max_n_particles=N_f, fill_value=zero_idx)
    summed_index = sum_indices(mode_indices)
    return ((n_particles - N_f_minus_2) // 2) * N_s + summed_index

labels = [(idx, N_f_minus_2) for idx in range(N_s)] + [(idx, N_f) for idx in range(N_s)]

In [None]:
sectors = full_hilb.decompose_sectors(labeling_fn, labels)

In [None]:
len(sectors), 