In [2]:
from functools import partial
import itertools
import numpy as np
import scipy
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

from netket.hilbert import SpinOrbitalFermions
from netket.operator import FermionOperator2nd
from netket.experimental.operator import ParticleNumberConservingFermioperator2nd
from netket.operator.fermion import (
    destroy as c,
    create as cdag
)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from src.qm_utils.fermion.fermion_utils import (
    bitset_to_array, 
    bitset_to_mode_indices,
    array_to_bitset
)
from src.qm_utils.fermion.fermionic_fock import DiscreteFermionicFockSpace
from src.acband import acband_form_factors, interaction_matrix, K_func1, interaction_hamiltonian_terms
from brillouin_zones import construct_brillouin_zones
from src.netket_compat import (
    get_sector_constraints,
    csr_from_nk_fermion_op
)

from src.qm_utils.lattice.lattice import Lattice2D

In [4]:
N_s = 27
N_f = 20
N_f_minus_2 = N_f - 2

In [5]:
fock = DiscreteFermionicFockSpace(mode_labels=list(range(N_s)), particle_numbers=[N_f_minus_2, N_f])

In [6]:
sqrt3 = 3.0 ** 0.5

# lB = 1.0
# a_M = (((4 * np.pi) / sqrt3) ** 0.5) * lB
a_M = 1
lB = ((sqrt3 / (4 * np.pi)) ** 0.5) * a_M
     
# fourier_resolution = 128
fourier_resolution = 256
G_radius = 64
V1 = 1.0
# v1 = 3 * V1 * (a_M ** 4) / (4 * np.pi)
v1 = 3 * V1 * (a_M ** 4) / (8 * np.pi) # ????? 4 pi -> 8 pi


# Lattice and Brillouin zones
e1 = np.array([1, 0])
e2 = np.array([0, 1])
a1 = a_M * e2
a2 = a_M * ((-sqrt3 / 2) * e1 + (1 / 2) * e2)
lattice = Lattice2D(np.stack([a1, a2]))
recip_lattice = lattice.reciprocal()

bz = construct_brillouin_zones(lattice)

bz_N_s = bz[N_s]

In [7]:
zero_idx = jnp.array(bz_N_s.zero())
sum_table_jax = jnp.array(bz_N_s.sum_table)

def sum_indices(indices):
    def scan_fn(carry, x):
        new_carry = sum_table_jax[carry, x]
        return new_carry, None
    init_carry = zero_idx
    final_carry, _ = jax.lax.scan(scan_fn, init_carry, indices)
    return final_carry

def labeling_fn(state_int):
    n_particles = jax.lax.population_count(state_int)
    mode_indices = bitset_to_mode_indices(state_int, n_modes=N_s, max_n_particles=N_f, fill_value=zero_idx)
    summed_index = sum_indices(mode_indices)
    return ((n_particles - N_f_minus_2) // 2) * N_s + summed_index

labels = [(idx, N_f_minus_2) for idx in range(N_s)] + [(idx, N_f) for idx in range(N_s)]

In [8]:
sectors = fock.decompose_sectors(labeling_fn, labels)

In [9]:
sectors

{(0, 18): Sector(label=(0, 18), dim=173583, full_hilb_n_modes=27),
 (1, 18): Sector(label=(1, 18), dim=173583, full_hilb_n_modes=27),
 (2, 18): Sector(label=(2, 18), dim=173583, full_hilb_n_modes=27),
 (3, 18): Sector(label=(3, 18), dim=173583, full_hilb_n_modes=27),
 (4, 18): Sector(label=(4, 18), dim=173583, full_hilb_n_modes=27),
 (5, 18): Sector(label=(5, 18), dim=173583, full_hilb_n_modes=27),
 (6, 18): Sector(label=(6, 18), dim=173583, full_hilb_n_modes=27),
 (7, 18): Sector(label=(7, 18), dim=173583, full_hilb_n_modes=27),
 (8, 18): Sector(label=(8, 18), dim=173583, full_hilb_n_modes=27),
 (9, 18): Sector(label=(9, 18), dim=173613, full_hilb_n_modes=27),
 (10, 18): Sector(label=(10, 18), dim=173583, full_hilb_n_modes=27),
 (11, 18): Sector(label=(11, 18), dim=173583, full_hilb_n_modes=27),
 (12, 18): Sector(label=(12, 18), dim=173583, full_hilb_n_modes=27),
 (13, 18): Sector(label=(13, 18), dim=173583, full_hilb_n_modes=27),
 (14, 18): Sector(label=(14, 18), dim=173583, full_hil

In [10]:
constraints = get_sector_constraints(bz_N_s, N_f)
full_hilb = SpinOrbitalFermions(n_orbitals=N_s)
hilbs = [
    SpinOrbitalFermions(
        n_orbitals=N_s, n_fermions=N_f, constraint=constraint
    ) for constraint in constraints
]
b1, b2 = lattice.reciprocal_lattice_vectors
b3 = -(b1 + b2)

In [11]:
b1, b2 = lattice.reciprocal_lattice_vectors
b3 = -(b1 + b2)

K_func_args = (0.8, b1, b2, b3)
K_func = partial(K_func1, args=K_func_args)

G_coords, ac_ff = acband_form_factors(
    bz_N_s,
    lB,
    K_func,
    fourier_resolution,
    G_radius=G_radius,
    pbar=True
)

G_vecs = recip_lattice.get_points(G_coords)
start_idx = 1
end_idx = 2 * G_radius
G_vecs_slice = G_vecs[start_idx:end_idx, start_idx:end_idx]

def V(q):
    return -v1 * np.linalg.norm(q, axis=-1) ** 2

int_mat = interaction_matrix(
    bz_N_s,
    G_coords,
    ac_ff,
    V
)

Computing AC band form factors: 100%|██████████| 16641/16641 [10:45<00:00, 25.80it/s]


In [12]:
hamiltonians = []
for sector_index, sector in enumerate(hilbs):
    terms, weights = interaction_hamiltonian_terms(bz_N_s, int_mat)
    H = FermionOperator2nd(
        sector,
        terms,
        weights
    )
    H = ParticleNumberConservingFermioperator2nd.from_fermionoperator2nd(H)
    hamiltonians.append(H)

In [13]:
k = 0
neg_k = bz_N_s.neg(k)
pair_ann_op = c(full_hilb, k) @ c(full_hilb, neg_k)

In [14]:
domain = sectors[(0, 20)]
codomain = sectors[(0, 18)]
op = pair_ann_op

In [16]:
sparse_rep = csr_from_nk_fermion_op(
    op, domain, codomain
)

Consider using `netket.experimental.operator.ParticleNumberAndSpinConservingFermioperator2nd` to reduce the number of connected elements and
considerably reduce the computational cost.
You can convert this operator by calling `netket.experimental.operator.ParticleNumberAndSpinConservingFermioperator2nd.from_fermionoperator2nd`.

  super()._setup(self)


In [17]:
x = 0b1101111111111111111101
onehot = np.zeros((domain.dim, ), dtype=np.float64)
onehot[domain.label_to_index[x]] = 1.0
y, = np.nonzero(sparse_rep @ onehot)
y = codomain.basis_labels[y.item()]

In [124]:
x = 0b1101111111111111111101
array_to_bitset(bitset_to_array(x, N_s)) == x

Array(True, dtype=bool)

In [18]:
neg_k

np.int64(21)

In [19]:
print(f"{x:027b}")
print(f"{y:027b}")

000001101111111111111111101
000000101111111111111111100


In [125]:
import time
start = time.time()
sparse_rep_ref = H.to_sparse()
elapsed_ref = time.time() - start
print(f"Reference conversion time: {elapsed_ref:.6f} seconds")

start = time.time()
sector = sectors[(0, 20)]
sparse_rep_new = csr_from_nk_fermion_op(
    H, sector, sector
)
elapsed_new = time.time() - start
print(f"New conversion time: {elapsed_new:.6f} seconds")

Reference conversion time: 7.277420 seconds
New conversion time: 5.140373 seconds


In [None]:
# x = jax.vmap(bitset_to_array, in_axes=(0, None))(
#     sector.basis_labels, sector.full_hilb.n_modes
# )
# sections = np.empty(sector.dim, dtype=np.int32)
# x_prime, mels = H.get_conn_flattened(x, sections)
# bitsets = jax.vmap(array_to_bitset)(x_prime)
# numbers = np.searchsorted(sector.basis_labels, bitsets)
# sections1 = np.empty(sections.size + 1, dtype=np.int32)
# sections1[1:] = sections
# sections1[0] = 0
# sparse_rep_new = scipy.sparse.csr_matrix(
#     (mels, numbers, sections1), 
#     shape=(sector.dim, sector.dim)
# )

x = hilbs[0].all_states()

sections = np.empty(sector.dim, dtype=np.int32)
x_prime, mels = H.get_conn_flattened(x, sections)

numbers = hilbs[0].states_to_numbers(x_prime)
sections1 = np.empty(sections.size + 1, dtype=np.int32)
sections1[1:] = sections
sections1[0] = 0
sparse_rep_new = scipy.sparse.csr_matrix(
    (mels, numbers, sections1), 
    shape=(sector.dim, sector.dim)
)

In [127]:
x = hilbs[0].all_states()
x_int = np.array(jax.vmap(array_to_bitset)(x))

In [113]:
perm = np.searchsorted(sector.basis_labels, x_int)

In [114]:
x_int - sector.basis_labels[perm]

array([0, 0, 0, ..., 0, 0, 0], shape=(32890,))

In [117]:
for a, b in zip(x_int, sector.basis_labels[perm]):
    if a != b:
        print("Mismatch found")

In [134]:
# basis_labels = hilbs[0].basis_labels
x = hilbs[0].all_states()
x_int = jax.vmap(array_to_bitset)(x)
perm = np.searchsorted(sector.basis_labels, x_int) # nk basis가 어디에 있는지

diff = 0
for i, j in itertools.product(range(sparse_rep_ref.shape[0]), repeat=2):
    v_ref = sparse_rep_ref[i, j]
    v_new = sparse_rep_new[perm[i], perm[j]]

    if abs(v_ref - v_new) > 1e-12:
        print(f"Mismatch at ({i}, {j}): ref={v_ref}, new={v_new}")
        print(f"{x_int[i]:027b} <- {x_int[j]:027b}, \n{sector.basis_labels[perm][i]:027b} <- {sector.basis_labels[perm][j]:027b}")
        break

diff = sparse_rep_ref - sparse_rep_new[perm, :][:, perm]
norm_ref = scipy.sparse.linalg.norm(sparse_rep_ref.copy())
norm_new = scipy.sparse.linalg.norm(sparse_rep_new.copy())
norm_diff = scipy.sparse.linalg.norm(diff)
print(f"Reference norm: {norm_ref:.6e}")
print(f"New norm: {norm_new:.6e}")
print(f"Difference norm: {norm_diff:.6e}")

Mismatch at (0, 0): ref=(18.539384689335325+3.611846954112911e-16j), new=(23.78528704237422+3.530191728397453e-16j)
111111110111111111111000000 <- 111111110111111111111000000, 
111111110111111111111000000 <- 111111110111111111111000000
Reference norm: 4.011622e+03
New norm: 4.011639e+03
Difference norm: 4.762947e+02


In [49]:
evals_ref = scipy.sparse.linalg.eigsh(sparse_rep_ref, k=20, return_eigenvectors=False)
evals_new = scipy.sparse.linalg.eigsh(sparse_rep_new, k=20, return_eigenvectors=False)

In [51]:
np.abs(evals_ref - evals_new) / np.abs(evals_ref)

array([4.39505415e-05, 5.01050643e-05, 6.28900279e-04, 9.39796448e-05,
       3.88610080e-04, 6.73790601e-05, 6.70041406e-05, 8.57738493e-05,
       8.96269856e-05, 1.02185427e-04, 2.91673675e-04, 2.21377138e-04,
       1.82965065e-04, 1.16043396e-06, 1.39973378e-04, 1.15329595e-04,
       2.32181220e-04, 5.28619117e-06, 2.07134369e-04, 3.82767771e-06])