TODO:
1. If want to use customized QR/SVD in batched version, extra work is needed.
2. The resulted fpeps from my SU code is not a flat fTNS, but a sparse one. Need conversion between flat and sparse fTN.

BUGS:
1. D!=4, bug in fTS.isel:
    `IndexError: index 2 is out of bounds for axis 4 with size 2`
2. Lx=Ly=D=4, chi=16, bug in vmap amp:
    `RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.`
3. pickle save SU flat fPEPS, when loading the saved skeleton and unpack back to fPEPS, get error:
    `ValueError: Cannot squeeze flat fermionic index with possible odd parity if array has no ordering `.label` attribute.`
   workaround: initialize a flat-rfPEPS and set its params to saved params.

In [1]:
import os
os.environ["OPENBLAS_NUM_THREADS"] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ["OMP_NUM_THREADS"] = '1'
import numpy as np
import quimb as qu
import quimb.tensor as qtn
import time
import symmray as sr

Lx = 4
Ly = 4
nsites = Lx * Ly
D = 4
chi = 3*D
seed = 42
# only the flat backend is compatible with jax.jit
flat = True

peps = sr.networks.PEPS_fermionic_rand(
    "Z2",
    Lx,
    Ly,
    D,
    phys_dim=[
        (0, 0),  # linear index 0 -> charge 0, offset 0
        (1, 1),  # linear index 1 -> charge 1, offset 1
        (1, 0),  # linear index 2 -> charge 1, offset 0
        (0, 1),  # linear index 3 -> charge 0, offset 1
    ],
    subsizes="equal",
    flat=flat,
    seed=seed,
)

In [2]:
# get pytree of initial parameters, and reference tn structure
params, skeleton = qtn.pack(peps)


def amplitude(x, params):
    tn = qtn.unpack(params, skeleton)

    fx = 2 * x[::2] + x[1::2]

    # might need to specify the right site ordering here
    amp = tn.isel({tn.site_ind(site): fx[i] for i, site in enumerate(tn.sites)})

    amp.contract_boundary_from_ymin_(max_bond=chi, cutoff=0.0, yrange=[0, amp.Ly//2-1])
    amp.contract_boundary_from_ymax_(max_bond=chi, cutoff=0.0, yrange=[amp.Ly//2, amp.Ly-1])

    return amp.contract()

In [3]:
# generate half-filling configs
# batchsize
B = 1024
rng = np.random.default_rng(seed)
xs_u = np.concatenate(
    [
        np.zeros((B, nsites // 2), dtype=np.int32),
        np.ones((B, nsites // 2), dtype=np.int32),
    ],
    axis=1,
)
xs_d = xs_u.copy()
xs_u = rng.permuted(xs_u, axis=1)
xs_d = rng.permuted(xs_d, axis=1)
xs = np.concatenate([xs_u[:, :, None], xs_d[:, :, None]], axis=2).reshape(B, -1)

First test non eager version:

In [4]:
amplitude(xs[0], params)

np.float64(35246.57543812603)

Then test version with torch, gpu tensors:

In [5]:
import torch

# torch.set_default_device("cuda:0") # GPU
torch.set_default_device("cpu") # CPU

# convert bitstrings and arrays to torch
xs = torch.tensor(xs)
params = qu.tree_map(
    lambda x: torch.tensor(x, dtype=torch.float32),
    params,
)

In [21]:
# %%timeit
# mantissa, exponent = amplitude(xs[0], params)
# mantissa, exponent

Then test and warm up torch vmapped version:

In [6]:
vamp = torch.vmap(
    amplitude,
    # batch on configs, not parameters
    in_dims=(0, None),
)

In [7]:
%%time
# warmup time
vamp(xs[:10], params)

CPU times: user 586 ms, sys: 28.2 ms, total: 615 ms
Wall time: 630 ms


RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.

In [13]:
from autoray import do
do('take', xs, 0, axis=0), xs[0]

(tensor([0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0,
         1, 1, 0, 1, 1, 1, 1, 0], dtype=torch.int32),
 tensor([0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0,
         1, 1, 0, 1, 1, 1, 1, 0], dtype=torch.int32))

In [30]:
t0 = time.time()
# final time (to compute full batch)
vamp(xs, params)
t1 = time.time()
print(f"Time taken: {t1 - t0} seconds")
vamp_time = (t1 - t0)/len(xs)
print(f"Time per sample: {vamp_time} seconds")

Time taken: 0.638314962387085 seconds
Time per sample: 0.0006233544554561377 seconds


In [43]:
t0 = time.time()
amplitude(xs[0], params)
t1 = time.time()
print(f"Time taken: {(t1 - t0)} seconds")
amp_time = (t1 - t0)

print(f'speedup: {amp_time/vamp_time}')

Time taken: 0.046881914138793945 seconds
speedup: 232.55358382350732


Then test a traced and jit compiled version:

In [25]:
%%time
# tracing time
amplitude_jit = torch.jit.trace(amplitude, (xs[0], params))

CPU times: user 13.3 s, sys: 774 ms, total: 14.1 s
Wall time: 14.1 s


In [26]:
%%time
# jit time
mantissa, exponent = amplitude_jit(xs[0], params)
mantissa, exponent

CPU times: user 7.95 s, sys: 141 ms, total: 8.09 s
Wall time: 8.41 s


(tensor(-1., device='cuda:0'), tensor(4.4671, device='cuda:0'))

In [27]:
%%timeit
# warmed up time
mantissa, exponent = amplitude_jit(xs[0], params)

77 ms ± 6.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
