In [2]:
import os
os.environ["NUMBA_NUM_THREADS"] = "20"

import netket as nk
import netket.experimental as nkx
import netket.nn as nknn

from math import pi

from netket.experimental.operator.fermion import destroy as c
from netket.experimental.operator.fermion import create as cdag
from netket.experimental.operator.fermion import number as nc

from vmc_torch.fermion_utils import generate_random_fpeps
import quimb.tensor as qtn
import symmray as sr
import pickle

# Define the lattice shape
Lx = int(4)
Ly = int(4)
spinless = False
# graph = nk.graph.Square(L)
graph = nk.graph.Grid([Lx,Ly], pbc=False)
N = graph.n_nodes

# Define the fermion filling and the Hilbert space
N_f = int(Lx*Ly)
hi = nkx.hilbert.SpinOrbitalFermions(N, s=1/2, n_fermions=N_f)


# Define the Hubbard Hamiltonian
t = 1.0
U = 8.0
mu = 0.0

H = 0.0
for (i, j) in graph.edges(): # Definition of the Hubbard Hamiltonian
    for spin in (1,-1):
        H -= t * (cdag(hi,i,spin) * c(hi,j,spin) + cdag(hi,j,spin) * c(hi,i,spin))
for i in graph.nodes():
    H += U * nc(hi,i,+1) * nc(hi,i,-1)


# SU in quimb
D = 3
seed = 2
symmetry = 'U1'
spinless = False
peps = generate_random_fpeps(Lx, Ly, D=D, seed=2, symmetry=symmetry, Nf=N_f, spinless=spinless)[0]
edges = qtn.edges_2d_square(Lx, Ly, cyclic=False)
site_info = sr.parse_edges_to_site_info(
    edges,
    D,
    phys_dim=4,
    site_ind_id="k{},{}",
    site_tag_id="I{},{}",
)

t = 1.0
U = 8.0
mu = 0.0

terms = {
    (sitea, siteb): sr.fermi_hubbard_local_array(
        t=t, U=U, mu=mu,
        symmetry=symmetry,
        coordinations=(
            site_info[sitea]['coordination'],
            site_info[siteb]['coordination'],
        ),
    ).fuse((0, 1), (2, 3))
    for (sitea, siteb) in peps.gen_bond_coos()
}
ham = qtn.LocalHam2D(Lx, Ly, terms)

su = qtn.SimpleUpdateGen(peps, ham, compute_energy_per_site=True,D=D, compute_energy_opts={"max_distance":1}, gate_opts={'cutoff':1e-12})

# cluster energies may not be accuracte yet
su.evolve(50, tau=0.3)
su.evolve(50, tau=0.1)
su.evolve(50, tau=0.03)
# su.evolve(100, tau=0.01)
# su.evolve(100, tau=0.003)

peps = su.get_state()
peps.equalize_norms_(value=1)

# save the state
params, skeleton = qtn.pack(peps)

n=50, tau=0.3, max|dS|=0.13, energy~-0.2358: 100%|##########| 50/50 [00:08<00:00,  6.15it/s]
n=100, tau=0.1, max|dS|=0.077, energy~-0.299077: 100%|##########| 50/50 [00:06<00:00,  7.30it/s]
n=150, tau=0.03, max|dS|=0.019, energy~-0.302929: 100%|##########| 50/50 [00:06<00:00,  7.18it/s]


In [34]:
# extract the boundary tensors
params, skeleton = qtn.pack(peps)
boundary_sites = [site for site in peps.gen_site_coos() if any(c == 0 or c == Lx - 1 for c in site)]
bulk_site_tags = [peps.site_tag_id.format(*site) for site in peps.gen_site_coos() if site not in boundary_sites]
boundary_tags = [peps.site_tag_id.format(*site) for site in boundary_sites]
bulk_tensors, boundary_tensors = peps.partition(boundary_tags, which='any')
boundary_tids = [next(iter(peps._get_tids_from_tags([tag]))) for tag in boundary_tags]

In [None]:
from vmc_torch.experiment.tn_model import *
class fTN_backflow_attn_Model_boundary(wavefunctionModel):
    def __init__(self, ftn, max_bond=None, embedding_dim=32, attention_heads=4, nn_hidden_dim=128, nn_eta=1e-3, dtype=torch.float32):
        super().__init__()
        self.param_dtype = dtype
        
        # extract the raw arrays and a skeleton of the TN
        params, self.skeleton = qtn.pack(ftn)

        # Flatten the dictionary structure and assign each parameter as a part of a ModuleDict
        self.torch_tn_params = nn.ModuleDict({
            str(tid): nn.ParameterDict({
                str(sector): nn.Parameter(data)
                for sector, data in blk_array.items()
            })
            for tid, blk_array in params.items()
        })

        # Get the boundary tensors tids and parameter shapes
        boundary_sites = [site for site in ftn.gen_site_coos() if any(c == 0 or c == Lx - 1 for c in site)]
        boundary_tags = [ftn.site_tag_id.format(*site) for site in boundary_sites]
        bulk_site_tags = [ftn.site_tag_id.format(*site) for site in ftn.gen_site_coos() if site not in boundary_sites]
        self.bulk_tid_list = [next(iter(ftn._get_tids_from_tags([tag]))) for tag in bulk_site_tags]
        self.boundary_tid_list = [next(iter(ftn._get_tids_from_tags([tag]))) for tag in boundary_tags]
        boundary_tn_params = {
            tid: params[tid]
            for tid in self.boundary_tid_list
        }
        boundary_tn_params_vec = flatten_proj_params(boundary_tn_params)

        # Define the neural network for the backflow transformation to boundary tensors
        input_dim = ftn.Lx * ftn.Ly
        phys_dim = ftn.phys_dim()
        
        self.nn = SelfAttn_FFNN_block(
            n_site=input_dim,
            num_classes=phys_dim,
            embedding_dim=embedding_dim,
            attention_heads=attention_heads,
            nn_hidden_dim=nn_hidden_dim,
            output_dim=boundary_tn_params_vec.numel()
        )

        # Get symmetry
        self.symmetry = ftn.arrays[0].symmetry

        # Store the shapes of the parameters
        self.param_shapes = [param.shape for param in self.parameters()]

        self.model_structure = {
            'fPEPS_backflow_attn':{'D': ftn.max_bond(), 'Lx': ftn.Lx, 'Ly': ftn.Ly, 'symmetry': self.symmetry, 'nn_hidden_dim': nn_hidden_dim, 'nn_eta': nn_eta, 'max_bond': max_bond},
        }
        if max_bond is None or max_bond <= 0:
            max_bond = None
        self.max_bond = max_bond
        self.nn_eta = nn_eta
    
    def amplitude(self, x):
        tn_nn_params = {}

        # `x` is expected to be batched as (batch_size, input_dim)
        # Loop through the batch and compute amplitude for each sample
        batch_amps = []
        for x_i in x:
            # Check x_i type
            if not type(x_i) == torch.Tensor:
                x_i = torch.tensor(x_i)
            
            # Get the bulk parameters
            bulk_tn_params = {
                tid: {
                    ast.literal_eval(sector): data
                    for sector, data in self.torch_tn_params[str(tid)].items()
                }
                for tid in self.bulk_tid_list
            }
            tn_nn_params.update(bulk_tn_params)
            
            # Get the boundary parameters
            boundary_tn_params = {
                tid: {
                    ast.literal_eval(sector): data
                    for sector, data in self.torch_tn_params[str(tid)].items()
                }
                for tid in self.boundary_tid_list
            }
            boundary_tn_params_vec = flatten_proj_params(boundary_tn_params)

            # Get the NN correction to the boundary parameters
            nn_correction = self.nn(x_i)
            # Add the correction to the original parameters
            new_boundary_tn_params_vec = boundary_tn_params_vec + self.nn_eta*nn_correction
            new_boundary_tn_params = reconstruct_proj_params(new_boundary_tn_params_vec, boundary_tn_params)
            tn_nn_params.update(new_boundary_tn_params)
            # Reconstruct the TN with the new parameters
            psi = qtn.unpack(tn_nn_params, self.skeleton)
            # Get the amplitude
            amp = psi.get_amp(x_i, conj=True)

            if self.max_bond is None:
                amp = amp
            else:
                amp = amp.contract_boundary_from_ymin(max_bond=self.max_bond, cutoff=0.0, yrange=[0, psi.Ly//2-1])
                amp = amp.contract_boundary_from_ymax(max_bond=self.max_bond, cutoff=0.0, yrange=[psi.Ly//2, psi.Ly-1])

            amp_val = amp.contract()
            if amp_val==0.0:
                amp_val = torch.tensor(0.0)
            batch_amps.append(amp_val)

        # Return the batch of amplitudes stacked as a tensor
        return torch.stack(batch_amps)