In [None]:
'SYMMRAY_DEBUG'

In [5]:
import numpy as np
import quimb as qu
import quimb.tensor as qtn

import symmray as sr

Lx = 4
Ly = 4
nsites = Lx * Ly
D = 6
chi = D
seed = 42
# only the flat backend is compatible with jax.jit
flat = True

# batchsize
B = 1024

peps = sr.networks.PEPS_fermionic_rand(
    "Z2",
    Lx,
    Ly,
    D,
    phys_dim=[
        (0, 0),  # linear index 0 -> charge 0, offset 0
        (1, 1),  # linear index 1 -> charge 1, offset 1
        (1, 0),  # linear index 2 -> charge 1, offset 0
        (0, 1),  # linear index 3 -> charge 0, offset 1
    ],
    subsizes="equal",
    flat=flat,
    seed=seed,
)

# get pytree of initial parameters, and reference tn structure
params, skeleton = qtn.pack(peps)


def amplitude(x, params):
    tn = qtn.unpack(params, skeleton)

    # might need to specify the right site ordering here
    tnx = tn.isel({tn.site_ind(site): x[i] for i, site in enumerate(tn.sites)})

    return tnx.contract_hotrg(
        max_bond=chi,
        cutoff=0.0,
        # these two options make the return value (mantissa, exponent)
        # which can avoid issues with small/large values and stability
        equalize_norms=1.0,
        final_contract_opts=dict(strip_exponent=True),
    )

# generate half-filling configs
rng = np.random.default_rng(seed)
xs_u = np.concatenate(
    [
        np.zeros((B, nsites // 2), dtype=np.int32),
        np.ones((B, nsites // 2), dtype=np.int32),
    ],
    axis=1,
)
xs_d = xs_u.copy()
xs_u = rng.permuted(xs_u, axis=1)
xs_d = rng.permuted(xs_d, axis=1)
xs = np.concatenate([xs_u[:, :, None], xs_d[:, :, None]], axis=2).reshape(B, -1)
xs = 2*xs[:, ::2] + xs[:, 1::2]

mantissa, exponent = amplitude(xs[0], params)
print(mantissa, exponent)

IndexError: index 2 is out of bounds for axis 4 with size 2