In [1]:
import os
os.environ["OPENBLAS_NUM_THREADS"] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ["OMP_NUM_THREADS"] = '1'
os.environ['DENSE_TENSOR'] = '0'
import sys
import warnings
warnings.filterwarnings("ignore")
from mpi4py import MPI
import numpy as np
import pickle
pwd = '/home/sijingdu/TNVMC/VMC_code/vmc_torch/data'
# torch
import torch
torch.autograd.set_detect_anomaly(False)

# quimb
import quimb.tensor as qtn
import autoray as ar

from vmc_torch.experiment.tn_model import *
from vmc_torch.sampler import MetropolisExchangeSamplerSpinful, MetropolisMPSSamplerSpinful
from vmc_torch.variational_state import Variational_State
from vmc_torch.optimizer import SGD, SR,Adam, SGD_momentum, DecayScheduler, TrivialPreconditioner
from vmc_torch.VMC import VMC
from vmc_torch.hamiltonian_torch import spinful_Fermi_Hubbard_square_lattice_torch
from vmc_torch.torch_utils import SVD,QR

# Register safe SVD and QR functions to torch
ar.register_function('torch','linalg.svd',SVD.apply)
ar.register_function('torch','linalg.qr',QR.apply)

from vmc_torch.global_var import DEBUG
from vmc_torch.utils import closest_divisible


COMM = MPI.COMM_WORLD
SIZE = COMM.Get_size()
RANK = COMM.Get_rank()

# Hamiltonian parameters
Lx = int(4)
Ly = int(4)
symmetry = 'Z2'
t = 1.0
U = 8.0
N_f = int(Lx*Ly)
# N_f = int(Lx*Ly)
n_fermions_per_spin = (N_f//2, N_f//2)
H = spinful_Fermi_Hubbard_square_lattice_torch(Lx, Ly, t, U, N_f, pbc=False, n_fermions_per_spin=n_fermions_per_spin)
graph = H.graph
# TN parameters
D = 4
chi = -1
dtype=torch.float64
torch.random.manual_seed(RANK)
np.random.seed(RANK)

# Load PEPS
skeleton = pickle.load(open(pwd+f"/{Lx}x{Ly}/t={t}_U={U}/N={N_f}/{symmetry}/D={D}/peps_skeleton.pkl", "rb"))
peps_params = pickle.load(open(pwd+f"/{Lx}x{Ly}/t={t}_U={U}/N={N_f}/{symmetry}/D={D}/peps_su_params.pkl", "rb"))
peps = qtn.unpack(peps_params, skeleton)
device = torch.device("cpu")
peps.apply_to_arrays(lambda x: torch.tensor(x, dtype=dtype, device=device))
peps.exponent = torch.tensor(peps.exponent, dtype=dtype, device=device)

# # randomize the PEPS tensors
# peps.apply_to_arrays(lambda x: torch.randn_like(torch.tensor(x, dtype=dtype), dtype=dtype))

# VMC sample size
N_samples = int(20)
N_samples = closest_divisible(N_samples, SIZE)
if (N_samples/SIZE)%2 != 0:
    N_samples += SIZE
        
# nn_hidden_dim = Lx*Ly
# model = fTNModel_vec(peps, max_bond=chi, dtype=dtype, functional=True, device=device)
# model1 = fTNModel(peps, max_bond=chi, dtype=dtype, functional=False)
# model1.tree = model.tree
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Set up sampler
sampler = MetropolisExchangeSamplerSpinful(H.hilbert, graph, N_samples=N_samples, burn_in_steps=2, reset_chain=False, random_edge=False, equal_partition=True, dtype=dtype, subchain_length=10)
# mps_dir = '/home/sijingdu/TNVMC/VMC_code/vmc_torch/data'+f'/{Lx}x{Ly}/t={t}_U={U}/N={N_f}/tmp'
# sampler = MetropolisMPSSamplerSpinful(H.hilbert, graph, mps_dir=mps_dir, mps_n_sample=1, N_samples=N_samples, burn_in_steps=20, reset_chain=True, equal_partition=True, dtype=dtype)

In [9]:
from vmc_torch.fermion_utils import *
from vmc_torch.experiment.tn_model import *

class fTNModel_vec_test(wavefunctionModel):

    def __init__(self, ftn, max_bond=None, dtype=torch.float32, functional=False, tree=None, device=None):
        super().__init__()
        self.param_dtype = dtype
        self.functional = functional
        self.device = device
        # extract the raw arrays and a skeleton of the TN
        params, self.skeleton = qtn.pack(ftn)

        # Flatten the dictionary structure and assign each parameter as a part of a ModuleDict
        self.torch_tn_params = nn.ModuleDict({
            str(tid): nn.ParameterDict({
                str(sector): nn.Parameter(data)
                for sector, data in blk_array.items()
            })
            for tid, blk_array in params.items()
        })

        # Get symmetry
        self.symmetry = ftn.arrays[0].symmetry

        # Store the shapes of the parameters
        self.param_shapes = [param.shape for param in self.parameters()]

        self.model_structure = {
            f'fPEPS (chi={max_bond})':{'D': ftn.max_bond(), 'Lx': ftn.Lx, 'Ly': ftn.Ly, 'symmetry': self.symmetry},
        }
        if max_bond is None or max_bond <= 0:
            max_bond = None
        self.max_bond = max_bond

        opt = ctg.HyperOptimizer(progbar=True, max_repeats=10, parallel=True)
        # Get the amplitude
        random_x = torch.randint(0, 3, (ftn.Lx*ftn.Ly,), dtype=torch.int)
        amp = ftn.get_amp(random_x, conj=True, functional=self.functional)
        self.tree = amp.contraction_tree(optimize=opt)

        # Get self parity
        self.parity_config = torch.tensor([array.parity for array in ftn.arrays], dtype=torch.int, device=self.device)

        # BUG: in tree.traverse(), the tids are not automatically sorted, so we need to sort them manually
        self.sorted_tree_traverse_path = {i: tuple(sorted(left_tids))+tuple(sorted(right_tids)) for i, (_, left_tids, right_tids) in enumerate(self.tree.traverse())}

        # compute the permutation dict for future global phase computation
        self.perm_dict = {i: tuple(torch.argsort(torch.tensor(left_right_tids))) for i, left_right_tids in self.sorted_tree_traverse_path.items()}
        self.perm_dict_desc = {i: tuple(torch.argsort(torch.tensor(tuple(sorted(left_tids))[::-1]+tuple(sorted(right_tids))[::-1]), descending=True)) for i, (_, left_tids, right_tids) in enumerate(self.tree.traverse())}
        
        self.adjacent_transposition_dict_asc = {i: decompose_permutation_into_transpositions(perm, asc=False) for i, perm in self.perm_dict.items()}
        self.adjacent_transposition_dict_desc = {i: decompose_permutation_into_transpositions(perm, asc=False) for i, perm in self.perm_dict_desc.items()}

    
    def compute_global_phase(self, input_config):
        """Get the global phase of contracting an amplitude of the fPEPS given computational graph."""
        on_site_parity_tensor = torch.tensor([0,1,1,0], dtype=torch.int, device=input_config.device)
        def get_parity(n):
            return on_site_parity_tensor[n]
        # input_parity_config = input_config % 2
        input_config_parity = get_parity(input_config)
        amp_parity_config = (self.parity_config + input_config_parity) % 2

        phase = 1
        phase *= calculate_phase_from_adjacent_trans_dict(
            self.tree, 
            input_config_parity, 
            self.parity_config, 
            amp_parity_config, 
            self.adjacent_transposition_dict_asc, 
            self.adjacent_transposition_dict_desc,
            self.sorted_tree_traverse_path
            )
            
        return phase

    def amplitude(self, x):
        # Reconstruct the original parameter structure (by unpacking from the flattened dict)
        params = {
            int(tid): {
                ast.literal_eval(sector): data
                for sector, data in blk_array.items()
            }
            for tid, blk_array in self.torch_tn_params.items()
        }
        # Reconstruct the TN with the new parameters
        psi = qtn.unpack(params, self.skeleton)
        # `x` is expected to be batched as (batch_size, input_dim)

        def amplitude_func(psi, x_i):
            # Check x_i type
            if not type(x_i) == torch.Tensor:
                x_i = torch.tensor(x_i, dtype=torch.int if self.functional else self.param_dtype)
            else:
                if x_i.dtype != self.param_dtype:
                    x_i = x_i.to(torch.int if self.functional else self.param_dtype)
            # Get the amplitude
            amp = psi.get_amp(x_i, conj=True, functional=self.functional)
            if self.max_bond is None:
                amp = amp
                amp_val = amp.contract(optimize=self.tree)
                phase = self.compute_global_phase(x_i.int())
                amp_val = phase * amp_val

            else:
                amp = amp.contract_boundary_from_xmin(max_bond=self.max_bond, cutoff=0.0, xrange=[0, psi.Lx//2-1])
                amp = amp.contract_boundary_from_xmax(max_bond=self.max_bond, cutoff=0.0, xrange=[psi.Lx//2, psi.Lx-1])
                amp_val = amp.contract()

            # if amp_val==0.0:
            #     amp_val = torch.tensor(0.0)

            # Return the batch of amplitudes stacked as a tensor
            return amp_val
        
        # vec_amplitude_func = vmap(amplitude_func, in_dims=(None, 0), randomness='different')
        # # Get the amplitude
        # batch_amps = vec_amplitude_func(psi, x)

        return amplitude_func(psi, x)
    
    def forward(self, x):
        return self.amplitude(x)
    
class fPEPS_vec(qtn.PEPS):
    def __init__(self, arrays, *, shape="urdlp", tags=None, site_ind_id="k{},{}", site_tag_id="I{},{}", x_tag_id="X{}", y_tag_id="Y{}", **tn_opts):
        super().__init__(arrays, shape=shape, tags=tags, site_ind_id=site_ind_id, site_tag_id=site_tag_id, x_tag_id=x_tag_id, y_tag_id=y_tag_id, **tn_opts)
        self.symmetry = self.arrays[0].symmetry
        self.spinless = True if self.phys_dim() == 2 else False
    
    def product_bra_state(self, config, reverse=1):
        product_tn = qtn.TensorNetwork()
        backend = self.tensors[0].data.backend
        dtype = eval(backend+'.'+self.tensors[0].data.dtype)
        if type(config) == numpy.ndarray:
            kwargs = {'like':config, 'dtype':dtype}
        elif type(config) == torch.Tensor:
            device = list(self.tensors[0].data.blocks.values())[0].device
            kwargs = {'like':config, 'device':device, 'dtype':dtype}
        if self.spinless:
            index_map = {0: 0, 1: 1}
            array_map = {
                0: do('array',[1.0,],**kwargs), 
                1: do('array',[1.0,],**kwargs)
            }
        else:
            if self.symmetry == 'Z2':
                index_map = {0:0, 1:1, 2:1, 3:0}
                array_map = {
                    0: do('array',[1.0, 0.0],**kwargs), 
                    1: do('array',[1.0, 0.0],**kwargs), 
                    2: do('array',[0.0, 1.0],**kwargs), 
                    3: do('array',[0.0, 1.0],**kwargs)
                }
            elif self.symmetry == 'U1':
                index_map = {0:0, 1:1, 2:1, 3:2}
                array_map = {
                    0: do('array',[1.0,],**kwargs), 
                    1: do('array',[1.0, 0.0],**kwargs), 
                    2: do('array',[0.0, 1.0],**kwargs), 
                    3: do('array',[1.0,],**kwargs)
                }
            elif self.symmetry == 'U1U1':
                index_map = {0:(0,0), 1:(0,1), 2:(1,0), 3:(1,1)}
                array_map = {
                    0: do('array',[1.0],**kwargs),
                    1: do('array',[1.0],**kwargs), 
                    2: do('array',[1.0],**kwargs),
                    3: do('array',[1.0],**kwargs)
                }

        for n, site in zip(config, self.sites):
            p_ind = self.site_ind_id.format(*site)
            p_tag = self.site_tag_id.format(*site)
            tid = self.sites.index(site)

            n_charge = index_map[int(n)]
            n_array = array_map[int(n)]

            oddpos = None
            if not self.spinless:
                # assert self.symmetry == 'U1', "Only U1 symmetry is supported for spinful fermions for now."
                if int(n) == 1:
                    oddpos = (3*tid+1)*(-1)**reverse
                elif int(n) == 2:
                    oddpos = (3*tid+2)*(-1)**reverse
                elif int(n) == 3:
                    # oddpos = ((3*tid+1)*(-1)**reverse, (3*tid+2)*(-1)**reverse)
                    oddpos = None
            else:
                oddpos = (3*tid+1)*(-1)**reverse

            tsr_data = sr.FermionicArray.from_blocks(
                blocks={(n_charge,):n_array}, 
                duals=(True,),
                symmetry=self.symmetry, 
                charge=n_charge, 
                oddpos=oddpos
            )
            tsr = qtn.Tensor(data=tsr_data, inds=(p_ind,),tags=(p_tag, 'bra'))
            product_tn |= tsr

        return product_tn
    
    def product_bra_state_functional(self, config, reverse=1):
        #XXX remember to comment out the drop_missing_blocks in tensordot_via_fused in line 2304-2305 in symmray abelian_core.py
        product_tn = qtn.TensorNetwork()
        backend = self.tensors[0].data.backend
        dtype = eval(backend+'.'+self.tensors[0].data.dtype)
        if type(config) == numpy.ndarray:
            kwargs = {'like':config, 'dtype':dtype}
        elif type(config) == torch.Tensor:
            device = list(self.tensors[0].data.blocks.values())[0].device
            kwargs = {'like':config, 'device':device, 'dtype':dtype}
        if self.spinless:
            raise NotImplementedError("Functional bra state is not implemented for spinless fermions.")
        else:
            if self.symmetry == 'Z2':
                charge_tensor = torch.tensor([0, 1, 1, 0], dtype=torch.int, device=device)
                vector_tensor = do('array', [[1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], **kwargs)
            else:
                raise NotImplementedError("Functional bra state is not implemented for spinful fermions.")

        for n, site in zip(config, self.sites):
            p_ind = self.site_ind_id.format(*site)
            p_tag = self.site_tag_id.format(*site)
            tid = self.sites.index(site)
            # n_charge = index_map[n.unsqueeze(0).int()].squeeze(0)
            n_charge = 0
            oddpos = None
            if not self.spinless:
                phase = 1 #...
            else:
                raise NotImplementedError("Functional bra state is not implemented for spinless fermions.")
            blocks={(int(charge_tensor[n.int()]),): vector_tensor[n.int()]}
            tsr_data = sr.FermionicArray.from_blocks(
                blocks=blocks, 
                duals=(True,),
                symmetry=self.symmetry, 
                charge=n_charge, 
            )
            tsr = qtn.Tensor(data=tsr_data, inds=(p_ind,),tags=(p_tag, 'bra'))
            product_tn |= tsr

        return product_tn
    
    # NOTE: don't use @classmethod here, as we need to access the specific instance attributes
    def get_amp(self, config, inplace=False, conj=True, reverse=1, contract=True, functional=False):
        """Get the amplitude of a configuration in a PEPS."""
        if functional:
            return self.get_amp_functional(config, inplace=inplace)
        
        peps = self if inplace else self.copy()
        product_state = self.product_bra_state(config, reverse=reverse).conj() if conj else self.product_bra_state(config, reverse=reverse)
        
        amp = peps|product_state # ---T---<---|n>

        if not contract:
            return amp
        
        for site in peps.sites:
            site_tag = peps.site_tag_id.format(*site)
            amp.contract_(tags=site_tag)

        amp.view_as_(
            qtn.PEPS,
            site_ind_id="k{},{}",
            site_tag_id="I{},{}",
            x_tag_id="X{}",
            y_tag_id="Y{}",
            Lx=peps.Lx,
            Ly=peps.Ly,
        )
        return amp
    
    
    def get_amp_functional(self, config, inplace=False, conj=True, reverse=1, contract=True):
        peps = self if inplace else self.copy()
        product_state = self.product_bra_state_functional(config, reverse=reverse).conj() if conj else self.product_bra_state_functional(config, reverse=reverse)
        
        amp = peps|product_state # ---T---<---|n>

        if not contract:
            return amp
        
        for site in peps.sites:
            site_tag = peps.site_tag_id.format(*site)
            amp.contract_(tags=site_tag)

        amp.view_as_(
            qtn.PEPS,
            site_ind_id="k{},{}",
            site_tag_id="I{},{}",
            x_tag_id="X{}",
            y_tag_id="Y{}",
            Lx=peps.Lx,
            Ly=peps.Ly,
        )
        return amp
    
    
    def get_amp_efficient(self, config, inplace=False):
        """Slicing to get the amplitude, faster than contraction with a tensor product state."""
        peps = self if inplace else self.copy()
        backend = self.tensors[0].data.backend
        dtype = eval(backend + '.' + self.tensors[0].data.dtype)
        if type(config) == numpy.ndarray:
            kwargs = {'like': config, 'dtype': dtype}
        elif type(config) == torch.Tensor:
            device = list(self.tensors[0].data.blocks.values())[0].device
            kwargs = {'like': config, 'device': device, 'dtype': dtype}
        
        
        if self.spinless:
            raise NotImplementedError("Efficient amplitude calculation is not implemented for spinless fermions.")
        else:
            if self.symmetry == 'Z2':
                index_map = {0: 0, 1: 1, 2: 1, 3: 0}
                array_map = {
                    0: do('array', [1.0, 0.0], **kwargs),
                    1: do('array', [1.0, 0.0], **kwargs),
                    2: do('array', [0.0, 1.0], **kwargs),
                    3: do('array', [0.0, 1.0], **kwargs)
                }
            elif self.symmetry == 'U1':
                index_map = {0: 0, 1: 1, 2: 1, 3: 2}
                array_map = {
                    0: do('array', [1.0], **kwargs),
                    1: do('array', [1.0, 0.0], **kwargs),
                    2: do('array', [0.0, 1.0], **kwargs),
                    3: do('array', [1.0], **kwargs)
                }
            elif self.symmetry == 'U1U1':
                index_map = {0:(0,0), 1:(0,1), 2:(1,0), 3:(1,1)}
                array_map = {
                    0: do('array',[1.0],**kwargs),
                    1: do('array',[1.0],**kwargs), 
                    2: do('array',[1.0],**kwargs),
                    3: do('array',[1.0],**kwargs)
                }
            

            for n, site in zip(config, self.sites):
                p_ind = peps.site_ind_id.format(*site)
                site_id = peps.sites.index(site)
                fts = peps.tensors[site_id]
                ftsdata = fts.data
                ftsdata.phase_sync(inplace=True) # explicitly apply all lazy phases that are stored and not yet applied
                phys_ind_order = fts.inds.index(p_ind)
                charge = index_map[int(n)]
                input_vec = array_map[int(n)]
                charge_sec_data_dict = ftsdata.blocks

                new_fts_inds = fts.inds[:phys_ind_order] + fts.inds[phys_ind_order + 1:]
                new_charge_sec_data_dict = {}
                for charge_blk, data in charge_sec_data_dict.items():
                    if charge_blk[phys_ind_order] == charge:
                        # new_data = data @ input_vec #BUG: This is not correct, should contract with the correct tensor index
                        new_data = do('tensordot', data, input_vec, axes=([phys_ind_order], [0]))
                        new_charge_blk = charge_blk[:phys_ind_order] + charge_blk[phys_ind_order + 1:]
                        new_charge_sec_data_dict[new_charge_blk] = new_data

                new_duals = ftsdata.duals[:phys_ind_order] + ftsdata.duals[phys_ind_order + 1:]

                if int(n) == 1:
                    new_oddpos = (3 * site_id + 1) * (-1)
                elif int(n) == 2:
                    new_oddpos = (3 * site_id + 2) * (-1)
                elif int(n) == 3 or int(n) == 0:
                    new_oddpos = ()

                new_oddpos1 = FermionicOperator(new_oddpos, dual=True) if new_oddpos != () else ()
                new_oddpos = ftsdata.oddpos + (new_oddpos1,) if new_oddpos1 is not () else ftsdata.oddpos
                oddpos = list(new_oddpos)[::-1]
                try:
                    if ftsdata.symmetry == 'U1':
                        new_charge = charge + ftsdata.charge
                    elif ftsdata.symmetry == 'Z2':
                        new_charge = (charge + ftsdata.charge) % 2 # Z2 symmetry, charge should be 0 or 1
                    elif ftsdata.symmetry == 'U1U1':
                        new_charge = (charge[0] + ftsdata.charge[0], charge[1] + ftsdata.charge[1]) # U1U1 symmetry, charge should be a tuple of two integers
                    new_fts_data = sr.FermionicArray.from_blocks(new_charge_sec_data_dict, duals=new_duals, charge=new_charge, oddpos=oddpos, symmetry=ftsdata.symmetry)
                except:
                    print(n, site, phys_ind_order, charge_sec_data_dict, new_charge_sec_data_dict)
                    
                fts.modify(data=new_fts_data, inds=new_fts_inds, left_inds=None)

            amp = qtn.PEPS(peps)

            return amp

# The only possible solution I can think of is to make the charge tuples also tensors. Then everything can be done at tensor level.

In [15]:
from torch.func import vmap


X = [H.hilbert.random_state(i) for i in range(10)]
X = torch.tensor(X, dtype=dtype, device=device)
vec_peps = fPEPS_vec(peps, Lx=Lx, Ly=Ly, symmetry=symmetry)
model = fTNModel_vec_test(vec_peps, max_bond=chi, dtype=dtype, functional=True)
# model1 = fTNModel(vec_peps, max_bond=chi, dtype=dtype, functional=False)
vmap(model)(X)
# func_amp = vec_peps.get_amp(X[0], functional=True)
# func_amp.contract()

F=4.76 C=5.57 S=10.00 P=11.43: 100%|██████████| 10/10 [00:00<00:00, 622.98it/s]


RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.

In [24]:
charge_tensor = torch.tensor([0, 1, 1, 0], dtype=torch.int, device=device)
vector_tensor = torch.tensor([[1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], dtype=dtype, device=device)
for n in X[0]:
    print((charge_tensor[n.int()],))

(tensor(1, dtype=torch.int32),)
(tensor(1, dtype=torch.int32),)
(tensor(0, dtype=torch.int32),)
(tensor(0, dtype=torch.int32),)
(tensor(1, dtype=torch.int32),)
(tensor(0, dtype=torch.int32),)
(tensor(0, dtype=torch.int32),)
(tensor(1, dtype=torch.int32),)
(tensor(0, dtype=torch.int32),)
(tensor(1, dtype=torch.int32),)
(tensor(1, dtype=torch.int32),)
(tensor(1, dtype=torch.int32),)
(tensor(1, dtype=torch.int32),)
(tensor(0, dtype=torch.int32),)
(tensor(1, dtype=torch.int32),)
(tensor(1, dtype=torch.int32),)


In [None]:
d0 = {(0,): torch.tensor([1.0], dtype=dtype, device=device)}
d1 = {(1,): torch.tensor([2.0], dtype=dtype, device=device)}

data = {(0,): torch.tensor([1.0], dtype=dtype, device=device), (1,): torch.tensor([-1.0], dtype=dtype, device=device)}
