In [16]:
from vmap_utils import Transformer_fPEPS_Model_batchedAttn, Transformer_fPEPS_Model

import quimb.tensor as qtn

import symmray as sr

import torch

# System parameters
Lx = 4
Ly = 2
nsites = Lx * Ly
D = 4
seed = 42
flat = True

# random fPEPS
peps = sr.networks.PEPS_fermionic_rand(
    "Z2",
    Lx,
    Ly,
    D,
    phys_dim=[
        (0, 0),  # linear index 0 -> charge 0, offset 0
        (1, 1),  # linear index 1 -> charge 1, offset 1
        (1, 0),  # linear index 2 -> charge 1, offset 0
        (0, 1),  # linear index 3 -> charge 0, offset 1
    ],  # -> (0, 3), (2, 1)
    # put an odd number of odd sites in, for testing
    # site_charge=lambda site: int(site in [(0, 0), (0, 1), (1, 0)]),
    subsizes="equal",
    flat=flat,
    seed=seed,
)
fx = torch.tensor([1, 2, 1, 2, 1, 2, 1, 2])

model = Transformer_fPEPS_Model_batchedAttn(
    tn=peps,
    max_bond=4*D,
    nn_eta=1,
    nn_hidden_dim=16,
    embed_dim=16,
    attn_heads=4,
    dtype=torch.float64,
)
nparams = sum(p.numel() for p in model.parameters())
model(fx.unsqueeze(0)), nparams

(tensor([-10.0554], dtype=torch.float64, grad_fn=<MulBackward0>), 4816)

In [8]:
import quimb as qu
tn = peps.copy()
params, skeleton = qtn.pack(tn)
# for torch, further flatten pytree into a single list
ftn_params_flat, ftn_params_pytree = qu.utils.tree_flatten(
    params, get_ref=True
)
ftn_params = torch.nn.ParameterList([
    torch.as_tensor(x, dtype=torch.float64) for x in ftn_params_flat
])
ftn_params

ParameterList(
    (0): Parameter containing: [torch.float64 of size 4x2x2x2]
    (1): Parameter containing: [torch.float64 of size 4x2x2x2]
    (2): Parameter containing: [torch.float64 of size 8x2x2x2x2]
    (3): Parameter containing: [torch.float64 of size 8x2x2x2x2]
    (4): Parameter containing: [torch.float64 of size 8x2x2x2x2]
    (5): Parameter containing: [torch.float64 of size 8x2x2x2x2]
    (6): Parameter containing: [torch.float64 of size 4x2x2x2]
    (7): Parameter containing: [torch.float64 of size 4x2x2x2]
)