In [241]:
import os
os.environ["NUMBA_NUM_THREADS"] = "20"

import netket as nk
import netket.experimental as nkx
import netket.nn as nknn

from math import pi

from netket.experimental.operator.fermion import destroy as c
from netket.experimental.operator.fermion import create as cdag
from netket.experimental.operator.fermion import number as nc

from vmc_torch.fermion_utils import generate_random_fpeps
import quimb.tensor as qtn
import symmray as sr
import pickle

# Define the lattice shape
L = 4  # Side of the square
Lx = int(L)
Ly = int(L)
spinless = False
# graph = nk.graph.Square(L)
graph = nk.graph.Grid([Lx,Ly], pbc=False)
N = graph.n_nodes

# Define the fermion filling and the Hilbert space
N_f = int(Lx*Ly)
hi = nkx.hilbert.SpinOrbitalFermions(N, s=1/2, n_fermions=N_f)


# Define the Hubbard Hamiltonian
t = 1.0
U = 8.0
mu = 0.0

H = 0.0
for (i, j) in graph.edges(): # Definition of the Hubbard Hamiltonian
    for spin in (1,-1):
        H -= t * (cdag(hi,i,spin) * c(hi,j,spin) + cdag(hi,j,spin) * c(hi,i,spin))
for i in graph.nodes():
    H += U * nc(hi,i,+1) * nc(hi,i,-1)


# SU in quimb
D = 3
seed = 2
symmetry = 'U1'
spinless = False
peps = generate_random_fpeps(Lx, Ly, D=D, seed=2, symmetry=symmetry, Nf=N_f, spinless=spinless)[0]
edges = qtn.edges_2d_square(Lx, Ly, cyclic=False)
site_info = sr.utils.parse_edges_to_site_info(
    edges,
    D,
    phys_dim=4,
    site_ind_id="k{},{}",
    site_tag_id="I{},{}",
)

t = 1.0
U = 8.0
mu = 0.0

terms = {
    (sitea, siteb): sr.fermi_hubbard_local_array(
        t=t, U=U, mu=mu,
        symmetry=symmetry,
        coordinations=(
            site_info[sitea]['coordination'],
            site_info[siteb]['coordination'],
        ),
    ).fuse((0, 1), (2, 3))
    for (sitea, siteb) in peps.gen_bond_coos()
}
ham = qtn.LocalHam2D(Lx, Ly, terms)

su = qtn.SimpleUpdateGen(peps, ham, compute_energy_per_site=True,D=D, compute_energy_opts={"max_distance":1}, gate_opts={'cutoff':1e-12})

# cluster energies may not be accuracte yet
su.evolve(50, tau=0.3)
su.evolve(50, tau=0.1)
su.evolve(50, tau=0.03)
# su.evolve(100, tau=0.01)
# su.evolve(100, tau=0.003)

peps = su.get_state()
peps.equalize_norms_(value=1)

# save the state
params, skeleton = qtn.pack(peps)

n=50, tau=0.3000, energy~-0.266968: 100%|##########| 50/50 [00:06<00:00,  7.35it/s]
n=100, tau=0.1000, energy~-0.342996: 100%|##########| 50/50 [00:06<00:00,  7.67it/s]
n=150, tau=0.0300, energy~-0.345458: 100%|##########| 50/50 [00:06<00:00,  7.73it/s]


In [242]:
from mpi4py import MPI
import ast
# torch
import torch
import torch.nn as nn

# quimb
import quimb.tensor as qtn
from quimb.tensor.tensor_2d import Rotator2D, pairwise
import symmray as sr
import autoray as ar
from autoray import do

from vmc_torch.fermion_utils import insert_proj_peps, flatten_proj_params, reconstruct_proj_params, insert_compressor
from vmc_torch.global_var import DEBUG, set_debug

# Define the custom zero-initialization function
def init_weights_to_zero(m, std=1e-3):
    if isinstance(m, (nn.Linear, nn.Conv2d, nn.Embedding, nn.LayerNorm)):
        # Set weights and biases to zero
        if hasattr(m, 'weight') and m.weight is not None:
            nn.init.normal_(m.weight, mean=0.0, std=std)
        if hasattr(m, 'bias') and m.bias is not None:
            nn.init.constant_(m.bias, 0.0)

class fTNModel(torch.nn.Module):

    def __init__(self, ftn, max_bond=None):
        super().__init__()
        
        # extract the raw arrays and a skeleton of the TN
        params, self.skeleton = qtn.pack(ftn)

        # Flatten the dictionary structure and assign each parameter as a part of a ModuleDict
        self.torch_tn_params = nn.ModuleDict({
            str(tid): nn.ParameterDict({
                str(sector): nn.Parameter(data)
                for sector, data in blk_array.items()
            })
            for tid, blk_array in params.items()
        })

        # Get symmetry
        self.symmetry = ftn.arrays[0].symmetry

        # Store the shapes of the parameters
        self.param_shapes = [param.shape for param in self.parameters()]

        self.model_structure = {
            'fPEPS (exact contraction)':{'D': ftn.max_bond(), 'Lx': ftn.Lx, 'Ly': ftn.Ly, 'symmetry': self.symmetry},
        }
        if max_bond is None or max_bond <= 0:
            max_bond = None
        self.max_bond = max_bond
        
    
    def from_params_to_vec(self):
        return torch.cat([param.data.flatten() for param in self.parameters()])
    
    @property
    def num_params(self):
        return len(self.from_params_to_vec())
    
    def params_grad_to_vec(self):
        param_grad_vec = torch.cat([param.grad.flatten() if param.grad is not None else torch.zeros_like(param).flatten() for param in self.parameters()])
        return param_grad_vec

    def clear_grad(self):
        for param in self.parameters():
            if param is not None:
                param.grad = None
    
    def from_vec_to_params(self, vec, quimb_format=False):
        # Reconstruct the original parameter structure (by unpacking from the flattened dict)
        # XXX: useful at all?
        params = {}
        idx = 0
        for tid, blk_array in self.torch_tn_params.items():
            params[tid] = {}
            for sector, data in blk_array.items():
                shape = data.shape
                size = data.numel()
                if quimb_format:
                    params[tid][ast.literal_eval(sector)] = vec[idx:idx+size].view(shape)
                else:
                    params[tid][sector] = vec[idx:idx+size].view(shape)
                idx += size
        return params
    
    def load_params(self, new_params):
        pointer = 0
        for param, shape in zip(self.parameters(), self.param_shapes):
            num_param = param.numel()
            new_param_values = new_params[pointer:pointer+num_param].view(shape)
            with torch.no_grad():
                param.copy_(new_param_values)
            pointer += num_param

    
    def amplitude(self, x):
        # Reconstruct the original parameter structure (by unpacking from the flattened dict)
        params = {
            int(tid): {
                ast.literal_eval(sector): data
                for sector, data in blk_array.items()
            }
            for tid, blk_array in self.torch_tn_params.items()
        }
        # Reconstruct the TN with the new parameters
        psi = qtn.unpack(params, self.skeleton)
        # `x` is expected to be batched as (batch_size, input_dim)
        # Loop through the batch and compute amplitude for each sample
        batch_amps = []
        for x_i in x:
            amp = psi.get_amp(x_i, conj=True)
            if self.max_bond is None:
                amp = amp
            else:
                amp = amp.contract_boundary_from_ymin(max_bond=self.max_bond, cutoff=0.0, yrange=[0, psi.Ly-2])
            amp_val = amp.contract()
            if amp_val==0.0:
                amp_val = torch.tensor(0.0)
            batch_amps.append(amp_val)

        # Return the batch of amplitudes stacked as a tensor
        return torch.stack(batch_amps)
    
    def forward(self, x):
        if x.ndim == 1:
            # If input is not batched, add a batch dimension
            x = x.unsqueeze(0)
        return self.amplitude(x)

class fTN_backflow_Model(torch.nn.Module):

    def __init__(self, ftn, max_bond=None, nn_hidden_dim=128, nn_eta=1e-3, dtype=torch.float32):
        super().__init__()
        self.param_dtype = dtype
        
        # extract the raw arrays and a skeleton of the TN
        params, self.skeleton = qtn.pack(ftn)

        # Flatten the dictionary structure and assign each parameter as a part of a ModuleDict
        self.torch_tn_params = nn.ModuleDict({
            str(tid): nn.ParameterDict({
                str(sector): nn.Parameter(data)
                for sector, data in blk_array.items()
            })
            for tid, blk_array in params.items()
        })

        # Define the neural network
        input_dim = ftn.Lx * ftn.Ly
        tn_params_vec = flatten_proj_params(params)
        self.nn = nn.Sequential(
            nn.Linear(input_dim, nn_hidden_dim),
            nn.ReLU(),
            nn.Linear(nn_hidden_dim, tn_params_vec.numel())
        )
        self.nn.to(self.param_dtype)

        # Get symmetry
        self.symmetry = ftn.arrays[0].symmetry

        # Store the shapes of the parameters
        self.param_shapes = [param.shape for param in self.parameters()]

        self.model_structure = {
            'fPEPS (exact contraction)':{'D': ftn.max_bond(), 'Lx': ftn.Lx, 'Ly': ftn.Ly, 'symmetry': self.symmetry},
        }
        if max_bond is None or max_bond <= 0:
            max_bond = None
        self.max_bond = max_bond
        self.nn_eta = nn_eta
        
        
    
    def from_params_to_vec(self):
        return torch.cat([param.data.flatten() for param in self.parameters()])
    
    @property
    def num_params(self):
        return len(self.from_params_to_vec())
    
    def params_grad_to_vec(self):
        param_grad_vec = torch.cat([param.grad.flatten() if param.grad is not None else torch.zeros_like(param).flatten() for param in self.parameters()])
        return param_grad_vec

    def clear_grad(self):
        for param in self.parameters():
            if param is not None:
                param.grad = None
    
    def load_params(self, new_params):
        pointer = 0
        for param, shape in zip(self.parameters(), self.param_shapes):
            num_param = param.numel()
            new_param_values = new_params[pointer:pointer+num_param].view(shape)
            with torch.no_grad():
                param.copy_(new_param_values)
            pointer += num_param

    
    def amplitude(self, x):
        # Reconstruct the original parameter structure (by unpacking from the flattened dict)
        params = {
            int(tid): {
                ast.literal_eval(sector): data
                for sector, data in blk_array.items()
            }
            for tid, blk_array in self.torch_tn_params.items()
        }
        params_vec = flatten_proj_params(params)
        # `x` is expected to be batched as (batch_size, input_dim)
        # Loop through the batch and compute amplitude for each sample
        batch_amps = []
        for x_i in x:
            # Get the NN correction to the parameters
            nn_correction = self.nn(x_i.to(self.param_dtype))
            # Add the correction to the original parameters
            tn_nn_params = reconstruct_proj_params(params_vec + self.nn_eta*nn_correction, params)
            # Reconstruct the TN with the new parameters
            psi = qtn.unpack(tn_nn_params, self.skeleton)
            # Get the amplitude
            amp = psi.get_amp(x_i, conj=True)

            if self.max_bond is None:
                amp = amp
            else:
                amp = amp.contract_boundary_from_ymin(max_bond=self.max_bond, cutoff=0.0, yrange=[0, psi.Ly-2])

            amp_val = amp.contract()
            if amp_val==0.0:
                amp_val = torch.tensor(0.0)
            batch_amps.append(amp_val)

        # Return the batch of amplitudes stacked as a tensor
        return torch.stack(batch_amps)
    
    def forward(self, x):
        if x.ndim == 1:
            # If input is not batched, add a batch dimension
            x = x.unsqueeze(0)
        return self.amplitude(x)

In [326]:
peps.apply_to_arrays(torch.tensor)
model = fTN_backflow_Model(peps, max_bond=4, nn_hidden_dim=32, nn_eta=1.0, dtype=torch.float32)
model1 = fTNModel(peps, max_bond=4)
model.apply(lambda x: init_weights_to_zero(x, std=1e-2))
model.num_params, model1.num_params

(18598, 531)

In [324]:
from vmc_torch.hamiltonian import spinful_Fermi_Hubbard_square_lattice
# Hamiltonian parameters
Lx = int(4)
Ly = int(4)
symmetry = 'U1'
t = 1.0
U = 8.0
N_f = int(Lx*Ly)
n_fermions_per_spin = (N_f//2, N_f//2)
H = spinful_Fermi_Hubbard_square_lattice(Lx, Ly, t, U, N_f, pbc=False, n_fermions_per_spin=n_fermions_per_spin)
import jax
import numpy as np
random_seed = np.random.randint(2**32)
random_config = torch.tensor(H.hilbert.random_state(key=jax.random.PRNGKey(random_seed)), dtype=torch.int)
model(random_config), model1(random_config)

(tensor([7.4627e-09], dtype=torch.float64, grad_fn=<StackBackward0>),
 tensor([1.0486e-08], dtype=torch.float64, grad_fn=<StackBackward0>))