In [11]:
from mpi4py import MPI
import ast
# torch
import torch
import torch.nn as nn

import os
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

# quimb
import quimb.tensor as qtn
import symmray as sr
import autoray as ar
from autoray import do

from vmc_torch.fermion_utils import *

COMM = MPI.COMM_WORLD
SIZE = COMM.Get_size()
RANK = COMM.Get_rank()

In [89]:
Lx = 6
Ly = 6
D = 4
symmetry = 'Z2'
N_f = int(4*4/2)-2
# Create a random PEPS
peps, parity_config = generate_random_fpeps(Lx, Ly, D, seed=2, symmetry=symmetry, Nf=N_f)

# Create a random configuration
random_conf = np.zeros(Lx*Ly)
random_conf[:N_f] = 1
np.random.seed(1)
np.random.shuffle(random_conf)

t = 1.0
V = 4.0
mu = 0.0
edges = qtn.edges_2d_square(Lx, Ly)
site_info = sr.utils.parse_edges_to_site_info(
    edges,
    D,
    phys_dim=2,
    site_ind_id="k{},{}",
    site_tag_id="I{},{}",
)
terms = {
    (sitea, siteb): sr.fermi_hubbard_spinless_local_array(
        t=t, V=V, mu=mu,
        symmetry=symmetry,
        coordinations=(
            site_info[sitea]['coordination'],
            site_info[siteb]['coordination'],
        ),
    ).fuse((0, 1), (2, 3))
    for (sitea, siteb) in peps.gen_bond_coos()
}
ham = qtn.LocalHam2D(Lx, Ly, terms)
su = qtn.SimpleUpdateGen(peps, ham, compute_energy_per_site=True,D=D, compute_energy_opts={"max_distance":1}, gate_opts={'cutoff':1e-10})
su.evolve(50, 0.3)
# su.evolve(50, 0.1)
peps = su.state

n=50, tau=0.3000, energy~-0.404314: 100%|##########| 50/50 [00:09<00:00,  5.19it/s]


In [96]:
class fTNModel(torch.nn.Module):

    def __init__(self, ftn, max_bond=None):
        super().__init__()
        
        # extract the raw arrays and a skeleton of the TN
        params, self.skeleton = qtn.pack(ftn)

        # Flatten the dictionary structure and assign each parameter as a part of a ModuleDict
        self.torch_tn_params = nn.ModuleDict({
            str(tid): nn.ParameterDict({
                str(sector): nn.Parameter(data)
                for sector, data in blk_array.items()
            })
            for tid, blk_array in params.items()
        })

        # Get symmetry
        self.symmetry = ftn.arrays[0].symmetry

        # Store the shapes of the parameters
        self.param_shapes = [param.shape for param in self.parameters()]

        self.model_structure = {
            'fPEPS (exact contraction)':{'D': ftn.max_bond(), 'Lx': ftn.Lx, 'Ly': ftn.Ly, 'symmetry': self.symmetry},
        }
        self.max_bond = max_bond

    def product_bra_state(self, config, peps, symmetry='Z2'):
        """Spinless fermion product bra state."""
        product_tn = qtn.TensorNetwork()
        backend = peps.tensors[0].data.backend
        iterable_oddpos = iter(range(2*peps.nsites+1))
        for n, site in zip(config, peps.sites):
            p_ind = peps.site_ind_id.format(*site)
            p_tag = peps.site_tag_id.format(*site)
            tid = peps.sites.index(site)
            nsites = peps.nsites
            # use autoray to ensure the correct backend is used
            with ar.backend_like(backend):
                if symmetry == 'Z2':
                    data = [sr.Z2FermionicArray.from_blocks(blocks={(0,):do('array', [1.0,], like=backend)}, duals=(True,),symmetry='Z2', charge=0, oddpos=2*tid+1), # It doesn't matter if oddpos is None for even parity tensor.
                            sr.Z2FermionicArray.from_blocks(blocks={(1,):do('array', [1.0,], like=backend)}, duals=(True,),symmetry='Z2',charge=1, oddpos=2*tid+1)
                        ]
                elif symmetry == 'U1':
                    data = [sr.U1FermionicArray.from_blocks(blocks={(0,):do('array', [1.0,], like=backend)}, duals=(True,),symmetry='U1', charge=0, oddpos=2*tid+1),
                            sr.U1FermionicArray.from_blocks(blocks={(1,):do('array', [1.0,], like=backend)}, duals=(True,),symmetry='U1', charge=1, oddpos=2*tid+1)
                        ]
            tsr_data = data[int(n)] # BUG: does not fit in jax compilation, a concrete value is needed for traced arrays
            tsr = qtn.Tensor(data=tsr_data, inds=(p_ind,),tags=(p_tag, 'bra'))
            product_tn |= tsr
        return product_tn

    def get_amp(self, peps, config, inplace=False, symmetry='Z2', conj=True):
        """Get the amplitude of a configuration in a PEPS."""
        if not inplace:
            peps = peps.copy()
        if conj:
            amp = peps|self.product_bra_state(config, peps, symmetry).conj()
        else:
            amp = peps|self.product_bra_state(config, peps, symmetry)
        for site in peps.sites:
            site_tag = peps.site_tag_id.format(*site)
            amp.contract_(tags=site_tag)

        amp.view_as_(
            qtn.PEPS,
            site_ind_id="k{},{}",
            site_tag_id="I{},{}",
            x_tag_id="X{}",
            y_tag_id="Y{}",
            Lx=peps.Lx,
            Ly=peps.Ly,
        )
        return amp
        
    
    def from_params_to_vec(self):
        return torch.cat([param.data.flatten() for param in self.parameters()])
    
    @property
    def num_params(self):
        return len(self.from_params_to_vec())
    
    def params_grad_to_vec(self):
        param_grad_vec = torch.cat([param.grad.flatten() if param.grad is not None else torch.zeros_like(param).flatten() for param in self.parameters()])
        return param_grad_vec

    def clear_grad(self):
        for param in self.parameters():
            param.grad = None
    
    def from_vec_to_params(self, vec, quimb_format=False):
        # Reconstruct the original parameter structure (by unpacking from the flattened dict)
        params = {}
        idx = 0
        for tid, blk_array in self.torch_params.items():
            params[tid] = {}
            for sector, data in blk_array.items():
                shape = data.shape
                size = data.numel()
                if quimb_format:
                    params[tid][ast.literal_eval(sector)] = vec[idx:idx+size].view(shape)
                else:
                    params[tid][sector] = vec[idx:idx+size].view(shape)
                idx += size
        return params
    
    def load_params(self, new_params):
        pointer = 0
        for param, shape in zip(self.parameters(), self.param_shapes):
            num_param = param.numel()
            new_param_values = new_params[pointer:pointer+num_param].view(shape)
            with torch.no_grad():
                param.copy_(new_param_values)
            pointer += num_param

    
    def amplitude(self, x):
        # Reconstruct the original parameter structure (by unpacking from the flattened dict)
        params = {
            int(tid): {
                ast.literal_eval(sector): data
                for sector, data in blk_array.items()
            }
            for tid, blk_array in self.torch_tn_params.items()
        }
        # Reconstruct the TN with the new parameters
        psi = qtn.unpack(params, self.skeleton)
        # `x` is expected to be batched as (batch_size, input_dim)
        # Loop through the batch and compute amplitude for each sample
        batch_amps = []
        for x_i in x:
            amp = self.get_amp(psi, x_i, symmetry=self.symmetry, conj=True)
            if self.max_bond is None:
                batch_amps.append(amp.contract())
            else:
                amp = amp.contract_boundary_from_ymin(max_bond=self.max_bond, cutoff=0.0, yrange=[0, psi.Ly-2])
                batch_amps.append(amp.contract())

        # Return the batch of amplitudes stacked as a tensor
        return torch.stack(batch_amps)
    
    def forward(self, x):
        if x.ndim == 1:
            # If input is not batched, add a batch dimension
            x = x.unsqueeze(0)
        return self.amplitude(x)

class fTN_NNiso_Model(torch.nn.Module):
    
    def __init__(self, ftn, max_bond, nn_hidden_dim=64, nn_eta=1e-3):
        super().__init__()
        self.max_bond = max_bond
        self.nn_eta = nn_eta
        # extract the raw arrays and a skeleton of the TN
        params, self.skeleton = qtn.pack(ftn)

        # Flatten the dictionary structure and assign each parameter as a part of a ModuleDict
        self.torch_tn_params = nn.ModuleDict({
            str(tid): nn.ParameterDict({
                str(sector): nn.Parameter(data)
                for sector, data in blk_array.items()
            })
            for tid, blk_array in params.items()
        })
        
        self.parity_config = [array.parity for array in ftn.arrays]
        self.N_fermion = sum(self.parity_config)
        dummy_config = torch.zeros(ftn.nsites)
        dummy_config[:self.N_fermion] = 1
        dummy_amp = self.get_amp(ftn, dummy_config, inplace=False)
        dummy_amp_w_proj = insert_proj_peps(dummy_amp, max_bond=max_bond, yrange=[0, ftn.Ly-2])
        dummy_amp_tn, dummy_proj_tn = dummy_amp_w_proj.partition(tags='proj')
        dummy_proj_params, dummy_proj_skeleton = qtn.pack(dummy_proj_tn)
        dummy_proj_params_vec = flatten_proj_params(dummy_proj_params)
        self.proj_params_vec_len = len(dummy_proj_params_vec)

        # Define an MLP layer (or any other neural network layers)
        self.mlp = nn.Sequential(
            nn.Linear(ftn.nsites, nn_hidden_dim),
            nn.ReLU(),
            nn.Linear(nn_hidden_dim, self.proj_params_vec_len)
        )

        # Get symmetry
        self.symmetry = ftn.arrays[0].symmetry
        assert self.symmetry == 'Z2', "Only Z2 symmetry fPEPS is supported for NN insertion now."
        if self.symmetry == 'Z2':
            assert self.N_fermion %2 == sum(self.parity_config) % 2, "The number of fermions must match the parity of the Z2-TNS."

        # Store the shapes of the parameters
        self.param_shapes = [param.shape for param in self.parameters()]

        self.model_structure = {
            'fPEPS (proj inserted)':{'D': ftn.max_bond(), 'chi': self.max_bond, 'Lx': ftn.Lx, 'Ly': ftn.Ly, 'symmetry': self.symmetry, 'proj_yrange': [0, ftn.Ly-2]},
            '2LayerMLP':{'hidden_dim': nn_hidden_dim, 'nn_eta': nn_eta, 'activation': 'ReLU'}
        }

    def product_bra_state(self, config, peps, symmetry='Z2'):
        """Spinless fermion product bra state."""
        product_tn = qtn.TensorNetwork()
        backend = peps.tensors[0].data.backend
        iterable_oddpos = iter(range(2*peps.nsites+1))
        for n, site in zip(config, peps.sites):
            p_ind = peps.site_ind_id.format(*site)
            p_tag = peps.site_tag_id.format(*site)
            tid = peps.sites.index(site)
            nsites = peps.nsites
            # use autoray to ensure the correct backend is used
            with ar.backend_like(backend):
                if symmetry == 'Z2':
                    data = [sr.Z2FermionicArray.from_blocks(blocks={(0,):do('array', [1.0,], like=backend)}, duals=(True,),symmetry='Z2', charge=0, oddpos=2*tid+1), # It doesn't matter if oddpos is None for even parity tensor.
                            sr.Z2FermionicArray.from_blocks(blocks={(1,):do('array', [1.0,], like=backend)}, duals=(True,),symmetry='Z2',charge=1, oddpos=2*tid+1)
                        ]
                elif symmetry == 'U1':
                    data = [sr.U1FermionicArray.from_blocks(blocks={(0,):do('array', [1.0,], like=backend)}, duals=(True,),symmetry='U1', charge=0, oddpos=2*tid+1),
                            sr.U1FermionicArray.from_blocks(blocks={(1,):do('array', [1.0,], like=backend)}, duals=(True,),symmetry='U1', charge=1, oddpos=2*tid+1)
                        ]
            tsr_data = data[int(n)] # BUG: does not fit in jax compilation, a concrete value is needed for traced arrays
            tsr = qtn.Tensor(data=tsr_data, inds=(p_ind,),tags=(p_tag, 'bra'))
            product_tn |= tsr
        return product_tn

    def get_amp(self, peps, config, inplace=False, symmetry='Z2', conj=True):
        """Get the amplitude of a configuration in a PEPS."""
        if not inplace:
            peps = peps.copy()
        if conj:
            amp = peps|self.product_bra_state(config, peps, symmetry).conj()
        else:
            amp = peps|self.product_bra_state(config, peps, symmetry)
        for site in peps.sites:
            site_tag = peps.site_tag_id.format(*site)
            amp.contract_(tags=site_tag)

        amp.view_as_(
            qtn.PEPS,
            site_ind_id="k{},{}",
            site_tag_id="I{},{}",
            x_tag_id="X{}",
            y_tag_id="Y{}",
            Lx=peps.Lx,
            Ly=peps.Ly,
        )
        return amp
        
    
    def from_params_to_vec(self):
        return torch.cat([param.data.flatten() for param in self.parameters()])
    
    @property
    def num_params(self):
        return len(self.from_params_to_vec())
    
    @property
    def num_tn_params(self):
        num=0
        for tid, blk_array in self.torch_tn_params.items():
            for sector, data in blk_array.items():
                num += data.numel()
        return num
    
    def params_grad_to_vec(self):
        param_grad_vec = torch.cat([param.grad.flatten() if param.grad is not None else torch.zeros_like(param).flatten() for param in self.parameters()])
        return param_grad_vec

    def clear_grad(self):
        for param in self.parameters():
            param.grad = None
    
    def load_params(self, new_params):
        pointer = 0
        for param, shape in zip(self.parameters(), self.param_shapes):
            num_param = param.numel()
            new_param_values = new_params[pointer:pointer+num_param].view(shape)
            with torch.no_grad():
                param.copy_(new_param_values)
            pointer += num_param

    def amplitude(self, x):
        # Reconstruct the original parameter structure (by unpacking from the flattened dict)
        params = {
            int(tid): {
                ast.literal_eval(sector): data
                for sector, data in blk_array.items()
            }
            for tid, blk_array in self.torch_tn_params.items()
        }
        # Reconstruct the TN with the new parameters
        psi = qtn.unpack(params, self.skeleton)
        # `x` is expected to be batched as (batch_size, input_dim)
        # Loop through the batch and compute amplitude for each sample
        batch_amps = []
        for x_i in x:
            amp = self.get_amp(psi, x_i, symmetry=self.symmetry, conj=True)

            # Insert projectors
            amp_w_proj = insert_proj_peps(amp, max_bond=self.max_bond, yrange=[0, psi.Ly-2])
            amp_tn, proj_tn = amp_w_proj.partition(tags='proj')
            proj_params, proj_skeleton = qtn.pack(proj_tn)
            proj_params_vec = flatten_proj_params(proj_params)

            # Check x_i type
            if not type(x_i) == torch.Tensor:
                x_i = torch.tensor(x_i, dtype=torch.float32)
            # Add NN output
            proj_params_vec += self.nn_eta*self.mlp(x_i)
            # Reconstruct the proj parameters
            new_proj_params = reconstruct_proj_params(proj_params_vec, proj_params)
            # Load the new parameters
            new_proj_tn = qtn.unpack(new_proj_params, proj_skeleton)
            new_amp_w_proj = amp_tn | new_proj_tn

            # contract column by column
            
            # batch_amps.append(torch.tensor(new_amp_w_proj.contract(), dtype=torch.float32, requires_grad=True))
            batch_amps.append(new_amp_w_proj.contract())

        # Return the batch of amplitudes stacked as a tensor
        return torch.stack(batch_amps)
    
    def forward(self, x):
        if x.ndim == 1:
            # If input is not batched, add a batch dimension
            x = x.unsqueeze(0)
        return self.amplitude(x)


class fTN_NN_Model(torch.nn.Module):
    
    def __init__(self, ftn, max_bond, nn_hidden_dim=64, nn_eta=1e-3):
        super().__init__()
        self.max_bond = max_bond
        self.nn_eta = nn_eta
        # extract the raw arrays and a skeleton of the TN
        params, self.skeleton = qtn.pack(ftn)

        # Flatten the dictionary structure and assign each parameter as a part of a ModuleDict
        self.torch_tn_params = nn.ModuleDict({
            str(tid): nn.ParameterDict({
                str(sector): nn.Parameter(data)
                for sector, data in blk_array.items()
            })
            for tid, blk_array in params.items()
        })
        
        self.parity_config = [array.parity for array in ftn.arrays]
        self.N_fermion = sum(self.parity_config)
        dummy_config = torch.zeros(ftn.nsites)
        dummy_config[:self.N_fermion] = 1
        dummy_amp = ftn.get_amp(dummy_config, inplace=False, conj=True)
        dummy_amp_2row = dummy_amp.contract_boundary_from_ymin(max_bond=max_bond, cutoff=0.0, yrange=[0, ftn.Ly//2-1])
        dummy_amp_2row = dummy_amp_2row.contract_boundary_from_ymax(max_bond=max_bond, cutoff=0.0, yrange=[ftn.Ly//2, ftn.Ly-1])
        dummy_2row_params, dummy_2row_skeleton = qtn.pack(dummy_amp_2row)
        dummy_2row_params_vec = flatten_proj_params(dummy_2row_params)
        self.tworow_params_vec_len = len(dummy_2row_params_vec)

        # Define an MLP layer (or any other neural network layers)
        self.mlp = nn.Sequential(
            nn.Linear(ftn.nsites, nn_hidden_dim),
            nn.ReLU(),
            nn.Linear(nn_hidden_dim, self.tworow_params_vec_len)
        )

        # Get symmetry
        self.symmetry = ftn.arrays[0].symmetry
        assert self.symmetry == 'Z2', "Only Z2 symmetry fPEPS is supported for NN insertion now."
        if self.symmetry == 'Z2':
            assert self.N_fermion %2 == sum(self.parity_config) % 2, "The number of fermions must match the parity of the Z2-TNS."

        # Store the shapes of the parameters
        self.param_shapes = [param.shape for param in self.parameters()]

        self.model_structure = {
            'fPEPS (proj inserted)':{'D': ftn.max_bond(), 'chi': self.max_bond, 'Lx': ftn.Lx, 'Ly': ftn.Ly, 'symmetry': self.symmetry},
            '2LayerMLP':{'hidden_dim': nn_hidden_dim, 'nn_eta': nn_eta, 'activation': 'ReLU'}
        }
        
    
    def from_params_to_vec(self):
        return torch.cat([param.data.flatten() for param in self.parameters()])
    
    @property
    def num_params(self):
        return len(self.from_params_to_vec())
    
    @property
    def num_tn_params(self):
        num=0
        for tid, blk_array in self.torch_tn_params.items():
            for sector, data in blk_array.items():
                num += data.numel()
        return num
    
    def params_grad_to_vec(self):
        param_grad_vec = torch.cat([param.grad.flatten() if param.grad is not None else torch.zeros_like(param).flatten() for param in self.parameters()])
        return param_grad_vec

    def clear_grad(self):
        for param in self.parameters():
            param.grad = None
    
    def load_params(self, new_params):
        pointer = 0
        for param, shape in zip(self.parameters(), self.param_shapes):
            num_param = param.numel()
            new_param_values = new_params[pointer:pointer+num_param].view(shape)
            with torch.no_grad():
                param.copy_(new_param_values)
            pointer += num_param

    def amplitude(self, x):
        # Reconstruct the original parameter structure (by unpacking from the flattened dict)
        params = {
            int(tid): {
                ast.literal_eval(sector): data
                for sector, data in blk_array.items()
            }
            for tid, blk_array in self.torch_tn_params.items()
        }
        # Reconstruct the TN with the new parameters
        psi = qtn.unpack(params, self.skeleton)
        # `x` is expected to be batched as (batch_size, input_dim)
        # Loop through the batch and compute amplitude for each sample
        batch_amps = []
        for x_i in x:
            # Check x_i type
            if not type(x_i) == torch.Tensor:
                x_i = torch.tensor(x_i, dtype=torch.float32)
            
            amp = psi.get_amp(x_i, conj=True)

            # Contract to 2 rows
            amp_2row = amp.contract_boundary_from_ymin(max_bond=self.max_bond, cutoff=0.0, yrange=[0, psi.Ly//2-1])
            amp_2row = amp_2row.contract_boundary_from_ymax(max_bond=self.max_bond, cutoff=0.0, yrange=[psi.Ly//2, psi.Ly-1])
            amp_2row_params, amp_2row_skeleton = qtn.pack(amp_2row)
            amp_2row_params_vec = flatten_proj_params(amp_2row_params)
            # Add NN output
            amp_2row_params_vec = amp_2row_params_vec + self.nn_eta*self.mlp(x_i)
            # Reconstruct the proj parameters
            new_2row_params = reconstruct_proj_params(amp_2row_params_vec, amp_2row_params)
            # Load the new parameters
            new_amp_2row = qtn.unpack(new_2row_params, amp_2row_skeleton)

            batch_amps.append(new_amp_2row.contract())

        # Return the batch of amplitudes stacked as a tensor
        return torch.stack(batch_amps)
    
    def forward(self, x):
        if x.ndim == 1:
            # If input is not batched, add a batch dimension
            x = x.unsqueeze(0)
        return self.amplitude(x)

In [97]:
peps.apply_to_arrays(lambda x: torch.tensor(x, dtype=torch.float32, requires_grad=True))
# Get the amplitude of the configuration
import pyinstrument
with pyinstrument.profile():
    amp = peps.get_amp(random_conf, conj=True)
    amp.contract()
    # amp.contract_boundary_from_ymin(max_bond=8, yrange=(0, amp.Ly-2), cutoff=0.0).contract()

  peps.apply_to_arrays(lambda x: torch.tensor(x, dtype=torch.float32, requires_grad=True))

pyinstrument ........................................
.
.  Block at /tmp/ipykernel_27785/3024431008.py:4
.
.  0.073 <module>  ../../../../../tmp/ipykernel_27785/3024431008.py:4
.  ├─ 0.043 PEPS.contract  quimb/tensor/tensor_core.py:8396
.  │     [8 frames hidden]  functools, quimb, cotengra, autoray, ...
.  │        0.042 wrapper  functools.py:883
.  │        └─ 0.042 tensordot_fermionic  symmray/fermionic_core.py:711
.  │           ├─ 0.029 tensordot_abelian  symmray/abelian_core.py:1996
.  │           │  └─ 0.029 _tensordot_via_fused  symmray/abelian_core.py:1941
.  │           │     ├─ 0.016 Z2FermionicArray.fuse  symmray/abelian_core.py:1418
.  │           │     │  ├─ 0.008 _VariableFunctionsClass.reshape  <built-in>
.  │           │     │  ├─ 0.004 <dictcomp>  symmray/abelian_core.py:1524
.  │           │     │  │  ├─ 0.003 _recurse_concat  symmray/abelian_core.py:1484
.  │           │     

In [98]:
ftn_model = fTNModel(peps, max_bond=8)
with pyinstrument.profile():
    ftn_model.amplitude([random_conf])


pyinstrument ........................................
.
.  Block at /tmp/ipykernel_27785/829635654.py:2
.
.  0.137 <module>  ../../../../../tmp/ipykernel_27785/829635654.py:2
.  └─ 0.137 fTNModel.amplitude  ../../../../../tmp/ipykernel_27785/3087146501.py:119
.     ├─ 0.099 PEPS.contract_boundary_from_ymin  quimb/tensor/tensor_2d.py:2003
.     │     [44 frames hidden]  quimb, functools, autoray, cotengra
.     │        0.010 Composed.__call__  autoray/autoray.py:921
.     │        └─ 0.010 qr_stabilized  symmray/linalg.py:118
.     │           └─ 0.010 wrapper  functools.py:883
.     │              └─ 0.010 qr_fermionic  symmray/linalg.py:107
.     │                 └─ 0.010 qr  symmray/linalg.py:46
.     │                    └─ 0.008 _qr  symmray/linalg.py:26
.     │                       └─ 0.008 qr_stabilized  quimb/tensor/decomp.py:669
.     │                             [4 frames hidden]  autoray, <built-in>, quimb
.     │        0.005 do  autoray/autoray.py:30
.     │        └─ 

In [100]:
ftn_nniso_model = fTN_NNiso_Model(peps, max_bond=8, nn_hidden_dim=16, nn_eta=1e-3)
with pyinstrument.profile():
    ftn_nniso_model.amplitude([random_conf])


pyinstrument ........................................
.
.  Block at /tmp/ipykernel_27785/1814615920.py:2
.
.  0.413 <module>  ../../../../../tmp/ipykernel_27785/1814615920.py:2
.  └─ 0.412 fTN_NNiso_Model.amplitude  ../../../../../tmp/ipykernel_27785/3087146501.py:281
.     ├─ 0.331 insert_proj_peps  vmc_torch/fermion_utils.py:270
.     │  ├─ 0.322 insert_compressor  vmc_torch/fermion_utils.py:204
.     │  │  ├─ 0.196 TensorNetwork.contract  quimb/tensor/tensor_core.py:8396
.     │  │  │     [6 frames hidden]  functools, quimb, cotengra, autoray
.     │  │  │        0.172 wrapper  functools.py:883
.     │  │  │        └─ 0.170 tensordot_fermionic  symmray/fermionic_core.py:711
.     │  │  │           ├─ 0.128 tensordot_abelian  symmray/abelian_core.py:1996
.     │  │  │           │  └─ 0.126 _tensordot_via_fused  symmray/abelian_core.py:1941
.     │  │  │           │     ├─ 0.058 Z2FermionicArray.fuse  symmray/abelian_core.py:1418
.     │  │  │           │     │  ├─ 0.020 <dictcomp>  

In [101]:
ftn_nn_model = fTN_NN_Model(peps, max_bond=8, nn_hidden_dim=16, nn_eta=1e-3)
with pyinstrument.profile():
    ftn_nn_model.amplitude([random_conf])


pyinstrument ........................................
.
.  Block at /tmp/ipykernel_27785/1036372546.py:2
.
.  0.140 <module>  ../../../../../tmp/ipykernel_27785/1036372546.py:2
.  └─ 0.139 fTN_NN_Model.amplitude  ../../../../../tmp/ipykernel_27785/3087146501.py:413
.     ├─ 0.050 PEPS.contract_boundary_from_ymin  quimb/tensor/tensor_2d.py:2003
.     │     [44 frames hidden]  quimb, functools, autoray, cotengra
.     │        0.004 Composed.__call__  autoray/autoray.py:921
.     │        └─ 0.004 qr_stabilized  symmray/linalg.py:118
.     │           └─ 0.004 wrapper  functools.py:883
.     │              └─ 0.004 qr_fermionic  symmray/linalg.py:107
.     │                 └─ 0.004 qr  symmray/linalg.py:46
.     │                    └─ 0.004 _qr  symmray/linalg.py:26
.     │                       └─ 0.004 qr_stabilized  quimb/tensor/decomp.py:669
.     │                             [2 frames hidden]  autoray, <built-in>
.     │        0.004 do  autoray/autoray.py:30
.     │        └─ 0

In [111]:
# check number of parameters in each model
print('Number of parameters in fTN model:', ftn_model.num_params)
print('Number of TN parameters in fTN_NNiso model: {}, NN parameters: {}'.format(ftn_nniso_model.num_tn_params, ftn_nniso_model.num_params-ftn_nniso_model.num_tn_params))
print('Number of TN parameters in fTN_NN model: {}, NN parameters: {}'.format(ftn_nn_model.num_tn_params, ftn_nn_model.num_params-ftn_nn_model.num_tn_params))

Number of parameters in fTN model: 5184
Number of TN parameters in fTN_NNiso model: 5184, NN parameters: 54992
Number of TN parameters in fTN_NN model: 5184, NN parameters: 14192
