In [9]:
from functools import partial
import itertools
import numpy as np
import scipy
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from netket.hilbert import SpinOrbitalFermions
from netket.operator import FermionOperator2nd
from netket.experimental.operator import ParticleNumberConservingFermioperator2nd
from netket.operator.fermion import (
    destroy as c,
    create as cdag
)

In [10]:
from src.qm_utils.fermion.fermion_utils import (
    bitset_to_array, 
    bitset_to_mode_indices,
    array_to_bitset
)
from src.qm_utils.fermion.fermionic_fock import DiscreteFermionicFockSpace
from src.acband import acband_form_factors, interaction_matrix, K_func1, interaction_hamiltonian_terms
from brillouin_zones import construct_brillouin_zones
from src.netket_compat import (
    csr_from_nk_fermion_op
)

from src.qm_utils.lattice.lattice import Lattice2D

In [11]:
# N_s = 36
# N_f = 26
# N_f_minus_2 = N_f - 2

N_s = 27
N_f = 20
N_f_minus_2 = N_f - 2

In [12]:
full_hilb = DiscreteFermionicFockSpace(mode_labels=list(range(N_s)), particle_numbers=[N_f_minus_2, N_f])

In [13]:
sqrt3 = 3.0 ** 0.5

# lB = 1.0
# a_M = (((4 * np.pi) / sqrt3) ** 0.5) * lB
a_M = 1
lB = ((sqrt3 / (4 * np.pi)) ** 0.5) * a_M
     
# fourier_resolution = 128
fourier_resolution = 256
G_radius = 32
V1 = 1.0
# v1 = 3 * V1 * (a_M ** 4) / (4 * np.pi)
v1 = 3 * V1 * (a_M ** 4) / (8 * np.pi) # ????? 4 pi -> 8 pi

# Lattice and Brillouin zones
e1 = np.array([1, 0])
e2 = np.array([0, 1])
a1 = a_M * e2
a2 = a_M * ((-sqrt3 / 2) * e1 + (1 / 2) * e2)
lattice = Lattice2D(np.stack([a1, a2]))
recip_lattice = lattice.reciprocal()

bz = construct_brillouin_zones(lattice)

bz_N_s = bz[N_s]

In [14]:
zero_idx = jnp.array(bz_N_s.zero())
sum_table_jax = jnp.array(bz_N_s.sum_table)

def sum_indices(indices):
    def scan_fn(carry, x):
        new_carry = sum_table_jax[carry, x]
        return new_carry, None
    init_carry = zero_idx
    final_carry, _ = jax.lax.scan(scan_fn, init_carry, indices)
    return final_carry

def labeling_fn(state_int):
    n_particles = jax.lax.population_count(state_int)
    mode_indices = bitset_to_mode_indices(state_int, n_modes=N_s, max_n_particles=N_f, fill_value=zero_idx)
    summed_index = sum_indices(mode_indices)
    return ((n_particles - N_f_minus_2) // 2) * N_s + summed_index

labels = [(idx, N_f_minus_2) for idx in range(N_s)] + [(idx, N_f) for idx in range(N_s)]

In [15]:
sectors = full_hilb.decompose_sectors(labeling_fn, labels)

In [16]:
sectors

{(0, 18): Sector(label=(0, 18), dim=173610, full_hilb_n_modes=27),
 (1, 18): Sector(label=(1, 18), dim=173583, full_hilb_n_modes=27),
 (2, 18): Sector(label=(2, 18), dim=173583, full_hilb_n_modes=27),
 (3, 18): Sector(label=(3, 18), dim=173583, full_hilb_n_modes=27),
 (4, 18): Sector(label=(4, 18), dim=173583, full_hilb_n_modes=27),
 (5, 18): Sector(label=(5, 18), dim=173583, full_hilb_n_modes=27),
 (6, 18): Sector(label=(6, 18), dim=173583, full_hilb_n_modes=27),
 (7, 18): Sector(label=(7, 18), dim=173583, full_hilb_n_modes=27),
 (8, 18): Sector(label=(8, 18), dim=173583, full_hilb_n_modes=27),
 (9, 18): Sector(label=(9, 18), dim=173583, full_hilb_n_modes=27),
 (10, 18): Sector(label=(10, 18), dim=173583, full_hilb_n_modes=27),
 (11, 18): Sector(label=(11, 18), dim=173583, full_hilb_n_modes=27),
 (12, 18): Sector(label=(12, 18), dim=173583, full_hilb_n_modes=27),
 (13, 18): Sector(label=(13, 18), dim=173613, full_hilb_n_modes=27),
 (14, 18): Sector(label=(14, 18), dim=173583, full_hil

In [20]:
b1, b2 = lattice.reciprocal_lattice_vectors
b3 = -(b1 + b2)

K_func_args = (0.8, b1, b2, b3)
K_func = partial(K_func1, args=K_func_args)

G_coords, ac_ff = acband_form_factors(
    bz_N_s,
    lB,
    K_func,
    fourier_resolution,
    G_radius=G_radius,
    pbar=True
)

G_vecs = recip_lattice.get_points(G_coords)
start_idx = 1
end_idx = 2 * G_radius
G_vecs_slice = G_vecs[start_idx:end_idx, start_idx:end_idx]

def V(q):
    return -v1 * np.linalg.norm(q, axis=-1) ** 2

print("G_coords shape:", G_coords.shape)
print("ac_ff shape:", ac_ff.shape)

int_mat = interaction_matrix(
    bz_N_s,
    G_coords,
    ac_ff,
    V
)

Computing AC band form factors: 100%|██████████| 4225/4225 [02:06<00:00, 33.45it/s]


G_coords shape: (65, 65, 2)
ac_ff shape: (65, 65, 27, 27)


ValueError: operands could not be broadcast together with shapes (63,63) (62,63) 

In [None]:
bz_N_s

<src.qm_utils.lattice.brillouin_zone.BrillouinZone2D at 0x7b2a1c2bcf80>

In [None]:
from src.netket_compat import (
    csr_from_nk_fermion_op
)

In [None]:
full_hilb = SpinOrbitalFermions(n_orbitals=N_s)
N_sector = SpinOrbitalFermions(n_orbitals=N_s, n_particles=N_f)
N_minus_2_sector = SpinOrbitalFermions(n_orbitals=N_s, n_particles=N_f_minus_2)

In [None]:
k = 0
neg_k = bz_N_s.neg(k)
pair_ann_op = c(full_hilb, k) @ c(full_hilb, neg_k)

In [None]:
domain = sectors[(0, 26)]
codomain = sectors[(0, 24)]
pair_ann_op_sparse_rep = csr_from_nk_fermion_op(
    pair_ann_op, domain, codomain
)

In [None]:
for x in domain.basis_labels:
    if (x >> k) & 1 == 1 and (x >> neg_k) & 1 == 1:
        print(f"{x:036b}")

In [None]:
onehot = np.zeros((domain.dim, ), dtype=np.float64)
onehot[domain.label_to_index[x]] = 1.0
y, = np.nonzero(pair_ann_op_sparse_rep @ onehot)
y = codomain.basis_labels[y.item()]

In [None]:
print(f"Initial state: {x:036b}")
print(f"Final state:   {y:036b}")

In [None]:
terms, weights = interaction_hamiltonian_terms(bz_N_s, int_mat)
H_N = FermionOperator2nd(
    N_sector,
    terms,
    weights
)
H_N = ParticleNumberConservingFermioperator2nd.from_fermionoperator2nd(H_N)
H_N_minus_2 = FermionOperator2nd(
    N_minus_2_sector,
    terms,
    weights
)
H_N_minus_2 = ParticleNumberConservingFermioperator2nd.from_fermionoperator2nd(H_N_minus_2)

In [None]:
import time
start = time.time()
sector = sectors[(0, 24)]
sparse_rep_new = csr_from_nk_fermion_op(
    H_N_minus_2, sector, sector
)
elapsed_new = time.time() - start