In [None]:
from functools import partial
import itertools
import numpy as np
import scipy
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from netket.hilbert import SpinOrbitalFermions
from netket.operator import FermionOperator2nd
from netket.experimental.operator import ParticleNumberConservingFermioperator2nd
from netket.operator.fermion import (
    destroy as c,
    create as cdag
)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from src.qm_utils.fermion.fermion_utils import (
    bitset_to_array, 
    bitset_to_mode_indices,
    array_to_bitset
)
from src.qm_utils.fermion.fermionic_fock import DiscreteFermionicFockSpace
from src.acband import acband_form_factors, interaction_matrix, K_func1, interaction_hamiltonian_terms
from brillouin_zones import construct_brillouin_zones
from src.netket_compat import (
    get_sector_constraints,
    csr_from_nk_fermion_op
)

from src.qm_utils.lattice.lattice import Lattice2D

ModuleNotFoundError: No module named 'src'

In [None]:
N_s = 27
N_f = 20
N_f_minus_2 = N_f - 2

In [None]:
fock = DiscreteFermionicFockSpace(mode_labels=list(range(N_s)), particle_numbers=[N_f_minus_2, N_f])

In [None]:
sqrt3 = 3.0 ** 0.5

# lB = 1.0
# a_M = (((4 * np.pi) / sqrt3) ** 0.5) * lB
a_M = 1
lB = ((sqrt3 / (4 * np.pi)) ** 0.5) * a_M
     
# fourier_resolution = 128
fourier_resolution = 256
G_radius = 64
V1 = 1.0
# v1 = 3 * V1 * (a_M ** 4) / (4 * np.pi)
v1 = 3 * V1 * (a_M ** 4) / (8 * np.pi) # ????? 4 pi -> 8 pi


# Lattice and Brillouin zones
e1 = np.array([1, 0])
e2 = np.array([0, 1])
a1 = a_M * e2
a2 = a_M * ((-sqrt3 / 2) * e1 + (1 / 2) * e2)
lattice = Lattice2D(np.stack([a1, a2]))
recip_lattice = lattice.reciprocal()

bz = construct_brillouin_zones(lattice)

bz_N_s = bz[N_s]

In [None]:
zero_idx = jnp.array(bz_N_s.zero())
sum_table_jax = jnp.array(bz_N_s.sum_table)

def sum_indices(indices):
    def scan_fn(carry, x):
        new_carry = sum_table_jax[carry, x]
        return new_carry, None
    init_carry = zero_idx
    final_carry, _ = jax.lax.scan(scan_fn, init_carry, indices)
    return final_carry

def labeling_fn(state_int):
    n_particles = jax.lax.population_count(state_int)
    mode_indices = bitset_to_mode_indices(state_int, n_modes=N_s, max_n_particles=N_f, fill_value=zero_idx)
    summed_index = sum_indices(mode_indices)
    return ((n_particles - N_f_minus_2) // 2) * N_s + summed_index

labels = [(idx, N_f_minus_2) for idx in range(N_s)] + [(idx, N_f) for idx in range(N_s)]

In [None]:
sectors = fock.decompose_sectors(labeling_fn, labels)

In [None]:
sectors

{(0, 18): Sector(label=(0, 18), dim=173583, full_hilb_n_modes=27),
 (1, 18): Sector(label=(1, 18), dim=173583, full_hilb_n_modes=27),
 (2, 18): Sector(label=(2, 18), dim=173583, full_hilb_n_modes=27),
 (3, 18): Sector(label=(3, 18), dim=173583, full_hilb_n_modes=27),
 (4, 18): Sector(label=(4, 18), dim=173583, full_hilb_n_modes=27),
 (5, 18): Sector(label=(5, 18), dim=173583, full_hilb_n_modes=27),
 (6, 18): Sector(label=(6, 18), dim=173583, full_hilb_n_modes=27),
 (7, 18): Sector(label=(7, 18), dim=173583, full_hilb_n_modes=27),
 (8, 18): Sector(label=(8, 18), dim=173583, full_hilb_n_modes=27),
 (9, 18): Sector(label=(9, 18), dim=173613, full_hilb_n_modes=27),
 (10, 18): Sector(label=(10, 18), dim=173583, full_hilb_n_modes=27),
 (11, 18): Sector(label=(11, 18), dim=173583, full_hilb_n_modes=27),
 (12, 18): Sector(label=(12, 18), dim=173583, full_hilb_n_modes=27),
 (13, 18): Sector(label=(13, 18), dim=173583, full_hilb_n_modes=27),
 (14, 18): Sector(label=(14, 18), dim=173583, full_hil

In [None]:
constraints = get_sector_constraints(bz_N_s, N_f)
full_hilb = SpinOrbitalFermions(n_orbitals=N_s)
hilbs = [
    SpinOrbitalFermions(
        n_orbitals=N_s, n_fermions=N_f, constraint=constraint
    ) for constraint in constraints
]
b1, b2 = lattice.reciprocal_lattice_vectors
b3 = -(b1 + b2)

In [None]:
b1, b2 = lattice.reciprocal_lattice_vectors
b3 = -(b1 + b2)

K_func_args = (0.8, b1, b2, b3)
K_func = partial(K_func1, args=K_func_args)

G_coords, ac_ff = acband_form_factors(
    bz_N_s,
    lB,
    K_func,
    fourier_resolution,
    G_radius=G_radius,
    pbar=True
)

G_vecs = recip_lattice.get_points(G_coords)
start_idx = 1
end_idx = 2 * G_radius
G_vecs_slice = G_vecs[start_idx:end_idx, start_idx:end_idx]

def V(q):
    return -v1 * np.linalg.norm(q, axis=-1) ** 2

int_mat = interaction_matrix(
    bz_N_s,
    G_coords,
    ac_ff,
    V
)

Computing AC band form factors: 100%|██████████| 16641/16641 [10:56<00:00, 25.33it/s]


In [21]:
hamiltonians = []
for sector_index, sector in enumerate(hilbs):
    terms, weights = interaction_hamiltonian_terms(bz_N_s, int_mat)
    H = FermionOperator2nd(
        sector,
        terms,
        weights
    )
    H = ParticleNumberConservingFermioperator2nd.from_fermionoperator2nd(H)
    hamiltonians.append(H)

In [22]:
k = 0
neg_k = bz_N_s.neg(k)
pair_ann_op = c(full_hilb, k) @ c(full_hilb, neg_k)

In [23]:
domain = sectors[(0, 20)]
codomain = sectors[(0, 18)]
op = pair_ann_op

In [24]:
sparse_rep = csr_from_nk_fermion_op(
    op, domain, codomain
)

Consider using `netket.experimental.operator.ParticleNumberAndSpinConservingFermioperator2nd` to reduce the number of connected elements and
considerably reduce the computational cost.
You can convert this operator by calling `netket.experimental.operator.ParticleNumberAndSpinConservingFermioperator2nd.from_fermionoperator2nd`.

  super()._setup(self)


In [25]:
x = 0b1101111111111111111101
onehot = np.zeros((domain.dim, ), dtype=np.float64)
onehot[domain.label_to_index[x]] = 1.0
y, = np.nonzero(sparse_rep @ onehot)
y = codomain.basis_labels[y.item()]

In [26]:
x = 0b1101111111111111111101
array_to_bitset(bitset_to_array(x, N_s)) == x

Array(True, dtype=bool)

In [27]:
neg_k

np.int64(21)

In [28]:
print(f"{x:027b}")
print(f"{y:027b}")

000001101111111111111111101
000000101111111111111111100


In [29]:
import time
H = hamiltonians[0]
start = time.time()
sparse_rep_ref = H.to_sparse()
elapsed_ref = time.time() - start
print(f"Reference conversion time: {elapsed_ref:.6f} seconds")

start = time.time()
sector = sectors[(0, 20)]
sparse_rep_new = csr_from_nk_fermion_op(
    H, sector, sector
)
elapsed_new = time.time() - start
print(f"New conversion time: {elapsed_new:.6f} seconds")

Reference conversion time: 7.835921 seconds
New conversion time: 5.834336 seconds


In [30]:
x = H.hilbert.all_states()
x_int = jax.vmap(array_to_bitset)(x)
perm = np.searchsorted(sector.basis_labels, x_int) # nk basis가 어디에 있는지

diff = sparse_rep_ref - sparse_rep_new[perm, :][:, perm]
norm_ref = scipy.sparse.linalg.norm(sparse_rep_ref.copy())
norm_new = scipy.sparse.linalg.norm(sparse_rep_new.copy())
norm_diff = scipy.sparse.linalg.norm(diff)
print(f"Reference norm: {norm_ref:.6e}")
print(f"New norm: {norm_new:.6e}")
print(f"Difference norm: {norm_diff:.6e}")

Reference norm: 4.011639e+03
New norm: 4.011639e+03
Difference norm: 0.000000e+00


In [32]:
for i in range(N_s):
    print("Checking Sector ", i)
    H = hamiltonians[i]
    start = time.time()
    sparse_rep_ref = H.to_sparse()
    elapsed_ref = time.time() - start
    print(f"Reference conversion time: {elapsed_ref:.6f} seconds")

    start = time.time()
    sector = sectors[(i, 20)]
    sparse_rep_new = csr_from_nk_fermion_op(
        H, sector, sector
    )
    elapsed_new = time.time() - start
    print(f"New conversion time: {elapsed_new:.6f} seconds")

    x = H.hilbert.all_states()
    x_int = jax.vmap(array_to_bitset)(x)
    perm = np.searchsorted(sector.basis_labels, x_int) # nk basis가 어디에 있는지

    diff = sparse_rep_ref - sparse_rep_new[perm, :][:, perm]
    norm_ref = scipy.sparse.linalg.norm(sparse_rep_ref.copy())
    norm_new = scipy.sparse.linalg.norm(sparse_rep_new.copy())
    norm_diff = scipy.sparse.linalg.norm(diff)
    print(f"Reference norm: {norm_ref:.6e}")
    print(f"New norm: {norm_new:.6e}")
    print(f"Difference norm: {norm_diff:.6e}")
    assert norm_diff < 1e-10
    

Checking Sector  0
Reference conversion time: 6.837692 seconds
New conversion time: 5.891257 seconds
Reference norm: 4.011639e+03
New norm: 4.011639e+03
Difference norm: 0.000000e+00
Checking Sector  1
Reference conversion time: 15.197236 seconds
New conversion time: 4.890256 seconds
Reference norm: 4.011656e+03
New norm: 4.011656e+03
Difference norm: 0.000000e+00
Checking Sector  2
Reference conversion time: 13.401549 seconds
New conversion time: 5.939713 seconds
Reference norm: 4.011639e+03
New norm: 4.011639e+03
Difference norm: 0.000000e+00
Checking Sector  3
Reference conversion time: 12.094738 seconds
New conversion time: 5.379302 seconds
Reference norm: 4.011656e+03
New norm: 4.011656e+03
Difference norm: 0.000000e+00
Checking Sector  4
Reference conversion time: 12.809017 seconds
New conversion time: 5.806751 seconds
Reference norm: 4.011647e+03
New norm: 4.011647e+03
Difference norm: 0.000000e+00
Checking Sector  5
Reference conversion time: 12.661896 seconds
New conversion ti