In [36]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [37]:
from model import TQS
import torch

In [38]:
EMBED_DIM = 64
MAX_LENGTH = 100
NUM_HEADS = 1
NUM_LAYERS = 1
DIM_FEEDFORWARD = 128
TEST_LENGTH = 50
TEST_BATCH = 32

In [39]:
test_potentials = torch.randn(TEST_LENGTH, TEST_BATCH)
test_spins = torch.zeros(TEST_LENGTH, TEST_BATCH)
test_spins[torch.randint(0, TEST_LENGTH, (TEST_BATCH,)), torch.arange(TEST_BATCH)] = 1

In [40]:
tqs = TQS(
    embed_dim=EMBED_DIM,
    max_chain_len=MAX_LENGTH,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    possible_spins=2,
    dim_feedforward=DIM_FEEDFORWARD,
)



# Model Output

In [41]:
probs, phases = tqs(test_potentials, test_spins)

In [42]:
probs

tensor([[[0.5380, 0.5356],
         [0.4481, 0.6791],
         [0.5979, 0.5281],
         ...,
         [0.5739, 0.4870],
         [0.6269, 0.3589],
         [0.6757, 0.3046]],

        [[0.6940, 0.3112],
         [0.6776, 0.3211],
         [0.6816, 0.3187],
         ...,
         [0.4421, 0.6806],
         [0.4159, 0.7102],
         [0.6719, 0.3144]],

        [[0.6981, 0.3292],
         [0.6812, 0.3502],
         [0.6492, 0.3499],
         ...,
         [0.6708, 0.3952],
         [0.6402, 0.4293],
         [0.6183, 0.4653]],

        ...,

        [[0.5910, 0.2289],
         [0.6081, 0.2579],
         [0.6124, 0.2532],
         ...,
         [0.6440, 0.2477],
         [0.6285, 0.2718],
         [0.5965, 0.2648]],

        [[0.6129, 0.2108],
         [0.6105, 0.2613],
         [0.5909, 0.2515],
         ...,
         [0.6097, 0.2118],
         [0.6132, 0.2060],
         [0.6112, 0.2572]],

        [[0.5790, 0.2679],
         [0.6032, 0.2958],
         [0.7620, 0.2118],
         ...,
 

# Autoregressive Sampling

In [43]:
def autoregressive_sample(model, initial_potentials, max_length):
    model.eval()
    # Allocate a buffer for the sampled basis states
    sampled_spins = torch.zeros(max_length, initial_potentials.size(1))
    batch_size = initial_potentials.size(1)
    batch_remaining_idx = torch.arange(batch_size)

    for i in range(max_length):
        # get P(s_{i+1} | V, s_{1:i}) and phi(s_{i+1} | V, s_{1:i}) distributions
        probs, _ = model(
            initial_potentials[:, batch_remaining_idx],
            sampled_spins[:i, batch_remaining_idx],
        )

        # sample s_{i+1} from P(s_{i+1} | V, s_{1:i})
        last_probs = probs[-1]  # (batch, 2); P(s_{i+1} | V, s_{1:i})
        sampled_spins[i, batch_remaining_idx] = (
            torch.multinomial(last_probs, 1).squeeze().float()
        )

        # a mask with dimension (batch_remaining,); True if we sampled a 1
        newly_completed = sampled_spins[i, batch_remaining_idx] == 1.0

        # mask out the batch_remaining_idx that have been completed
        batch_remaining_idx = batch_remaining_idx[~newly_completed]

    return sampled_spins


sampled_spins = autoregressive_sample(tqs, test_potentials, TEST_LENGTH)
sampled_spins

tensor([[0., 0., 1.,  ..., 1., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [44]:
sampled_spins.sum(dim=0)

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [45]:
sampled_spins.shape

torch.Size([50, 32])

In [46]:
sampled_spins = tqs.sample_spins(test_potentials, TEST_LENGTH)
sampled_spins

tensor([[0., 0., 0.,  ..., 1., 0., 1.],
        [0., 0., 1.,  ..., 0., 1., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [47]:
assert (
    sampled_spins.sum(dim=0) == 1.0
).all()  # each batch should have exactly one 1 spin

In [48]:
TEST_BATCH

32

In [49]:
TEST_LENGTH

50

In [50]:
T = 1.0
print(tqs.compute_psi(sampled_spins, test_potentials, T).shape)

torch.Size([50, 32])


In [51]:
psi_x, psi_l, psi_r = tqs.psi_terms(sampled_spins, test_potentials, T)
assert (
    psi_x.shape == psi_l.shape == psi_r.shape == torch.Size([TEST_LENGTH, TEST_BATCH])
)
psi_x, psi_l, psi_r

(tensor([[ 0.0152-0.7677j,  0.0049-0.7665j, -0.0710-0.7565j,  ...,
           0.5052+0.7007j, -0.0490-0.7604j,  0.5660+0.6489j],
         [ 0.1236-0.7585j,  0.1155-0.7587j,  0.4753+0.7158j,  ...,
           0.1146-0.7587j,  0.4615+0.7254j,  0.0094-0.7560j],
         [-0.1148-0.7627j,  0.6276+0.5856j, -0.1579-0.7483j,  ...,
          -0.1236-0.7602j, -0.1524-0.7505j, -0.1822-0.7376j],
         ...,
         [-0.4433-0.6510j, -0.4596-0.6382j, -0.4987-0.6033j,  ...,
          -0.4542-0.6422j, -0.4922-0.6096j, -0.5244-0.5762j],
         [-0.4751-0.6324j, -0.4895-0.6198j, -0.5281-0.5826j,  ...,
          -0.4843-0.6243j, -0.5202-0.5908j, -0.5531-0.5547j],
         [-0.1791-0.7542j, -0.2028-0.7466j, -0.2713-0.7195j,  ...,
          -0.1915-0.7502j, -0.2560-0.7263j, -0.3198-0.6948j]],
        grad_fn=<MulBackward0>),
 tensor([[ 1.0397e-02-0.7674j,  1.4821e-03-0.7662j, -6.8553e-02-0.7569j,
           ..., -6.4291e-04-0.7660j, -4.8013e-02-0.7604j,
          -1.0023e-01-0.7486j],
         [ 1.20

In [52]:
E_loc = tqs.E_loc(psi_x, psi_l, psi_r, test_potentials, T)
E_loc

tensor([[-1.7213+7.5433e-04j,  0.0045+9.3430e-03j,  0.0737-6.3609e-01j,
          ...,  1.3469+1.0397e+00j, -0.4282-6.3878e-01j,
         -1.2808+9.6873e-01j],
        [-4.0353+1.8561e-03j, -1.5842-7.1718e-01j, -0.0490+1.0840e+00j,
          ...,  0.9579-7.1503e-01j,  2.6905+1.0798e+00j,
         -1.3292-6.6351e-01j],
        [-0.9715-6.7107e-01j,  0.6846+1.0984e+00j, -0.7564-6.9959e-01j,
          ..., -2.2852+1.5685e-04j, -0.4408-6.7715e-01j,
         -2.3531-2.6554e-04j],
        ...,
        [-1.4233+1.8915e-03j, -1.4623-3.0339e-03j, -1.8690-2.1616e-03j,
          ..., -1.3657-6.2466e-04j, -2.1946+3.6514e-04j,
         -2.8395+1.5085e-03j],
        [ 0.1083+6.7896e-04j, -2.4013-4.9192e-04j, -1.3422-2.6135e-03j,
          ..., -0.9792+1.1694e-04j, -2.9170+4.3970e-04j,
         -1.7497+3.8087e-04j],
        [-2.0055-1.9759e-05j, -3.5131+2.7206e-03j, -2.0150-4.5414e-03j,
          ..., -1.6840-6.9516e-01j, -1.7233+6.1059e-04j,
         -0.4289-6.4541e-01j]], grad_fn=<AddBackward0>)

In [53]:
psi_x

tensor([[ 0.0152-0.7677j,  0.0049-0.7665j, -0.0710-0.7565j,  ...,
          0.5052+0.7007j, -0.0490-0.7604j,  0.5660+0.6489j],
        [ 0.1236-0.7585j,  0.1155-0.7587j,  0.4753+0.7158j,  ...,
          0.1146-0.7587j,  0.4615+0.7254j,  0.0094-0.7560j],
        [-0.1148-0.7627j,  0.6276+0.5856j, -0.1579-0.7483j,  ...,
         -0.1236-0.7602j, -0.1524-0.7505j, -0.1822-0.7376j],
        ...,
        [-0.4433-0.6510j, -0.4596-0.6382j, -0.4987-0.6033j,  ...,
         -0.4542-0.6422j, -0.4922-0.6096j, -0.5244-0.5762j],
        [-0.4751-0.6324j, -0.4895-0.6198j, -0.5281-0.5826j,  ...,
         -0.4843-0.6243j, -0.5202-0.5908j, -0.5531-0.5547j],
        [-0.1791-0.7542j, -0.2028-0.7466j, -0.2713-0.7195j,  ...,
         -0.1915-0.7502j, -0.2560-0.7263j, -0.3198-0.6948j]],
       grad_fn=<MulBackward0>)

In [54]:
for i in tqs.parameters():
    print(i)

Parameter containing:
tensor([ 0.6717, -2.0510,  1.5898, -1.0533, -0.5701, -0.8790,  1.6793, -1.3156,
         0.2563,  0.9460,  1.6996,  2.0750,  0.4715,  0.4805,  0.0139, -0.6473,
         0.3041, -2.0161, -0.0596,  1.3679, -0.7808, -1.3750, -1.7872, -0.4623,
        -0.4465,  0.5659, -0.4160,  0.7223,  0.0461, -0.9342,  0.2083,  1.9021,
        -2.6492,  0.3302,  0.0999,  0.7141, -0.2292, -0.4616, -2.3555,  2.7235,
        -0.8708, -1.2030, -0.3637, -0.7009, -0.9202, -0.5260,  0.2257,  2.2966,
        -2.8187, -0.2005, -0.2349, -0.6985, -1.5180,  1.0470, -1.2092,  0.0643,
         1.4487, -1.5407,  0.8551, -0.9517,  0.2826,  0.2067, -1.1548, -0.8621],
       requires_grad=True)
Parameter containing:
tensor([ 0.4390, -1.4166,  0.1547, -0.0180,  1.6928, -0.6779, -0.3299,  1.4186,
        -1.1165,  1.5816,  1.0712, -0.8469, -0.5574, -0.7589,  0.9012, -0.2221,
        -1.4166, -1.5427,  1.2969,  1.0006,  0.1996, -0.4680, -0.7375, -0.6803,
         0.5467, -0.1454, -0.2836, -0.8476,  0.6

In [55]:
psi_x.shape

torch.Size([50, 32])

In [56]:
psi_x

tensor([[ 0.0152-0.7677j,  0.0049-0.7665j, -0.0710-0.7565j,  ...,
          0.5052+0.7007j, -0.0490-0.7604j,  0.5660+0.6489j],
        [ 0.1236-0.7585j,  0.1155-0.7587j,  0.4753+0.7158j,  ...,
          0.1146-0.7587j,  0.4615+0.7254j,  0.0094-0.7560j],
        [-0.1148-0.7627j,  0.6276+0.5856j, -0.1579-0.7483j,  ...,
         -0.1236-0.7602j, -0.1524-0.7505j, -0.1822-0.7376j],
        ...,
        [-0.4433-0.6510j, -0.4596-0.6382j, -0.4987-0.6033j,  ...,
         -0.4542-0.6422j, -0.4922-0.6096j, -0.5244-0.5762j],
        [-0.4751-0.6324j, -0.4895-0.6198j, -0.5281-0.5826j,  ...,
         -0.4843-0.6243j, -0.5202-0.5908j, -0.5531-0.5547j],
        [-0.1791-0.7542j, -0.2028-0.7466j, -0.2713-0.7195j,  ...,
         -0.1915-0.7502j, -0.2560-0.7263j, -0.3198-0.6948j]],
       grad_fn=<MulBackward0>)

In [57]:
# https://discuss.pytorch.org/t/example-for-one-of-the-differentiated-tensors-appears-to-not-have-been-used-in-the-graph/58396
a = torch.rand(10, requires_grad=True)
b = torch.rand(10, requires_grad=True)

output = (2 * a).sum()

torch.autograd.grad(output, (a), allow_unused=True)

(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),)

In [58]:
print(test_potentials.shape)
print(sampled_spins.shape)

torch.Size([50, 32])
torch.Size([50, 32])


In [59]:
params = {k: v for k, v in tqs.named_parameters()}
buffers = {k: v for k, v in tqs.named_buffers()}
print(dlnP_dTheta := tqs.dlnP_dTheta(sampled_spins, test_potentials, params, buffers))

tensor([[[[6.9006e-01, 1.3752e-03, 1.3267e-03,  ..., 8.9926e-04,
           8.8654e-04, 9.6735e-04]],

         [[1.1664e-03, 7.0071e-01, 1.1851e-03,  ..., 7.7769e-04,
           7.7244e-04, 8.6914e-04]],

         [[1.0144e-03, 1.1042e-03, 6.7375e-01,  ..., 7.0872e-04,
           7.0849e-04, 8.1121e-04]],

         ...,

         [[9.9526e-04, 1.0698e-03, 1.1388e-03,  ..., 6.0544e-01,
           1.1210e-03, 1.1624e-03]],

         [[1.0441e-03, 1.1067e-03, 1.1484e-03,  ..., 1.0022e-03,
           6.1370e-01, 1.0807e-03]],

         [[1.3325e-03, 1.4025e-03, 1.4130e-03,  ..., 1.1025e-03,
           1.1270e-03, 6.9311e-01]]],


        [[[6.7785e-01, 1.3981e-03, 1.8666e-03,  ..., 9.1634e-04,
           9.0292e-04, 9.8103e-04]],

         [[1.1806e-03, 6.9862e-01, 2.1555e-03,  ..., 7.8924e-04,
           7.8390e-04, 8.8104e-04]],

         [[3.3260e-04, 4.2896e-04, 3.0662e-02,  ..., 3.4048e-04,
           2.8368e-04, 2.7715e-04]],

         ...,

         [[9.4954e-04, 1.0205e-03, 1.2223

Why is `jacrev` the right function to use here and what does it do?

In [60]:
from torch.func import functional_call, vmap, jacrev

In [61]:
# grab a spin chain and a corresponding potential function,
# recalling that their dimensions are (seq, batch)
a_spin_chain = sampled_spins.clone().detach().requires_grad_(True)[:, 1]
a_potential_func = test_potentials.clone().detach().requires_grad_(True)[:, 1]

In [62]:
print(a_spin_chain)
print(a_potential_func)

tensor([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       grad_fn=<SelectBackward0>)
tensor([ 2.0037, -1.4423, -0.7348, -0.2919,  1.9261, -1.0678,  0.2192,  1.7884,
         0.1938, -0.5133,  0.6416, -0.8084, -1.2946,  1.0344,  0.6041,  1.4357,
         1.0964,  1.6152,  0.9775,  1.0709,  2.1710, -1.0971, -0.9162, -1.4216,
         0.4786,  0.2975,  0.3549,  0.8011, -0.2619,  0.3683,  0.5602, -0.2319,
         0.4204, -1.2452, -0.5627, -1.2707, -0.2916,  1.2549, -0.0316, -1.0407,
        -0.1653, -0.7021,  0.7244, -1.6090,  0.2124,  0.3022,  1.3344,  0.5378,
        -0.4012, -1.5133], grad_fn=<SelectBackward0>)


In [63]:
params = {k: v.detach for k, v in tqs.named_parameters()}
buffers = {k: v.detach for k, v in tqs.named_buffers()}

deriv = jacrev(tqs, argnums=1)(a_potential_func.unsqueeze(1), a_spin_chain.unsqueeze(1))

# Jacobian of probabilities; Jacobian of phases.
for d in deriv:
    print(d.shape)

torch.Size([100, 1, 2, 50, 1])
torch.Size([100, 1, 1, 50, 1])


## Finding the Jacobian of $\ln P(x; \Theta)$
- For a singleton-batch spin chain
- Should be with the jacrev function

In [64]:
params = {k: v.detach() for k, v in tqs.named_parameters()}
buffers = {k: v.detach() for k, v in tqs.named_buffers()}

lnP = tqs._ln_P(a_spin_chain, a_potential_func, params, buffers)
lnP

tensor([[-0.5317],
        [-0.5293],
        [-0.3054],
        [-0.5118],
        [-0.5230],
        [-0.5531],
        [-0.6074],
        [-0.6630],
        [-0.6746],
        [-0.6420],
        [-0.6014],
        [-0.5678],
        [-0.5241],
        [-0.4478],
        [-0.3698],
        [-0.3058],
        [-0.2782],
        [-0.2892],
        [-0.3390],
        [-0.4384],
        [-0.5318],
        [-0.5516],
        [-0.5257],
        [-0.4989],
        [-0.4790],
        [-0.4778],
        [-0.4789],
        [-0.4532],
        [-0.4123],
        [-0.3854],
        [-0.3704],
        [-0.3600],
        [-0.3510],
        [-0.3439],
        [-0.3298],
        [-0.3283],
        [-0.3443],
        [-0.3707],
        [-0.3992],
        [-0.4071],
        [-0.3948],
        [-0.3800],
        [-0.3814],
        [-0.4184],
        [-0.4819],
        [-0.5290],
        [-0.5132],
        [-0.4805],
        [-0.4720],
        [-0.5133]], grad_fn=<LogBackward0>)

The following is the derivative of tqs._lnP with respect to a single spin chain:

In [65]:
params = {k: v.detach() for k, v in tqs.named_parameters()}
buffers = {k: v.detach() for k, v in tqs.named_buffers()}
deriv = jacrev(tqs._ln_P, argnums=2)(a_spin_chain, a_potential_func, params, buffers)
deriv

{'embedding.base_potential_vect': tensor([[[ 0.0056,  0.0063, -0.0010,  ..., -0.0011, -0.0080, -0.0040]],
 
         [[ 0.0046,  0.0044, -0.0012,  ..., -0.0010, -0.0060, -0.0018]],
 
         [[ 0.0030,  0.0002, -0.0004,  ...,  0.0001, -0.0014, -0.0009]],
 
         ...,
 
         [[ 0.0023,  0.0015, -0.0005,  ..., -0.0007, -0.0019, -0.0017]],
 
         [[ 0.0025,  0.0019,  0.0005,  ..., -0.0010, -0.0027, -0.0019]],
 
         [[ 0.0030,  0.0026,  0.0019,  ..., -0.0017, -0.0033, -0.0017]]],
        grad_fn=<ViewBackward0>),
 'embedding.base_spin_vect': tensor([[[ 3.7625e-04,  3.6970e-04, -9.7771e-05,  ...,  2.6439e-05,
           -3.5320e-04, -4.2557e-04]],
 
         [[ 4.6195e-04,  4.1429e-04, -1.6471e-04,  ...,  6.2683e-05,
           -3.7176e-04, -5.2453e-04]],
 
         [[-1.4932e-02,  1.7162e-05, -1.1031e-02,  ..., -1.2226e-02,
           -1.0215e-02,  2.7979e-02]],
 
         ...,
 
         [[ 2.4470e-04,  1.8577e-04, -9.1214e-05,  ...,  4.9140e-05,
           -1.7601e-04, -

...so now we need to map this single-spin-chain `jacrev(tqs)` operator over the spin chains in our batch
- (noting that `jacrev(tqs)` is *itself* an operator)

In [66]:
params = {k: v.detach() for k, v in tqs.named_parameters()}
buffers = {k: v.detach() for k, v in tqs.named_buffers()}
dlnP_dTheta_sample = jacrev(tqs._ln_P, argnums=2)
vmap_dlnP_dTheta = vmap(dlnP_dTheta_sample, in_dims=(1, 1, None, None))
per_sample_grads = vmap_dlnP_dTheta(sampled_spins, test_potentials, params, buffers)
per_sample_grads

{'embedding.base_potential_vect': tensor([[[[ 6.7307e-03,  7.4979e-03, -1.4580e-03,  ..., -7.5334e-04,
            -8.7267e-03, -4.8599e-03]],
 
          [[ 5.6669e-03,  5.4228e-03, -1.5140e-03,  ..., -1.1094e-03,
            -7.2215e-03, -2.6041e-03]],
 
          [[ 5.3460e-03,  4.2751e-03, -1.5961e-03,  ..., -8.4391e-04,
            -6.2321e-03, -2.0589e-03]],
 
          ...,
 
          [[ 3.4617e-03,  2.1413e-03, -6.8780e-04,  ..., -6.5126e-04,
            -2.9268e-03, -2.8343e-03]],
 
          [[ 3.4734e-03,  2.5773e-03,  4.1507e-04,  ..., -9.6181e-04,
            -3.6066e-03, -3.2446e-03]],
 
          [[ 4.1812e-03,  3.4472e-03,  2.0736e-03,  ..., -1.7527e-03,
            -4.2566e-03, -3.0993e-03]]],
 
 
         [[[ 5.5753e-03,  6.3499e-03, -9.8158e-04,  ..., -1.1419e-03,
            -8.0055e-03, -4.0245e-03]],
 
          [[ 4.5625e-03,  4.3962e-03, -1.1791e-03,  ..., -1.0387e-03,
            -6.0414e-03, -1.8027e-03]],
 
          [[ 2.9748e-03,  1.7371e-04, -4.4045e-04, 

The output above is $\frac{d \ln P}{d\Theta}$!

In [67]:
per_sample_grads.keys()

dict_keys(['embedding.base_potential_vect', 'embedding.base_spin_vect', 'encoder_layer.self_attn.in_proj_weight', 'encoder_layer.self_attn.in_proj_bias', 'encoder_layer.self_attn.out_proj.weight', 'encoder_layer.self_attn.out_proj.bias', 'encoder_layer.linear1.weight', 'encoder_layer.linear1.bias', 'encoder_layer.linear2.weight', 'encoder_layer.linear2.bias', 'encoder_layer.norm1.weight', 'encoder_layer.norm1.bias', 'encoder_layer.norm2.weight', 'encoder_layer.norm2.bias', 'encoder.layers.0.self_attn.in_proj_weight', 'encoder.layers.0.self_attn.in_proj_bias', 'encoder.layers.0.self_attn.out_proj.weight', 'encoder.layers.0.self_attn.out_proj.bias', 'encoder.layers.0.linear1.weight', 'encoder.layers.0.linear1.bias', 'encoder.layers.0.linear2.weight', 'encoder.layers.0.linear2.bias', 'encoder.layers.0.norm1.weight', 'encoder.layers.0.norm1.bias', 'encoder.layers.0.norm2.weight', 'encoder.layers.0.norm2.bias', 'prob_head.weight', 'prob_head.bias', 'phase_head.weight', 'phase_head.bias'])

Apply the same concept to $E_{l}$ to get $\frac{d E_\text{l}}{d \Theta}$

In [68]:
params = {k: v.detach() for k, v in tqs.named_parameters()}
buffers = {k: v.detach() for k, v in tqs.named_buffers()}

T = 1.0

tqs._E_loc(sampled_spins, test_potentials, T, params, buffers)

Psi shapes: torch.Size([50, 32]) torch.Size([50, 32]) torch.Size([50, 32])
V shapes: torch.Size([50, 32]) torch.Size([50, 32]) torch.Size([50, 32])


tensor([[-1.9956-0.2130j, -1.9893-1.5266j, -0.0678-0.7360j,  ...,
          1.3918+0.9762j, -0.0460-0.3618j, -0.0903-0.8115j],
        [-2.2513+1.5457j, -0.3085+0.3771j,  0.7130+0.0444j,  ...,
         -0.0097-1.5441j,  2.0001+2.0099j, -0.0885+0.2834j],
        [-0.0160-0.0173j,  0.9582+0.6681j, -0.0230-0.2256j,  ...,
         -1.9646+0.2171j, -0.0563-0.4267j, -1.9355+0.2603j],
        ...,
        [-2.2554-0.3734j, -2.2473-0.3462j, -2.0657-0.0813j,  ...,
         -2.2884-0.4081j, -1.9041+0.1191j, -1.5597+0.4852j],
        [-3.0016-1.3327j, -1.8036+0.2482j, -2.3477-0.3860j,  ...,
         -2.4944-0.6372j, -1.5229+0.5422j, -2.1384-0.1384j],
        [-1.9989+0.0042j, -1.6929+1.1325j, -1.9964+0.0060j,  ...,
          0.1907+0.4852j, -2.0707-0.2003j,  0.0591-0.3885j]])

In [69]:
dE_loc_dTheta = jacrev(tqs._E_loc, argnums=3)(
    a_spin_chain, a_potential_func, T, params, buffers
)

Psi shapes: torch.Size([50, 1]) torch.Size([50, 1]) torch.Size([50, 1])
V shapes: torch.Size([50, 1]) torch.Size([50, 1]) torch.Size([50, 1])


RuntimeError: jacrev: Expected all outputs to be real but received complex tensor at flattened input idx: 0