In [1]:
import os
os.environ['NUMBA_NUM_THREADS'] = '20'
import itertools
import quimb.tensor as qtn
import numpy as np

In [None]:
from autoray import do
from vmc_torch.fermion_utils import from_quimb_config_to_netket_config, from_netket_config_to_quimb_config, calc_phase_symmray

def generate_binary_vectors(N, m):
    # Generate all combinations of m positions out of N indices
    combinations = itertools.combinations(range(N), m)
    
    # Create a vector for each combination of positions
    vectors = []
    for positions in combinations:
        vec = [0] * N
        for idx in positions:
            vec[idx] = 1
        vectors.append(vec)
    return vectors

class Hilbert:

    @property
    def size(self):
        """Number of states in the Hilbert space"""
        raise NotImplementedError
    
    def to_quimb_config(self, config):
        raise NotImplementedError
    
    def all_states(self):
        """Generate all states in the Hilbert space"""
        raise NotImplementedError
    
    def random_state(self, key):
        """Generate a random state in the Hilbert space"""
        raise NotImplementedError

class SpinfulFermion(Hilbert):
    def __init__(self, n_orbitals, n_fermions=None, n_fermions_per_spin=None):
        """
        Spinful fermionic Hilbert space with n_orbitals orbitals and n_fermions fermions.
        n_fermions_per_spin is the number of fermions per spin, if not provided, it will be set to n_fermions // 2.

        Configuration structure: [nu1,...,nuN, nd1,...,ndN]
        where nu1,...,nuN are the spin-up fermions and nd1,...,ndN are the spin-down fermions.
        The total number of fermions is n_fermions = n_fermions_up + n_fermions_down.
        The total number of orbitals is n_orbitals = n_up + n_down.
        
        """
        self.n_orbitals = n_orbitals*2  # Total number of orbitals (spin up + spin down)
        self.n_fermions = n_fermions
        self.n_fermions_spin_up = n_fermions_per_spin[0] if n_fermions_per_spin is not None else n_fermions // 2
        self.n_fermions_spin_down = n_fermions_per_spin[1] if n_fermions_per_spin is not None else n_fermions // 2

    
    def all_states(self):
        spin_up_states = generate_binary_vectors(self.n_orbitals // 2, self.n_fermions_spin_up)
        spin_down_states = generate_binary_vectors(self.n_orbitals // 2, self.n_fermions_spin_down)
        return (up + down for up, down in itertools.product(spin_up_states, spin_down_states))
    
    def random_state(self, key=None):
        """Generate a random state in the Hilbert space.
    
        Args:
            key (int, optional): Random seed for reproducibility.
        
        Returns:
            np.ndarray: A random binary state of shape (n_orbitals,).
        """
        rng = np.random.default_rng(key)
    
        n_up = self.n_fermions_spin_up
        n_down = self.n_fermions_spin_down
        n_half = self.n_orbitals // 2

        # Randomly select positions for spin-up fermions
        up_positions = rng.choice(n_half, size=n_up, replace=False)
        up_state = np.zeros(n_half, dtype=np.int32)
        up_state[up_positions] = 1

        # Randomly select positions for spin-down fermions
        down_positions = rng.choice(n_half, size=n_down, replace=False)
        down_state = np.zeros(n_half, dtype=np.int32)
        down_state[down_positions] = 1

        # Concatenate spin-up and spin-down states
        return from_netket_config_to_quimb_config(np.concatenate([up_state, down_state]))
    
    # def to_quimb_config(self, config):
    #     return from_netket_config_to_quimb_config(config)

class Spin(Hilbert):
    def __init__(self, s, N, total_sz=None):
        """Hilbert space obtained as tensor product of local spin states.

        Args:
           s: Spin at each site. Must be integer or half-integer. Currently only supports s=1/2.
           N: Number of sites (default=1)
           total_sz: If given, constrains the total spin of system to a particular
                value.
        """
        self.s = s
        assert float(s)==0.5, "Currently only supports s=1/2 for Spin Hilbert space"
        self.N = N
        self.total_sz = total_sz
        if self.total_sz is not None:
            assert type(self.total_sz) == int, "total_sz must be an integer for spin-1/2 sites"
        self._size = int(2 * s + 1) ** N
    
    def from_netket_to_quimb_spin_config(self, config):
        """From (-1,1) to (0,1) basis"""
        def func(x):
            return (x + 1) / 2
        if len(config.shape) == 1:
            return func(config)
        else:
            return do('array',([func(c) for c in config]))
    
    def from_quimb_to_netket_spin_config(self, config):
        """From (0,1) to (-1,1) basis"""
        def func(x):
            return 2*x - 1
        if len(config.shape) == 1:
            return func(config)
        else:
            return do('array',([func(c) for c in config]))
        
    def all_states(self):
        """
        Generate all states in the Hilbert space.
        Each state is a configuration of spins at each site.
        """
        # Generate all possible configurations for a single site
        single_site_states = np.linspace(0, 1, int(2 * self.s + 1))
        assert len(single_site_states) == 2, 'Currently only supports s=1/2 for Spin Hilbert space, len(single_site_states) should be 2'
        
        if self.total_sz is not None:
            su = int(self.N/2 + self.total_sz/self.s/2)
            all_states = generate_binary_vectors(self.N, su)
            return do('array', list(all_states))
            

        else:
            # Generate all combinations of single-site states for N sites
            all_states = itertools.product(single_site_states, repeat=self.N)            
            # Convert to numpy array
            return do('array', list(all_states))
        
    def random_state(self, key=None):
        """
        Generate a random state in the Hilbert space.
        
        Args:
            key (int, optional): Random seed for reproducibility.
        
        Returns:
            np.ndarray: A random configuration of spins in the Hilbert space.
        """
        rng = np.random.default_rng(key)
        # Generate a random state by sampling from the possible spin values
        single_site_states = np.linspace(-self.s, self.s, int(2 * self.s + 1))
        assert len(single_site_states) == 2, 'Currently only supports s=1/2 for Spin Hilbert space, len(single_site_states) should be 2'

        if self.total_sz is not None:
            # If total_sz is specified, we need to sample a state with the correct total spin
            su = int(self.N/2 + self.total_sz/self.s/2)
            su_positions = rng.choice(self.N, size=su, replace=False)
            random_state = np.zeros(self.N, dtype=np.int32)
            random_state[su_positions] = 1

        else:
            # Otherwise, just generate a random state from the available single-site states
            random_state = rng.choice(np.linspace(0, 1, int(2 * self.s + 1)), size=self.N)
            
        return do('array', random_state)

class Graph:
    def __init__(self):
        self._edges = None
        self._site_index_map = None
    
    @property
    def edges(self):
        return self._edges
    
    @property
    def site_index_map(self):
        return self._site_index_map

class SquareLatticeGraph(Graph):
    def __init__(self, Lx, Ly, pbc=False, site_index_map=lambda i, j, Lx, Ly: i * Ly + j):
        """Zig-zag ordering"""
        self.Lx = Lx
        self.Ly = Ly
        self.pbc = pbc
        edges = qtn.edges_2d_square(self.Lx, self.Ly, cyclic=self.pbc)
        self._edges = [(site_index_map(*site_i, Lx, Ly), site_index_map(*site_j, Lx, Ly)) for site_i, site_j in edges]
        self._site_index_map = site_index_map


class Hamiltonian:
    def __init__(self, H, hi, graph):
        self._H = H
        self._hi = hi
        self._graph = graph
        self._hilbert = None # Customized Hilbert space
    
    @property
    def hi(self):
        """Hilbert space"""
        return self._hi
    
    @property
    def graph(self):
        """Graph"""
        return self._graph
    
    @property
    def H(self):
        return self._H
    
    def get_conn(self, sigma):
        raise NotImplementedError


def square_lattice_spinful_Fermi_Hubbard(Lx, Ly, t, U, N_f, pbc=False, n_fermions_per_spin=None):
    """Implementation of spinful Fermi-Hubbard model on a square lattice"""
    if pbc:
        raise NotImplementedError("PBC not implemented yet")
    N = Lx * Ly
    if n_fermions_per_spin is None:
        hi = SpinfulFermion(n_orbitals=N, n_fermions=N_f)
    else:
        hi = SpinfulFermion(n_orbitals=N, n_fermions_per_spin=n_fermions_per_spin)
    
    graph = SquareLatticeGraph(Lx, Ly, pbc)

    H = dict()
    for i, j in graph.edges:
        for spin in (1,-1):
            H[(i, j, spin)] = -t

    for i in range(N):
        H[(i,)] = U
        
    return H, hi, graph



class spinful_Fermi_Hubbard_square_lattice_torch(Hamiltonian):
    def __init__(self, Lx, Ly, t, U, N_f, pbc=False, n_fermions_per_spin=None):
        """
        Implementation of spinful Fermi-Hubbard model on a square lattice using torch.
        Args:
            N_f is used to restrict the Hilbert space.
        """
        H, hi, graph = square_lattice_spinful_Fermi_Hubbard(Lx, Ly, t, U, N_f, pbc=pbc, n_fermions_per_spin=n_fermions_per_spin)
        super().__init__(H, hi, graph)
        self._hi = hi
        self._H = H
    def get_conn(self, sigma_quimb):
        """
        Return the connected configurations <eta| by the Hamiltonian to the state |sigma>,
        and their corresponding coefficients <eta|H|sigma>.
        """
        sigma = from_quimb_config_to_netket_config(sigma_quimb)
        connected_config_coeff = dict()
        for key, value in self._H.items():
            if len(key) == 3:
                # hopping term
                i0, j0, spin = key
                i = i0 if spin == 1 else i0 + self.hi.n_orbitals // 2
                j = j0 if spin == 1 else j0 + self.hi.n_orbitals // 2
                # Check if the two sites are different
                if sigma[i] != sigma[j]:
                    # H|sigma> = -t * |eta>
                    eta = sigma.copy()
                    eta[i], eta[j] = sigma[j], sigma[i]
                    eta_quimb0 = from_netket_config_to_quimb_config(eta)
                    eta_quimb = tuple(eta_quimb0)
                    # Calculate the phase correction
                    phase = calc_phase_symmray(from_netket_config_to_quimb_config(sigma), eta_quimb0)
                    if eta_quimb not in connected_config_coeff:
                        connected_config_coeff[eta_quimb] = value*phase
                    else:
                        connected_config_coeff[eta_quimb] += value*phase
            elif len(key) == 1:
                # on-site term
                i = key[0]
                if sigma_quimb[i] == 3:
                    eta_quimb = tuple(sigma_quimb.copy())
                    if eta_quimb not in connected_config_coeff:
                        connected_config_coeff[eta_quimb] = value
                    else:
                        connected_config_coeff[eta_quimb] += value
        
        return do('array', list(connected_config_coeff.keys())), do('array', list(connected_config_coeff.values()))


def square_lattice_spin_Heisenberg(Lx, Ly, J, pbc=False, total_sz=None):
    # Build square lattice with nearest neighbor edges
    N = Lx * Ly
    hi = Spin(s=1/2, N=N, total_sz=total_sz)  # Spin-1/2 Hilbert space
    graph = SquareLatticeGraph(Lx, Ly, pbc)
    # Heisenberg with coupling J for nearest neighbors
    H = dict()
    for i, j in graph.edges:
        # Add the Heisenberg term for the edge (i, j)
        # The Heisenberg Hamiltonian is J * (S_i . S_j) = J * (S_i^x S_j^x + S_i^y S_j^y + S_i^z S_j^z)
        # H = \sum_<i,j> 0.5J * (S_i^+ S_j^- + S_i^- S_j^+) + J * S_i^z S_j^z
        # Note S = 1/2\sigma

        if type(J) is dict:
            # If J is a dictionary, use the specific coupling for the edge (i,j)
            J_value = J.get((i, j), 0)
            H[(i, j)] = J_value
        else:
            H[(i, j)] = J
    
    return H, hi, graph

class spin_Heisenberg_square_lattice_torch(Hamiltonian):
    def __init__(self, Lx, Ly, J, pbc=False, total_sz=None):
        """
        Implementation of spin-1/2 Heisenberg model on a square lattice using torch.
        Args:
            J: Coupling constant (can be a dict for edge-specific couplings)
            total_sz: If given, constrains the total spin of system to a particular value.
        """
        H, hi, graph = square_lattice_spin_Heisenberg(Lx, Ly, J, pbc=pbc, total_sz=total_sz)
        super().__init__(H, hi, graph)
    
    def get_conn(self, sigma_quimb):
        """
        Return the connected configurations <eta| by the Hamiltonian to the state |sigma>,
        and their corresponding coefficients <eta|H|sigma>.
        """
        connected_config_coeff = dict()
        sigma = sigma_quimb.copy()
        for key, value in self._H.items():
            i, j = key
            J = value
            if sigma[i] != sigma[j]:
                # Hopping term

                # H|sigma> = 0.5J * |eta>
                eta = sigma.copy()
                eta[i], eta[j] = sigma[j], sigma[i]
                if tuple(eta) not in connected_config_coeff:
                    # Calculate the phase correction (not needed for Heisenberg)
                    connected_config_coeff[tuple(eta)] = 0.5 * J
                else:
                    # Accumulate the coefficients for degenerate states
                    connected_config_coeff[tuple(eta)] += 0.5 * J
            
            eta0 = sigma.copy()
            if tuple(eta0) not in connected_config_coeff:
                # Handle the case of on-site term, which is J * S_i^z S_j^z
                # For Heisenberg, this is already included in the coupling above
                connected_config_coeff[tuple(eta0)] = 0.25*J*(-1)**(abs(sigma[i]-sigma[j]))
            else:
                # Accumulate the coefficients for degenerate states
                connected_config_coeff[tuple(eta0)] += 0.25*J*(-1)**(abs(sigma[i]-sigma[j]))

        return do('array', list(connected_config_coeff.keys())), do('array', list(connected_config_coeff.values()))

In [71]:
from vmc_torch.hamiltonian import spinful_Fermi_Hubbard_square_lattice,spin_Heisenberg_square_lattice, spin_J1J2_square_lattice

In [90]:
import pyinstrument
Lx, Ly = 4,4
spinhi = Spin(s=1/2, N=Lx*Ly, total_sz=0)
random_config = spinhi.random_state()
H_benchmark = spin_Heisenberg_square_lattice(Lx, Ly, J=1.0, pbc=False, total_sz=0)
H = spin_Heisenberg_square_lattice_torch(Lx, Ly, J=4.0, pbc=False, total_sz=0)

with pyinstrument.Profiler() as profiler:
    etas, coeffs = H_benchmark.get_conn(random_config)
profiler.print()  # Print the profiling results
with pyinstrument.Profiler() as profiler:
    etas_, coeffs_ = H.get_conn(random_config)
profiler.print()

eta_coeff_benchmark = {tuple(eta): coeff for eta, coeff in zip(etas, coeffs)}
eta_coeff = {tuple(eta): coeff for eta, coeff in zip(etas_, coeffs_)}
print(random_config)
for eta in eta_coeff_benchmark.keys():
    assert eta in eta_coeff.keys()
    assert np.allclose(eta_coeff_benchmark[eta], eta_coeff[eta])
    print(np.array(eta), eta_coeff_benchmark[eta], eta_coeff[eta])



  _     ._   __/__   _ _  _  _ _/_   Recorded: 17:02:49  Samples:  0
 /_//_/// /_\ / //_// / //_'/ //     Duration: 0.001     CPU time: 0.001
/   _/                      v5.0.1

Profile at /tmp/ipykernel_7807/4058469748.py:8

No samples were recorded.



  _     ._   __/__   _ _  _  _ _/_   Recorded: 17:02:49  Samples:  0
 /_//_/// /_\ / //_// / //_'/ //     Duration: 0.000     CPU time: 0.000
/   _/                      v5.0.1

Profile at /tmp/ipykernel_7807/4058469748.py:11

No samples were recorded.


[0 1 1 1 1 1 0 0 0 0 1 0 1 1 0 0]
[0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 1. 0. 1. 1. 0. 0.] -4.0 -4.0
[0. 1. 1. 0. 1. 1. 0. 1. 0. 0. 1. 0. 1. 1. 0. 0.] 2.0 2.0
[0. 1. 1. 1. 1. 1. 0. 0. 1. 0. 1. 0. 0. 1. 0. 0.] 2.0 2.0
[0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 1. 0. 1. 0. 1. 0.] 2.0 2.0
[0. 1. 1. 1. 1. 0. 1. 0. 0. 0. 1. 0. 1. 1. 0. 0.] 2.0 2.0
[0. 1. 1. 1. 0. 1. 0. 0. 1. 0. 1. 0. 1. 1. 0. 0.] 2.0 2.0
[0. 1. 1. 1. 1. 0. 0. 0. 0. 1. 1. 0. 1. 1. 0. 0.] 2.0 2.0
[1. 0. 1. 1. 1. 1. 0. 0. 0. 0. 1. 0. 1. 1. 0. 0

In [4]:
import pyinstrument
Lx, Ly = 4,4
n_orbitals = Lx * Ly
n_fermions = int(n_orbitals)
spinful_hi = SpinfulFermion(n_orbitals=n_orbitals, n_fermions=n_fermions)
random_state = spinful_hi.random_state()
H_benchmark = spinful_Fermi_Hubbard_square_lattice(Lx=Lx, Ly=Ly, t=1.0, U=8.0, N_f=n_fermions, pbc=False, n_fermions_per_spin=(n_fermions//2, n_fermions//2))
H = spinful_Fermi_Hubbard_square_lattice_torch(Lx=Lx, Ly=Ly, t=1.0, U=8.0, N_f=n_fermions, pbc=False, n_fermions_per_spin=(n_fermions//2, n_fermions//2))
random_quimb_config = from_netket_config_to_quimb_config(random_state)
with pyinstrument.Profiler() as profiler:
    etas, coeffs = H_benchmark.get_conn(random_quimb_config)
profiler.print()
random_quimb_config


  _     ._   __/__   _ _  _  _ _/_   Recorded: 00:06:47  Samples:  1141
 /_//_/// /_\ / //_// / //_'/ //     Duration: 1.715     CPU time: 2.167
/   _/                      v4.7.3

Profile at /tmp/ipykernel_29274/2345923011.py:10

1.714 <module>  ../../../../../tmp/ipykernel_29274/2345923011.py:10
└─ 1.714 spinful_Fermi_Hubbard_square_lattice.get_conn  ../hamiltonian.py:255
   └─ 1.713 FermionOperator2nd.get_conn  netket/operator/_discrete_operator.py:126
         [399 frames hidden]  netket, numba, colorama, llvmlite, pp...




array([2, 3, 0, 3, 3, 3, 3, 2, 3, 0, 0, 0, 1, 0, 0, 1])

In [5]:
with pyinstrument.Profiler() as profiler:
    etas_, coeffs_ = H.get_conn(random_quimb_config)
profiler.print()

etas, coeffs = H_benchmark.get_conn(random_quimb_config)
random_state, from_quimb_config_to_netket_config(etas), coeffs
etas_, coeffs_ = H.get_conn(random_quimb_config)
# random_state, from_quimb_config_to_netket_config(etas_), coeffs_


  _     ._   __/__   _ _  _  _ _/_   Recorded: 00:06:50  Samples:  2
 /_//_/// /_\ / //_// / //_'/ //     Duration: 0.003     CPU time: 0.003
/   _/                      v4.7.3

Profile at /tmp/ipykernel_29274/1755742582.py:1

0.002 <module>  ../../../../../tmp/ipykernel_29274/1755742582.py:1
└─ 0.002 spinful_Fermi_Hubbard_square_lattice_torch.get_conn  ../../../../../tmp/ipykernel_29274/3808544642.py:202
   └─ 0.002 calc_phase_symmray  ../fermion_utils.py:1039
      └─ 0.002 from_quimb_config_to_netket_config  ../fermion_utils.py:990
         └─ 0.002 func  ../fermion_utils.py:993




In [6]:
eta_coeff_benchmark = {tuple(eta): coeff for eta, coeff in zip(etas, coeffs)}
eta_coeff = {tuple(eta): coeff for eta, coeff in zip(etas_, coeffs_)}
for eta in eta_coeff_benchmark.keys():
    assert eta in eta_coeff.keys()
    assert np.allclose(eta_coeff_benchmark[eta], eta_coeff[eta])
    print(eta, eta_coeff_benchmark[eta], eta_coeff[eta])

(2, 3, 0, 3, 3, 3, 3, 2, 3, 0, 0, 0, 1, 0, 0, 1) 48.0 48.0
(2, 3, 0, 2, 3, 3, 3, 3, 3, 0, 0, 0, 1, 0, 0, 1) 1.0 1.0
(2, 3, 0, 3, 3, 3, 3, 2, 3, 0, 0, 0, 0, 1, 0, 1) -1.0 -1.0
(2, 3, 0, 3, 3, 3, 3, 2, 2, 1, 0, 0, 1, 0, 0, 1) 1.0 1.0
(2, 3, 0, 3, 3, 3, 3, 2, 1, 2, 0, 0, 1, 0, 0, 1) -1.0 -1.0
(2, 3, 0, 3, 3, 3, 3, 2, 1, 0, 0, 0, 3, 0, 0, 1) 1.0 1.0
(2, 3, 0, 3, 3, 2, 3, 2, 3, 1, 0, 0, 1, 0, 0, 1) -1.0 -1.0
(2, 3, 0, 3, 3, 1, 3, 2, 3, 2, 0, 0, 1, 0, 0, 1) 1.0 1.0
(2, 3, 0, 3, 3, 3, 3, 2, 3, 0, 0, 0, 1, 0, 1, 0) -1.0 -1.0
(3, 2, 0, 3, 3, 3, 3, 2, 3, 0, 0, 0, 1, 0, 0, 1) 1.0 1.0
(2, 2, 1, 3, 3, 3, 3, 2, 3, 0, 0, 0, 1, 0, 0, 1) 1.0 1.0
(2, 1, 2, 3, 3, 3, 3, 2, 3, 0, 0, 0, 1, 0, 0, 1) -1.0 -1.0
(3, 3, 0, 3, 2, 3, 3, 2, 3, 0, 0, 0, 1, 0, 0, 1) 1.0 1.0
(2, 3, 0, 3, 3, 3, 2, 3, 3, 0, 0, 0, 1, 0, 0, 1) 1.0 1.0
(2, 3, 0, 3, 3, 3, 2, 2, 3, 0, 1, 0, 1, 0, 0, 1) -1.0 -1.0
(2, 3, 0, 3, 3, 3, 1, 2, 3, 0, 2, 0, 1, 0, 0, 1) 1.0 1.0
(2, 3, 1, 2, 3, 3, 3, 2, 3, 0, 0, 0, 1, 0, 0, 1) -1.0 -1.0
(2, 3, 2, 1, 3,

In [15]:
from vmc_torch.fermion_utils import *

for eta in etas:
    # print(calc_phase_netket(from_netket_config_to_quimb_config(random_state), eta))
    print(calc_phase_symmray(from_netket_config_to_quimb_config(random_state), eta))

1
1
1
-1
-1
1
-1
-1
-1
-1


In [57]:
import numpy as np

keys = np.array(['a', 'b', 'c'])
values = np.array([1, 2, 3])

dictionary = dict(zip(keys, values))
print(dictionary)

{'a': 1, 'b': 2, 'c': 3}
