In [1]:
import os
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
from mpi4py import MPI
import pickle


# torch
from torch.nn.parameter import Parameter
import torch
import torch.nn as nn
torch.autograd.set_detect_anomaly(False)

# quimb
import quimb as qu
import quimb.tensor as qtn
import autoray as ar
from autoray import do

# from vmc_torch.experiment.tn_model import *
from vmc_torch.sampler import MetropolisExchangeSamplerSpinful
from vmc_torch.variational_state import Variational_State
from vmc_torch.optimizer import SGD, SignedSGD, SignedRandomSGD, SR, TrivialPreconditioner, Adam, SGD_momentum, DecayScheduler
from vmc_torch.VMC import VMC
# from vmc_torch.hamiltonian import spinful_Fermi_Hubbard_square_lattice
from vmc_torch.hamiltonian_torch import spinful_Fermi_Hubbard_square_lattice_torch
from vmc_torch.torch_utils import SVD,QR
from vmc_torch.fermion_utils import generate_random_fpeps
from vmc_torch.utils import closest_divisible

# Register safe SVD and QR functions to torch
ar.register_function('torch','linalg.svd',SVD.apply)
ar.register_function('torch','linalg.qr',QR.apply)

from vmc_torch.global_var import DEBUG


COMM = MPI.COMM_WORLD
SIZE = COMM.Get_size()
RANK = COMM.Get_rank()

# Hamiltonian parameters
Lx = int(4)
Ly = int(4)
symmetry = 'Z2'
t = 1.0
U = 8.0
N_f = int(Lx*Ly-2)
# N_f=12
n_fermions_per_spin = (N_f//2, N_f//2)
H = spinful_Fermi_Hubbard_square_lattice_torch(Lx, Ly, t, U, N_f, pbc=False, n_fermions_per_spin=n_fermions_per_spin)
graph = H.graph
# TN parameters
D = 4
chi = 16
dtype=torch.float64

# Load PEPS
try:
    skeleton = pickle.load(open(f"../../data/{Lx}x{Ly}/t={t}_U={U}/N={N_f}/{symmetry}/D={D}/peps_skeleton.pkl", "rb"))
    peps_params = pickle.load(open(f"../../data/{Lx}x{Ly}/t={t}_U={U}/N={N_f}/{symmetry}/D={D}/peps_su_params.pkl", "rb"))
    peps = qtn.unpack(peps_params, skeleton)
except:
    peps = generate_random_fpeps(Lx, Ly, D=D, seed=2, symmetry=symmetry, Nf=N_f, spinless=False)[0]
peps_np = peps.copy()
peps.apply_to_arrays(lambda x: torch.tensor(x, dtype=dtype))

# VMC sample size
N_samples = 2
N_samples = closest_divisible(N_samples, SIZE)

In [2]:
from vmc_torch.experiment.tn_model import wavefunctionModel
import ast
import cotengra as ctg
class fTNModel_reuse(wavefunctionModel):
    def __init__(self, ftn, max_bond=None, dtype=torch.float32, functional=False, debug=False):
        super().__init__()
        self.param_dtype = dtype
        self.functional = functional
        self.debug = debug
        # extract the raw arrays and a skeleton of the TN
        params, self.skeleton = qtn.pack(ftn)
        self.skeleton.exponent = 0

        # Flatten the dictionary structure and assign each parameter as a part of a ModuleDict
        self.torch_tn_params = nn.ModuleDict({
            str(tid): nn.ParameterDict({
                str(sector): nn.Parameter(data)
                for sector, data in blk_array.items()
            })
            for tid, blk_array in params.items()
        })

        # Get symmetry
        self.symmetry = ftn.arrays[0].symmetry

        # Store the shapes of the parameters
        self.param_shapes = [param.shape for param in self.parameters()]

        self.model_structure = {
            f'fPEPS (chi={max_bond})':{'D': ftn.max_bond(), 'Lx': ftn.Lx, 'Ly': ftn.Ly, 'symmetry': self.symmetry},
        }
        if max_bond is None or max_bond <= 0:
            max_bond = None
        self.max_bond = max_bond
        self.tree = None
        self.Lx = ftn.Lx
        self.Ly = ftn.Ly
        self._env_x_cache = None
        self._env_y_cache = None
        self.config_ref = None
        self.amp_ref = None
    
    def from_1d_to_2d(self, config, ordering='snake'):
        if ordering == 'snake':
            config_2d = config.reshape((self.Lx, self.Ly))
            return config_2d
        else:
            raise NotImplementedError(f'Ordering {ordering} is not implemented.')
        
    def from_1dsite_to_2dsite(self, site, ordering='snake'):
        """
            Convert a 1d site index to a 2d site index.
            site: 1d site index
        """
        if ordering == 'snake':
            return (site // self.Ly, site % self.Ly)
        else:
            raise ValueError(f"Unsupported ordering: {ordering}")
    
    def transform_quimb_env_x_key_to_config_key(self, env_x, config):
        """
            Return a dictionary with the keys of of the config rows
        """
        config_2d = self.from_1d_to_2d(config)
        env_x_row_config = {}
        for key in env_x.keys():
            if key[0] == 'xmax': # from bottom to top
                row_n = key[1]
                if row_n != self.Lx-1:
                    rows_config = tuple(torch.cat(tuple(config_2d[row_n+1:].to(torch.int))).tolist())
                    env_x_row_config[('xmax', rows_config)] = env_x[key]
            elif key[0] == 'xmin': # from top to bottom
                row_n = key[1]
                if row_n != 0:
                    rows_config = tuple(torch.cat(tuple(config_2d[:row_n].to(torch.int))).tolist())
                    env_x_row_config[('xmin', rows_config)] = env_x[key]
        return env_x_row_config
    
    def transform_quimb_env_y_key_to_config_key(self, env_y, config):
        """
            Return a dictionary with the keys of of the config rows
        """
        config_2d = self.from_1d_to_2d(config)
        env_y_row_config = {}
        for key in env_y.keys():
            if key[0] == 'ymax':
                col_n = key[1]
                if col_n != self.Ly-1:
                    cols_config = tuple(torch.cat(tuple(config_2d[:, col_n+1:].to(torch.int))).tolist())
                    env_y_row_config[('ymax', cols_config)] = env_y[key]
            elif key[0] == 'ymin':
                col_n = key[1]
                if col_n != 0:
                    cols_config = tuple(torch.cat(tuple(config_2d[:, :col_n].to(torch.int))).tolist())
                    env_y_row_config[('ymin', cols_config)] = env_y[key]
        return env_y_row_config

    def cache_env_x(self, amp, config):
        """
            Cache the environment x for the given configuration
        """
        env_x = amp.compute_x_environments(max_bond=self.max_bond, cutoff=0.0)
        env_x_cache = self.transform_quimb_env_x_key_to_config_key(env_x, config)
        self._env_x_cache = env_x_cache
        self.config_ref = config
        self.amp_ref = amp
    
    def cache_env_y(self, amp, config):
        env_y = amp.compute_y_environments(max_bond=self.max_bond, cutoff=0.0)
        env_y_cache = self.transform_quimb_env_y_key_to_config_key(env_y, config)
        self._env_y_cache = env_y_cache
        self.config_ref = config
        self.amp_ref = amp
    
    def cache_env(self, amp, config):
        """
            Cache the environment x and y for the given configuration
        """
        self.cache_env_x(amp, config)
        self.cache_env_y(amp, config)
        
    @property
    def env_x_cache(self):
        """
            Return the cached environment x
        """
        if hasattr(self, '_env_x_cache'):
            return self._env_x_cache
        else:
            return None
        
    @property
    def env_y_cache(self):
        """
            Return the cached environment y
        """
        if hasattr(self, '_env_y_cache'):
            return self._env_y_cache
        else:
            return None
    
    def clear_env_x_cache(self):
        """
            Clear the cached environment x
        """
        self._env_x_cache = None

    def clear_env_y_cache(self):
        """
            Clear the cached environment y
        """
        self._env_y_cache = None
    
    def clear_wavefunction_env_cache(self):
        self.clear_env_x_cache()
        self.clear_env_y_cache()
        self.config_ref = None
        self.amp_ref = None
    
    def detect_changed_sites(self, config_ref, new_config):
        """
            Detect the sites that have changed in the new configuration,
            written in 1d coordinate format.
        """
        changed_sites = set()
        unchanged_sites = set()
        for i in range(self.Lx * self.Ly):
            if config_ref[i] != new_config[i]:
                changed_sites.add(i)
            else:
                unchanged_sites.add(i)
        changed_sites = sorted(changed_sites)
        unchanged_sites = sorted(unchanged_sites)
        if len(changed_sites) == 0:
            return [], []
        return changed_sites, unchanged_sites

    def from_1d_sites_to_tids(self, sites):
        """
            Convert a list of 1d site indices to a list of tensor ids.
        """
        tids_list = list(self.skeleton.tensor_map.keys())
        return [tids_list[site] for site in sites]
    
    def detect_changed_rows(self, config_ref, new_config):
        """
            Detect the rows that have changed in the new configuration
        """
        config_ref_2d = self.from_1d_to_2d(config_ref)
        new_config_2d = self.from_1d_to_2d(new_config)
        changed_rows = []
        for i in range(self.Lx):
            if not torch.equal(config_ref_2d[i], new_config_2d[i]):
                changed_rows.append(i)
        if len(changed_rows) == 0:
            return [], [], []
        unchanged_rows_above = list(range(changed_rows[0]))
        unchanged_rows_below = list(range(changed_rows[-1]+1, self.Lx))
        return changed_rows, unchanged_rows_above, unchanged_rows_below
    
    def detect_changed_cols(self, config_ref, new_config):
        """
            Detect the columns that have changed in the new configuration
        """
        config_ref_2d = self.from_1d_to_2d(config_ref)
        new_config_2d = self.from_1d_to_2d(new_config)
        changed_cols = []
        for i in range(self.Ly):
            if not torch.equal(config_ref_2d[:, i], new_config_2d[:, i]):
                changed_cols.append(i)
        if len(changed_cols) == 0:
            return [], [], []
        unchanged_cols_left = list(range(changed_cols[0]))
        unchanged_cols_right = list(range(changed_cols[-1]+1, self.Ly))
        return changed_cols, unchanged_cols_left, unchanged_cols_right
    
    def update_env_x_cache(self, config):
        """
            Update the cached environment x for the given configuration
        """
        if self.env_x_cache is not None:
            self.clear_env_x_cache()
        amp_tn = self.get_amp_tn(config)
        self.cache_env_x(amp_tn, config)
        self.config_ref = config
        self.amp_ref = amp_tn
    
    def update_env_x_cache_to_row(self, config, row_id, from_which='xmin'):
        amp_tn = self.get_amp_tn(config)
        new_env_x = amp_tn.compute_environments(max_bond=self.max_bond, cutoff=0.0, xrange=(0, row_id+1) if from_which=='xmin' else (row_id-1, self.Lx-1), from_which=from_which)
        new_env_x_cache = self.transform_quimb_env_x_key_to_config_key(new_env_x, config)
        # add the new env_x to the cache
        if self.env_x_cache is None:
            self._env_x_cache = new_env_x_cache
        else:
            self._env_x_cache.update(new_env_x_cache)
        self.config_ref = config
        self.amp_ref = amp_tn
    
    def update_env_y_cache(self, config):
        """
            Update the cached environment y for the given configuration
        """
        if self.env_y_cache is not None:
            self.clear_env_y_cache()
        amp_tn = self.get_amp_tn(config)
        self.cache_env_y(amp_tn, config)
        self.config_ref = config
        self.amp_ref = amp_tn
    
    def update_env_y_cache_to_col(self, config, col_id, from_which='ymin'):
        amp_tn = self.get_amp_tn(config)
        new_env_y = amp_tn.compute_environments(max_bond=self.max_bond, cutoff=0.0, yrange=(0, col_id+1) if from_which=='ymin' else (col_id-1, self.Ly-1), from_which=from_which)
        new_env_y_cache = self.transform_quimb_env_y_key_to_config_key(new_env_y, config)
        # add the new env_y to the cache
        if self.env_y_cache is None:
            self._env_y_cache = new_env_y_cache
        else:
            self._env_y_cache.update(new_env_y_cache)
        self.config_ref = config
        self.amp_ref = amp_tn
    
    def psi(self):
        """
            Return the wavefunction (fPEPS)
        """
        # Reconstruct the original parameter structure (by unpacking from the flattened dict)
        params = {
            int(tid): {
                ast.literal_eval(sector): data
                for sector, data in blk_array.items()
            }
            for tid, blk_array in self.torch_tn_params.items()
        }
        # Reconstruct the TN with the new parameters
        psi = qtn.unpack(params, self.skeleton)
        return psi

    def get_local_amp_tensors(self, sites:list, config:torch.Tensor):
        """
            Get the local tensors for the given tensor ids and configuration.
            tids: a list of tensor ids. list of int.
            config: the input configuration.
        """
        # first pick out the tensor parameters and form the local tn parameters vector
        local_ts_params = {}
        tids = self.from_1d_sites_to_tids(sites)
        for tid in tids:
            local_ts_params[tid] = {
                ast.literal_eval(sector): data
                for sector, data in self.torch_tn_params[str(tid)].items()
            }
        
        # Get sites corresponding to the tids
        sites_2d = [self.from_1dsite_to_2dsite(site) for site in sites]

        # Select the corresponding tensor skeleton
        local_ts_skeleton = self.skeleton.select([self.skeleton.site_tag_id.format(*site) for site in sites_2d], which='any')

        # Reconstruct the TN with the new parameters
        local_ftn = qtn.unpack(local_ts_params, local_ts_skeleton)

        # Fix the physical indices
        return local_ftn.fix_phys_inds(sites_2d, config[sites])
    
    def get_amp_tn(self, config, reconstruct=False):

        if self.amp_ref is None or reconstruct:
            psi = self.psi()
            # Check config type
            if not type(config) == torch.Tensor:
                config = torch.tensor(config, dtype=torch.int if self.functional else self.param_dtype)
            else:
                if config.dtype != self.param_dtype:
                    config = config.to(torch.int if self.functional else self.param_dtype)
            # Get the amplitude
            amp_tn = psi.get_amp(config, conj=True, functional=self.functional)
            return amp_tn
        
        else:
            # detect the sites that have changed
            changed_sites, unchanged_sites = self.detect_changed_sites(self.config_ref, config)

            if len(changed_sites) == 0:
                return self.amp_ref
            else:
                # substitute the changed sites tensors
                local_amp_tn = self.get_local_amp_tensors(changed_sites, config)
                unchanged_sites_2d = [self.from_1dsite_to_2dsite(site) for site in unchanged_sites]
                unchanged_sites_tags = [self.skeleton.site_tag_id.format(*site) for site in unchanged_sites_2d]
                unchanged_amp_tn = self.amp_ref.select(unchanged_sites_tags, which='any')
                # merge the local_amp_tn and unchanged_amp_tn
                amp_tn = local_amp_tn | unchanged_amp_tn
                return amp_tn
    
    def amplitude(self, x):
        # Reconstruct the original parameter structure (by unpacking from the flattened dict)
        params = {
            int(tid): {
                ast.literal_eval(sector): data
                for sector, data in blk_array.items()
            }
            for tid, blk_array in self.torch_tn_params.items()
        }
        # Reconstruct the TN with the new parameters
        psi = qtn.unpack(params, self.skeleton)
        # `x` is expected to be batched as (batch_size, input_dim)
        # Loop through the batch and compute amplitude for each sample
        batch_amps = []
        for x_i in x:
            # Check x_i type
            if not type(x_i) == torch.Tensor:
                x_i = torch.tensor(x_i, dtype=torch.int if self.functional else self.param_dtype)
            else:
                if x_i.dtype != self.param_dtype:
                    x_i = x_i.to(torch.int if self.functional else self.param_dtype)
            # Get the amplitude
            # amp = psi.get_amp(x_i, conj=True, functional=self.functional)
            amp = self.get_amp_tn(x_i)

            if self.max_bond is None:
                amp = amp
                if self.tree is None:
                    opt = ctg.HyperOptimizer(progbar=True, max_repeats=10, parallel=True)
                    self.tree = amp.contraction_tree(optimize=opt)
                amp_val = amp.contract(optimize=self.tree)

            else:
                if self.cache_env_mode:
                    self.cache_env_x(amp, x_i)
                    # self.cache_env_y(amp, x_i)
                    self.config_ref = x_i
                    config_2d = self.from_1d_to_2d(x_i)
                    key_bot = ('xmax', tuple(torch.cat(tuple(config_2d[self.Lx//2:].to(torch.int))).tolist()))
                    key_top = ('xmin', tuple(torch.cat(tuple(config_2d[:self.Lx//2].to(torch.int))).tolist()))
                    amp_bot = self.env_x_cache[key_bot]
                    amp_top = self.env_x_cache[key_top]
                    amp_val = (amp_bot|amp_top).contract()
                    

                else:
                    if self.env_x_cache is None and self.env_y_cache is None:
                        # check whether we can reuse the cached environment
                        amp = amp.contract_boundary_from_ymin(max_bond=self.max_bond, cutoff=0.0, yrange=[0, psi.Ly//2-1])
                        amp = amp.contract_boundary_from_ymax(max_bond=self.max_bond, cutoff=0.0, yrange=[psi.Ly//2, psi.Ly-1])
                        amp_val = amp.contract()
                    else:
                        config_2d = self.from_1d_to_2d(x_i)
                        # detect the rows that have changed
                        changed_rows, unchanged_rows_above, unchanged_rows_below = self.detect_changed_rows(self.config_ref, x_i)
                        # detect the columns that have changed
                        changed_cols, unchanged_cols_left, unchanged_cols_right = self.detect_changed_cols(self.config_ref, x_i)
                        if len(changed_rows) == 0:
                            key_bot = ('xmax', tuple(torch.cat(tuple(config_2d[self.Lx//2:].to(torch.int))).tolist()))
                            key_top = ('xmin', tuple(torch.cat(tuple(config_2d[:self.Lx//2].to(torch.int))).tolist()))
                            amp_bot = self.env_x_cache[key_bot]
                            amp_top = self.env_x_cache[key_top]
                            amp_val = (amp_bot|amp_top).contract()
                        else:
                            if len(changed_rows) <= len(changed_cols):
                                # for bottom envs, until the last row in the changed rows, we can reuse the env
                                # for top envs, until the first row in the changed rows, we can reuse the env
                                amp_changed_rows = qtn.TensorNetwork([amp.select(amp.x_tag_id.format(row_n)) for row_n in changed_rows])
                                amp_unchanged_bottom_env = qtn.TensorNetwork()
                                amp_unchanged_top_env = qtn.TensorNetwork()
                                if len(unchanged_rows_below) != 0:
                                    amp_unchanged_bottom_env = self.env_x_cache[('xmax', tuple(torch.cat(tuple(config_2d[unchanged_rows_below].to(torch.int))).tolist()))]
                                if len(unchanged_rows_above) != 0:
                                    amp_unchanged_top_env = self.env_x_cache[('xmin', tuple(torch.cat(tuple(config_2d[unchanged_rows_above].to(torch.int))).tolist()))]
                                amp_val = (amp_unchanged_bottom_env|amp_unchanged_top_env|amp_changed_rows).contract()
                                # print(f'changed rows: {changed_rows}', self.from_1d_to_2d(x_i), self.from_1d_to_2d(self.config_ref))
                            else:
                                # for left envs, until the first column in the changed columns, we can reuse the env
                                # for right envs, until the last column in the changed columns, we can reuse the env
                                amp_changed_cols = qtn.TensorNetwork([amp.select(amp.y_tag_id.format(col_n)) for col_n in changed_cols])
                                amp_unchanged_left_env = qtn.TensorNetwork()
                                amp_unchanged_right_env = qtn.TensorNetwork()
                                if len(unchanged_cols_left) != 0:
                                    amp_unchanged_left_env = self.env_y_cache[('ymin', tuple(torch.cat(tuple(config_2d[:, unchanged_cols_left].to(torch.int))).tolist()))]
                                if len(unchanged_cols_right) != 0:
                                    amp_unchanged_right_env = self.env_y_cache[('ymax', tuple(torch.cat(tuple(config_2d[:, unchanged_cols_right].to(torch.int))).tolist()))]
                                amp_val = (amp_unchanged_left_env|amp_unchanged_right_env|amp_changed_cols).contract()
                                
            if amp_val==0.0:
                amp_val = torch.tensor(0.0)
            
            if self.debug:
                print(f'Reused Amp val: {amp_val}, Exact Amp val: {self.get_amp_tn(x_i).contract()}')
            
            batch_amps.append(amp_val)

        # Return the batch of amplitudes stacked as a tensor
        return torch.stack(batch_amps)

In [3]:
random_x = torch.tensor(H.hilbert.random_state())
random_x1 = random_x.clone()
print(random_x)
random_x1[1] = 2
random_x1[5] = 1
model = fTNModel_reuse(peps, max_bond=chi, dtype=dtype)

print((peps.product_bra_state(random_x).conj()|peps).contract(), (peps|peps.product_bra_state(random_x, reverse=1)).contract())

permute_list = []
for tensor in peps.product_bra_state(random_x).conj().tensors:
    # print(tensor.data.charge, tensor.data.oddpos)
    if len(tensor.data.oddpos)>0:
        # print(tensor.data.oddpos[0].label, tensor.data.oddpos[0].dual)
        permute_list.append(tensor.data.oddpos[0].label)

# sort the list and note down the parity of the permutation
print(len(permute_list))
# sum from len(permute_list) to 1
N = 0
for i in range(len(permute_list)):
    N += i
print('phase correction from reversing the ordering:', (-1)**(N%2))

tensor([3, 0, 3, 0, 1, 1, 2, 0, 0, 3, 0, 1, 0, 2, 3, 2])
tensor(-45918.1884, dtype=torch.float64) tensor(45918.1884, dtype=torch.float64)
6
phase correction from reversing the ordering: -1


In [4]:
changed_sites = model.detect_changed_sites(random_x, random_x1)
changed_sites

([1], [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])

In [5]:
model = fTNModel_reuse(peps, max_bond=chi, dtype=dtype, functional=False)
model.cache_env_mode = True
model(random_x)
model.cache_env_mode = False

In [6]:
import time
t0 = time.time()
amp0 = model.get_amp_tn(random_x1)
t1 = time.time()
print("Time taken:", t1 - t0)
amp0.contract()

Time taken: 0.0023186206817626953


0.0

In [7]:
t0 = time.time()
amp1 = model.get_amp_tn(random_x1, reconstruct=True)
t1 = time.time()
print("Time taken:", t1 - t0)
amp1.contract()

Time taken: 0.011317729949951172


0.0

# mode = 'dm':
ImportError: autoray couldn't find function 'argsort' for backend 'symmray'.

##### TO-DO: need to implement eigh_truncate for fermionic tensors

# mode = 'fit'

In [9]:
amp0.contract_boundary_from_xmax(xrange=(0, model.Lx-1), max_bond=16, cutoff=0.0, mode='fit').contract()

TypeError: Expected FermionicArray, got <class 'numpy.ndarray'>.