In [None]:
import torch
@torch.inference_mode()
def evaluate_energy_reuse(fxs, fpeps_model, H, current_amps, verbose=False):
    # TODO: divide the connected configs into chunks of size fxs.shape[0] to avoid OOM [x]
    B = fxs.shape[0]
    # get connected configurations and coefficients
    conn_eta_num = []
    conn_etas = []
    conn_eta_coeffs = []
    for fx in fxs:
        eta, coeffs = H.get_conn(fx)
        conn_eta_num.append(len(eta))
        conn_etas.append(torch.tensor(eta))
        conn_eta_coeffs.append(torch.tensor(coeffs))

    conn_etas = torch.cat(conn_etas, dim=0)
    conn_eta_coeffs = torch.cat(conn_eta_coeffs, dim=0)

    if verbose:
        print(f'Prepared batched conn_etas and coeffs: {conn_etas.shape}, {conn_eta_coeffs.shape} (batch size {B})')

    # calculate amplitudes for connected configs, in the future consider TN reuse to speed up calculation, TN reuse is controlled by a param that is not batched over (control flow?)
    conn_amps = fpeps_model(conn_etas)

    # Local energy \sum_{s'} H_{s,s'} <s'|psi>/<s|psi>

    local_energies = []
    offset = 0
    for b in range(B):
        n_conn = conn_eta_num[b]
        amps_ratio = conn_amps[offset:offset+n_conn] / current_amps[b]
        energy_b = torch.sum(conn_eta_coeffs[offset:offset+n_conn] * amps_ratio)
        local_energies.append(energy_b)
        offset += n_conn
    local_energies = torch.stack(local_energies, dim=0)
    if verbose:
        print(f'Batched local energies: {local_energies.shape}')

    # Energy: (1/N) * \sum_s <s|H|psi>/<s|psi> = (1/N) * \sum_s \sum_{s'} H_{s,s'} <s'|psi>/<s|psi>
    energy = torch.mean(local_energies)
    if verbose:
        print(f'Energy: {energy.item()}')

    return energy, local_energies


In [2]:
from torch import nn
import quimb as qu
import quimb.tensor as qtn
class SimpleModel_reuse(nn.Module):
    def __init__(self, tn, max_bond, dtype=torch.float64):
        import quimb as qu
        import quimb.tensor as qtn
        super().__init__()
        
        params, skeleton = qtn.pack(tn)
        self.dtype = dtype
        self.skeleton = skeleton
        self.chi = max_bond
        # for torch, further flatten pytree into a single list
        params_flat, params_pytree = qu.utils.tree_flatten(
            params, get_ref=True
        )
        self.params_pytree = params_pytree

        # register the flat list parameters
        self.params = nn.ParameterList([
            torch.as_tensor(x, dtype=self.dtype) for x in params_flat
        ])
    
    def amplitude(self, x, params, cache_bmps=False):
        tn = qtn.unpack(params, self.skeleton)
        # might need to specify the right site ordering here
        amp = tn.isel({tn.site_ind(site): x[i] for i, site in enumerate(tn.sites)})
        if self.chi > 0:
            amp.contract_boundary_from_ymin_(max_bond=self.chi, cutoff=0.0, yrange=[0, amp.Ly//2-1])
            amp.contract_boundary_from_ymax_(max_bond=self.chi, cutoff=0.0, yrange=[amp.Ly//2, amp.Ly-1])
        if cache_bmps:
            env_x = amp.compute_x_environments(max_bond=self.chi, cutoff=0.0)
            return amp.contract(), env_x
        
        return amp.contract()
    
    def vamp(self, x, params, cache_bmps=False):
        params = qu.utils.tree_unflatten(params, self.params_pytree)
        return torch.vmap(
            self.amplitude,
            in_dims=(0, None, None),
        )(x, params, cache_bmps)

    def forward(self, x, cache_bmps=False):
        return self.vamp(x, self.params, cache_bmps)

In [3]:
import symmray as sr
from vmc_torch.hamiltonian_torch import spinful_Fermi_Hubbard_square_lattice_torch
Lx = 4
Ly = 4
nsites = Lx * Ly
N_f = nsites  # half-filling
D = 4
chi = 2*D # BUG: chi=4*D then the contraction becomes input dependent?
# only the flat backend is compatible with jax.jit
flat = True
peps = sr.networks.PEPS_fermionic_rand(
    "Z2",
    Lx,
    Ly,
    D,
    phys_dim=[
        (0, 0),  # linear index 0 -> charge 0, offset 0
        (1, 0),  # linear index 1 -> charge 1, offset 0
        (1, 1),  # linear index 2 -> charge 1, offset 1
        (0, 1),  # linear index 3 -> charge 0, offset 1
    ],
    subsizes="equal",
    flat=True,
    seed=42,
)
peps.apply_to_arrays(lambda x: torch.tensor(x, dtype=torch.float64))

fpeps_model = SimpleModel_reuse(peps, max_bond=chi, dtype=torch.float64)
n_params = sum(p.numel() for p in fpeps_model.parameters())

# generate Hamiltonian graph
t=1.0
U=8.0
N_f = int(Lx*Ly) # half-filling
n_fermions_per_spin = (N_f // 2, N_f // 2)
H = spinful_Fermi_Hubbard_square_lattice_torch(
    Lx,
    Ly,
    t,
    U,
    N_f,
    pbc=False,
    n_fermions_per_spin=n_fermions_per_spin,
    no_u1_symmetry=False,
)
graph = H.graph

fxs = [torch.tensor(H.hilbert.random_state()) for _ in range(10)]
fxs = torch.stack(fxs, dim=0)

In [4]:
def cache(fx):
    amp = peps.isel({peps.site_ind(site): fx[i] for i, site in enumerate(peps.sites)})
    env_x = amp.compute_x_environments(max_bond=chi, cutoff=0.0)
    return amp.contract()

def vcache(fx):
    return torch.vmap(cache, in_dims=0, out_dims=0)(fx)

vcache(fxs) # BUG: the cache function must only return `pytree` to allow vmap

cache(fxs[0])
# TODO: do one time non-vmap contraction for cached skeletons of bMPSs

tensor(21168.4573, dtype=torch.float64)