In [6]:
from mpi4py import MPI
import ast
# torch
import torch
import torch.nn as nn

# quimb
import quimb.tensor as qtn
import symmray as sr
import autoray as ar
from autoray import do

from vmc_torch.fermion_utils import *

COMM = MPI.COMM_WORLD
SIZE = COMM.Get_size()
RANK = COMM.Get_rank()

class fTNModel(torch.nn.Module):

    def __init__(self, ftn):
        super().__init__()
        # extract the raw arrays and a skeleton of the TN
        params, self.skeleton = qtn.pack(ftn)

        # Flatten the dictionary structure and assign each parameter
        self.torch_params = {
            tid: nn.ParameterDict({
                str(sector): nn.Parameter(data)
                for sector, data in blk_array.items()
            })
            for tid, blk_array in params.items()
        }

        # Get symmetry
        self.symmetry = ftn.arrays[0].symmetry

    def product_bra_state(self, config, peps, symmetry='Z2'):
        """Spinless fermion product bra state."""
        product_tn = qtn.TensorNetwork()
        backend = peps.tensors[0].data.backend
        iterable_oddpos = iter(range(2*peps.nsites+1))
        for n, site in zip(config, peps.sites):
            p_ind = peps.site_ind_id.format(*site)
            p_tag = peps.site_tag_id.format(*site)
            tid = peps.sites.index(site)
            nsites = peps.nsites
            # use autoray to ensure the correct backend is used
            with ar.backend_like(backend):
                if symmetry == 'Z2':
                    data = [sr.Z2FermionicArray.from_blocks(blocks={(0,):do('array', [1.0,], like=backend)}, duals=(True,),symmetry='Z2', charge=0, oddpos=2*tid+1), # It doesn't matter if oddpos is None for even parity tensor.
                            sr.Z2FermionicArray.from_blocks(blocks={(1,):do('array', [1.0,], like=backend)}, duals=(True,),symmetry='Z2',charge=1, oddpos=2*tid+1)
                        ]
                elif symmetry == 'U1':
                    data = [sr.U1FermionicArray.from_blocks(blocks={(0,):do('array', [1.0,], like=backend)}, duals=(True,),symmetry='U1', charge=0, oddpos=2*tid+1),
                            sr.U1FermionicArray.from_blocks(blocks={(1,):do('array', [1.0,], like=backend)}, duals=(True,),symmetry='U1', charge=1, oddpos=2*tid+1)
                        ]
            tsr_data = data[int(n)] # BUG: does not fit in jax compilation, a concrete value is needed for traced arrays
            tsr = qtn.Tensor(data=tsr_data, inds=(p_ind,),tags=(p_tag, 'bra'))
            product_tn |= tsr
        return product_tn

    def get_amp(self, peps, config, inplace=False, symmetry='Z2', conj=True):
        """Get the amplitude of a configuration in a PEPS."""
        if not inplace:
            peps = peps.copy()
        if conj:
            amp = peps|self.product_bra_state(config, peps, symmetry).conj()
        else:
            amp = peps|self.product_bra_state(config, peps, symmetry)
        for site in peps.sites:
            site_tag = peps.site_tag_id.format(*site)
            amp.contract_(tags=site_tag)

        amp.view_as_(
            qtn.PEPS,
            site_ind_id="k{},{}",
            site_tag_id="I{},{}",
            x_tag_id="X{}",
            y_tag_id="Y{}",
            Lx=peps.Lx,
            Ly=peps.Ly,
        )
        return amp
        
    def parameters(self):
        # Manually yield all parameters from the nested structure
        for tid_dict in self.torch_params.values():
            for param in tid_dict.values():
                yield param
    
    def from_params_to_vec(self):
        return torch.cat([param.data.flatten() for param in self.parameters()])
    
    @property
    def num_params(self):
        return len(self.from_params_to_vec())
    
    def params_grad_to_vec(self):
        param_grad_vec = torch.cat([param.grad.flatten() if param.grad is not None else torch.zeros_like(param).flatten() for param in self.parameters()])
        return param_grad_vec

    def clear_grad(self):
        for param in self.parameters():
            param.grad = None
    
    def from_vec_to_params(self, vec, quimb_format=False):
        # Reconstruct the original parameter structure (by unpacking from the flattened dict)
        params = {}
        idx = 0
        for tid, blk_array in self.torch_params.items():
            params[tid] = {}
            for sector, data in blk_array.items():
                shape = data.shape
                size = data.numel()
                if quimb_format:
                    params[tid][ast.literal_eval(sector)] = vec[idx:idx+size].view(shape)
                else:
                    params[tid][sector] = vec[idx:idx+size].view(shape)
                idx += size
        return params
    
    def load_params(self, new_params):
        if isinstance(new_params, torch.Tensor):
            new_params = self.from_vec_to_params(new_params)
        # Update the parameters manually
        with torch.no_grad():
            for tid, blk_array in new_params.items():
                for sector, data in blk_array.items():
                    self.torch_params[tid][sector].data = data

    
    def amplitude(self, x):
        # Reconstruct the original parameter structure (by unpacking from the flattened dict)
        params = {
            tid: {
                ast.literal_eval(sector): data
                for sector, data in blk_array.items()
            }
            for tid, blk_array in self.torch_params.items()
        }
        # Reconstruct the TN with the new parameters
        psi = qtn.unpack(params, self.skeleton)
       # `x` is expected to be batched as (batch_size, input_dim)
        # Loop through the batch and compute amplitude for each sample
        batch_amps = []
        for x_i in x:
            amp = self.get_amp(psi, x_i, symmetry=self.symmetry, conj=True)
            batch_amps.append(amp.contract())

        # Return the batch of amplitudes stacked as a tensor
        return torch.stack(batch_amps)
    
    def forward(self, x):
        if x.ndim == 1:
            # If input is not batched, add a batch dimension
            x = x.unsqueeze(0)
        return self.amplitude(x)


class fTN_NNiso_Model(torch.nn.Module):
    
    def __init__(self, ftn, max_bond, nn_hidden_dim=64, nn_eta=1e-3):
        super().__init__()
        self.max_bond = max_bond
        self.nn_eta = nn_eta
        # extract the raw arrays and a skeleton of the TN
        params, self.skeleton = qtn.pack(ftn)

        # Flatten the dictionary structure and assign each parameter as a part of a ModuleDict
        self.torch_tn_params = nn.ModuleDict({
            str(tid): nn.ParameterDict({
                str(sector): nn.Parameter(data)
                for sector, data in blk_array.items()
            })
            for tid, blk_array in params.items()
        })
        
        self.parity_config = [array.parity for array in ftn.arrays]
        self.N_fermion = sum(self.parity_config)
        dummy_config = torch.zeros(ftn.nsites)
        dummy_config[:self.N_fermion] = 1
        dummy_amp = self.get_amp(ftn, dummy_config, inplace=False)
        dummy_amp_w_proj = insert_proj_peps(dummy_amp, max_bond=max_bond, yrange=[0, ftn.Ly-2])
        dummy_amp_tn, dummy_proj_tn = dummy_amp_w_proj.partition(tags='proj')
        dummy_proj_params, dummy_proj_skeleton = qtn.pack(dummy_proj_tn)
        dummy_proj_params_vec = flatten_proj_params(dummy_proj_params)
        self.proj_params_vec_len = len(dummy_proj_params_vec)

        # Define an MLP layer (or any other neural network layers)
        self.mlp = nn.Sequential(
            nn.Linear(ftn.nsites, nn_hidden_dim),
            nn.ReLU(),
            nn.Linear(nn_hidden_dim, self.proj_params_vec_len)
        )

        # Get symmetry
        self.symmetry = ftn.arrays[0].symmetry
        assert self.symmetry == 'Z2', "Only Z2 symmetry fPEPS is supported for NN insertion now."
        if self.symmetry == 'Z2':
            assert self.N_fermion %2 == sum(self.parity_config) % 2, "The number of fermions must match the parity of the Z2-TNS."

        # Store the shapes of the parameters
        self.param_shapes = [param.shape for param in self.parameters()]

        self.model_structure = {
            'fPEPS (proj inserted)':{'D': ftn.max_bond(), 'chi': self.max_bond, 'Lx': ftn.Lx, 'Ly': ftn.Ly, 'symmetry': self.symmetry, 'proj_yrange': [0, ftn.Ly-2]},
            '2LayerMLP':{'hidden_dim': nn_hidden_dim, 'nn_eta': nn_eta, 'activation': 'ReLU'}
        }

    def product_bra_state(self, config, peps, symmetry='Z2'):
        """Spinless fermion product bra state."""
        product_tn = qtn.TensorNetwork()
        backend = peps.tensors[0].data.backend
        iterable_oddpos = iter(range(2*peps.nsites+1))
        for n, site in zip(config, peps.sites):
            p_ind = peps.site_ind_id.format(*site)
            p_tag = peps.site_tag_id.format(*site)
            tid = peps.sites.index(site)
            nsites = peps.nsites
            # use autoray to ensure the correct backend is used
            with ar.backend_like(backend):
                if symmetry == 'Z2':
                    data = [sr.Z2FermionicArray.from_blocks(blocks={(0,):do('array', [1.0,], like=backend)}, duals=(True,),symmetry='Z2', charge=0, oddpos=2*tid+1), # It doesn't matter if oddpos is None for even parity tensor.
                            sr.Z2FermionicArray.from_blocks(blocks={(1,):do('array', [1.0,], like=backend)}, duals=(True,),symmetry='Z2',charge=1, oddpos=2*tid+1)
                        ]
                elif symmetry == 'U1':
                    data = [sr.U1FermionicArray.from_blocks(blocks={(0,):do('array', [1.0,], like=backend)}, duals=(True,),symmetry='U1', charge=0, oddpos=2*tid+1),
                            sr.U1FermionicArray.from_blocks(blocks={(1,):do('array', [1.0,], like=backend)}, duals=(True,),symmetry='U1', charge=1, oddpos=2*tid+1)
                        ]
            tsr_data = data[int(n)] # BUG: does not fit in jax compilation, a concrete value is needed for traced arrays
            tsr = qtn.Tensor(data=tsr_data, inds=(p_ind,),tags=(p_tag, 'bra'))
            product_tn |= tsr
        return product_tn

    def get_amp(self, peps, config, inplace=False, symmetry='Z2', conj=True):
        """Get the amplitude of a configuration in a PEPS."""
        if not inplace:
            peps = peps.copy()
        if conj:
            amp = peps|self.product_bra_state(config, peps, symmetry).conj()
        else:
            amp = peps|self.product_bra_state(config, peps, symmetry)
        for site in peps.sites:
            site_tag = peps.site_tag_id.format(*site)
            amp.contract_(tags=site_tag)

        amp.view_as_(
            qtn.PEPS,
            site_ind_id="k{},{}",
            site_tag_id="I{},{}",
            x_tag_id="X{}",
            y_tag_id="Y{}",
            Lx=peps.Lx,
            Ly=peps.Ly,
        )
        return amp
        
    
    def from_params_to_vec(self):
        return torch.cat([param.data.flatten() for param in self.parameters()])
    
    @property
    def num_params(self):
        return len(self.from_params_to_vec())
    
    def params_grad_to_vec(self):
        param_grad_vec = torch.cat([param.grad.flatten() if param.grad is not None else torch.zeros_like(param).flatten() for param in self.parameters()])
        return param_grad_vec

    def clear_grad(self):
        for param in self.parameters():
            param.grad = None
    
    def load_params(self, new_params):
        pointer = 0
        for param, shape in zip(self.parameters(), self.param_shapes):
            num_param = param.numel()
            new_param_values = new_params[pointer:pointer+num_param].view(shape)
            with torch.no_grad():
                param.copy_(new_param_values)
            pointer += num_param

    def amplitude(self, x):
        # Reconstruct the original parameter structure (by unpacking from the flattened dict)
        params = {
            int(tid): {
                ast.literal_eval(sector): data
                for sector, data in blk_array.items()
            }
            for tid, blk_array in self.torch_tn_params.items()
        }
        # Reconstruct the TN with the new parameters
        psi = qtn.unpack(params, self.skeleton)
        # `x` is expected to be batched as (batch_size, input_dim)
        # Loop through the batch and compute amplitude for each sample
        batch_amps = []
        for x_i in x:
            amp = self.get_amp(psi, x_i, symmetry=self.symmetry, conj=True)

            # Insert projectors
            amp_w_proj = insert_proj_peps(amp, max_bond=self.max_bond, yrange=[0, psi.Ly-2])
            amp_tn, proj_tn = amp_w_proj.partition(tags='proj')
            proj_params, proj_skeleton = qtn.pack(proj_tn)
            proj_params_vec = flatten_proj_params(proj_params)

            # Check x_i type
            if not type(x_i) == torch.Tensor:
                x_i = torch.tensor(x_i, dtype=torch.float32)
            # Add NN output
            proj_params_vec += self.nn_eta*self.mlp(x_i)
            # Reconstruct the proj parameters
            new_proj_params = reconstruct_proj_params(proj_params_vec, proj_params)
            # Load the new parameters
            new_proj_tn = qtn.unpack(new_proj_params, proj_skeleton)
            new_amp_w_proj = amp_tn | new_proj_tn

            # contract column by column
            
            # batch_amps.append(torch.tensor(new_amp_w_proj.contract(), dtype=torch.float32, requires_grad=True))
            batch_amps.append(new_amp_w_proj.contract())

        # Return the batch of amplitudes stacked as a tensor
        return torch.stack(batch_amps)
    
    def forward(self, x):
        if x.ndim == 1:
            # If input is not batched, add a batch dimension
            x = x.unsqueeze(0)
        return self.amplitude(x)


In [19]:
Lx = 6
Ly = 6
D = 8
symmetry = 'Z2'
N_f = int(4*4/2)-2
# Create a random PEPS
peps, parity_config = generate_random_fpeps(Lx, Ly, D, seed=2, symmetry=symmetry, Nf=N_f)

# Create a random configuration
random_conf = np.zeros(Lx*Ly)
random_conf[:N_f] = 1
np.random.seed(1)
np.random.shuffle(random_conf)

t = 1.0
V = 4.0
mu = 0.0
edges = qtn.edges_2d_square(Lx, Ly)
site_info = sr.utils.parse_edges_to_site_info(
    edges,
    D,
    phys_dim=2,
    site_ind_id="k{},{}",
    site_tag_id="I{},{}",
)
terms = {
    (sitea, siteb): sr.fermi_hubbard_spinless_local_array(
        t=t, V=V, mu=mu,
        symmetry=symmetry,
        coordinations=(
            site_info[sitea]['coordination'],
            site_info[siteb]['coordination'],
        ),
    ).fuse((0, 1), (2, 3))
    for (sitea, siteb) in peps.gen_bond_coos()
}
ham = qtn.LocalHam2D(Lx, Ly, terms)
su = qtn.SimpleUpdateGen(peps, ham, compute_energy_per_site=True,D=D, compute_energy_opts={"max_distance":1}, gate_opts={'cutoff':1e-10})
su.evolve(50, 0.3)
# su.evolve(50, 0.1)
peps = su.state

n=50, tau=0.3000, energy~-0.402307: 100%|##########| 50/50 [00:10<00:00,  4.77it/s]


In [28]:
peps.apply_to_arrays(lambda x: torch.tensor(x, dtype=torch.float32, requires_grad=True))
# Get the amplitude of the configuration
amp = peps.get_amp(random_conf, conj=True)
import pyinstrument
with pyinstrument.profile():
    amp.contract()
    # amp.contract_boundary_from_ymin(max_bond=8, yrange=(0, amp.Ly-2), cutoff=0.0).contract()

  peps.apply_to_arrays(lambda x: torch.tensor(x, dtype=torch.float32, requires_grad=True))

pyinstrument ........................................
.
.  Block at /tmp/ipykernel_22117/2079912326.py:5
.
.  0.066 <module>  ../../../../../tmp/ipykernel_22117/2079912326.py:5
.  └─ 0.066 PEPS.contract  quimb/tensor/tensor_core.py:8396
.        [6 frames hidden]  functools, quimb, cotengra
.           0.066 wrapper  functools.py:883
.           └─ 0.065 tensordot_fermionic  symmray/fermionic_core.py:711
.              ├─ 0.052 tensordot_abelian  symmray/abelian_core.py:1996
.              │  └─ 0.052 _tensordot_via_fused  symmray/abelian_core.py:1941
.              │     ├─ 0.031 Z2FermionicArray.fuse  symmray/abelian_core.py:1418
.              │     │  ├─ 0.010 <dictcomp>  symmray/abelian_core.py:1524
.              │     │  │  └─ 0.010 _recurse_concat  symmray/abelian_core.py:1484
.              │     │  │     ├─ 0.005 translated_function  autoray/autoray.py:1310
.              │     │  │   

In [3]:
# Get the amplitude of the configuration
amp = peps.get_amp(random_conf, conj=True)
amp.contract()
yrange = [0, Lx-2]
chi = 8
amp_w_proj = insert_proj_peps(amp, max_bond=chi, yrange=yrange)

# Phase check
print(amp_w_proj.contract(), amp.contract())

amp_tn, proj_tn = amp_w_proj.partition(tags='proj')
proj_params, proj_skeleton = qtn.pack(proj_tn)

# Flatten the proj parameters
proj_params_vec = flatten_proj_params(proj_params)

# Perturbation
perturbation = 1e-4
perturbed_params = proj_params_vec + perturbation*np.random.randn(len(proj_params_vec))

# Reconstruct the proj parameters
perturbed_proj_params = reconstruct_proj_params(perturbed_params, proj_params)

# Load the perturbed parameters
new_proj_tn = qtn.unpack(perturbed_proj_params, proj_skeleton)
new_amp_w_proj = amp_tn | new_proj_tn
new_amp_w_proj.contract()

-6.023293644344654e-08 -5.956890913005554e-08


-5.990340537268858e-08

In [23]:
peps = su.state.copy()
peps.apply_to_arrays(lambda x: torch.tensor(x, dtype=torch.float32, requires_grad=True))
np.random.shuffle(random_conf)
test_model = fTN_NNiso_Model(peps, max_bond=chi, nn_hidden_dim=8, nn_eta=0.0)
# # test_model = fTNModel(peps)
torch.autograd.set_detect_anomaly(True)
print(torch.tensor([random_conf], dtype=torch.float32))
loss = test_model.amplitude(torch.tensor([random_conf], dtype=torch.float32))
# amp = peps.get_amp(random_conf, conj=True)
# amp_w_proj = insert_proj_peps(amp, max_bond=chi, yrange=[0, peps.Ly-2])
# loss = amp_w_proj.contract()
loss.backward()

tensor([[0., 0., 0., 0., 1., 1., 1., 0., 0., 1., 1., 1., 0., 1., 0., 1.]])
[tensor(-1.9720e-08, grad_fn=<ViewBackward0>)]


Note: If symmetry is U1, the shape of each block is different, and the numbers of parameters in the projectors for different configurations are also different.

Must use NN sturcture that can have dynamic output dimensions.

In [127]:
np.random.shuffle(random_conf)
ampx = peps.get_amp(random_conf, conj=True)
ampx_w_proj = insert_proj_peps(ampx, max_bond=chi, yrange=[0, peps.Ly-2])
ampx_tn, projx_tn = ampx_w_proj.partition(tags='proj')
projx_params, projx_skeleton = qtn.pack(projx_tn)
projx_params_vec = flatten_proj_params(projx_params)

{(0,
  0,
  0): tensor([[[ 0.4578,  0.2153],
          [-0.0884,  0.2482]],
 
         [[-0.1604,  0.4821],
          [-0.1731, -0.6656]]]),
 (1,
  1,
  0): tensor([[[ -0.1208, -11.7845],
          [  0.0700,   7.7760]],
 
         [[  0.0388,   4.7961],
          [ -0.0251,  -3.2224]]]),
 (0,
  1,
  1): tensor([[[  2.3070, -11.2277],
          [ -1.4781,   8.7886]],
 
         [[ -0.5556,  25.5320],
          [  0.3850, -17.3789]]]),
 (1,
  0,
  1): tensor([[[-6.5736e-05, -2.6272e-02],
          [ 1.0638e-01, -3.7179e-01]],
 
         [[ 8.8650e-04,  7.9483e-02],
          [-3.2298e-02,  3.0391e-01]]])}