In [618]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [619]:
from model import TQS
import torch

In [620]:
EMBED_DIM = 64
MAX_LENGTH = 100
NUM_HEADS = 1
NUM_LAYERS = 1
DIM_FEEDFORWARD = 128
TEST_LENGTH = 50
TEST_BATCH = 32

In [621]:
test_potentials = torch.randn(TEST_LENGTH, TEST_BATCH)
test_spins = torch.zeros(TEST_LENGTH, TEST_BATCH)
test_spins[torch.randint(0, TEST_LENGTH, (TEST_BATCH,)), torch.arange(TEST_BATCH)] = 1

In [622]:
tqs = TQS(
    embed_dim=EMBED_DIM,
    max_chain_len=MAX_LENGTH,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    possible_spins=2,
    dim_feedforward=DIM_FEEDFORWARD,
)



# Model Output

In [623]:
probs, phases = tqs(test_potentials, test_spins)

In [624]:
probs

tensor([[[0.4685, 0.5502],
         [0.4831, 0.5676],
         [0.3663, 0.5296],
         ...,
         [0.4775, 0.5025],
         [0.4807, 0.5330],
         [0.3822, 0.5043]],

        [[0.5155, 0.5580],
         [0.4286, 0.5356],
         [0.3694, 0.3971],
         ...,
         [0.3951, 0.4800],
         [0.3627, 0.4351],
         [0.3857, 0.4661]],

        [[0.3079, 0.4326],
         [0.3701, 0.4135],
         [0.3565, 0.4776],
         ...,
         [0.4106, 0.4472],
         [0.3604, 0.4803],
         [0.4186, 0.5195]],

        ...,

        [[0.1973, 0.5313],
         [0.2029, 0.5303],
         [0.2131, 0.5967],
         ...,
         [0.2172, 0.5763],
         [0.2065, 0.4921],
         [0.1831, 0.5414]],

        [[0.2584, 0.5622],
         [0.2390, 0.5443],
         [0.2392, 0.5655],
         ...,
         [0.2285, 0.5596],
         [0.2363, 0.5448],
         [0.2393, 0.5657]],

        [[0.2659, 0.5688],
         [0.2887, 0.5277],
         [0.2696, 0.5833],
         ...,
 

# Autoregressive Sampling

In [625]:
def autoregressive_sample(model, initial_potentials, max_length):
    model.eval()
    # Allocate a buffer for the sampled basis states
    sampled_spins = torch.zeros(max_length, initial_potentials.size(1))
    batch_size = initial_potentials.size(1)
    batch_remaining_idx = torch.arange(batch_size)

    for i in range(max_length):
        # get P(s_{i+1} | V, s_{1:i}) and phi(s_{i+1} | V, s_{1:i}) distributions
        probs, _ = model(
            initial_potentials[:, batch_remaining_idx],
            sampled_spins[:i, batch_remaining_idx],
        )

        # sample s_{i+1} from P(s_{i+1} | V, s_{1:i})
        last_probs = probs[-1]  # (batch, 2); P(s_{i+1} | V, s_{1:i})
        sampled_spins[i, batch_remaining_idx] = (
            torch.multinomial(last_probs, 1).squeeze().float()
        )

        # a mask with dimension (batch_remaining,); True if we sampled a 1
        newly_completed = sampled_spins[i, batch_remaining_idx] == 1.0

        # mask out the batch_remaining_idx that have been completed
        batch_remaining_idx = batch_remaining_idx[~newly_completed]

    return sampled_spins


sampled_spins = autoregressive_sample(tqs, test_potentials, TEST_LENGTH)
sampled_spins

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 1., 1., 1.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [626]:
sampled_spins.sum(dim=0)

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [627]:
sampled_spins.shape

torch.Size([50, 32])

In [628]:
sampled_spins = tqs.sample_spins(test_potentials, TEST_LENGTH)
sampled_spins

tensor([[0., 1., 1.,  ..., 1., 0., 0.],
        [1., 0., 0.,  ..., 0., 1., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [629]:
assert (
    sampled_spins.sum(dim=0) == 1.0
).all()  # each batch should have exactly one 1 spin

In [630]:
TEST_BATCH

32

In [631]:
TEST_LENGTH

50

In [632]:
T = 1.0
print(tqs.compute_psi(sampled_spins, test_potentials, T).shape)

torch.Size([50, 32])


In [633]:
psi_x, psi_l, psi_r = tqs.psi_terms(sampled_spins, test_potentials, T)
assert (
    psi_x.shape == psi_l.shape == psi_r.shape == torch.Size([TEST_LENGTH, TEST_BATCH])
)
psi_x, psi_l, psi_r

(tensor([[ 0.6965-0.1732j,  0.2971-0.6767j,  0.2579-0.6938j,  ...,
           0.3527-0.6484j,  0.7091-0.1241j,  0.6904-0.2013j],
         [ 0.4320-0.5739j,  0.6759+0.0923j,  0.6806-0.0020j,  ...,
           0.6415+0.2341j,  0.4454-0.5647j,  0.4077-0.5937j],
         [ 0.6814-0.1205j,  0.6776-0.1526j,  0.6497-0.2423j,  ...,
           0.6965+0.0055j,  0.6930-0.0560j,  0.6752-0.1600j],
         ...,
         [-0.5646+0.4825j, -0.5629+0.4853j, -0.5520+0.4984j,  ...,
          -0.5897+0.4516j, -0.5793+0.4649j, -0.5622+0.4862j],
         [-0.4920+0.5745j, -0.4923+0.5744j, -0.4857+0.5797j,  ...,
          -0.5279+0.5422j, -0.5137+0.5558j, -0.4917+0.5749j],
         [-0.4723+0.5699j, -0.4764+0.5660j, -0.4802+0.5628j,  ...,
          -0.5119+0.5360j, -0.4970+0.5496j, -0.4755+0.5667j]],
        grad_fn=<MulBackward0>),
 tensor([[ 0.6965-0.1739j,  0.6904-0.2015j,  0.6652-0.2727j,  ...,
           0.7163-0.0793j,  0.7092-0.1250j,  0.6917-0.1954j],
         [ 0.6675+0.1256j,  0.4044-0.5963j,  0.36

In [634]:
E_loc = tqs.E_loc(psi_x, psi_l, psi_r, test_potentials, T)
E_loc

tensor([[-1.7186+0.8070j, -1.7591-0.3370j, -1.2952-1.4781j,  ...,
         -1.3619-1.1539j, -2.2010+0.9252j, -1.5639+0.7484j],
        [-1.6782-0.5832j, -1.9106+0.8872j, -1.1814+0.9077j,  ...,
         -1.5238+0.9304j,  0.0401-2.7400j, -0.1941-2.6570j],
        [-1.1188+0.7164j, -1.1177-0.1952j, -2.0296+0.0053j,  ...,
          0.2159+0.0247j, -1.5791+0.8763j, -2.2518+0.9412j],
        ...,
        [-1.9615-0.0327j, -2.2533+0.2210j, -2.5262+0.4753j,  ...,
         -2.1406+0.1091j, -1.9447-0.0432j, -2.6785+0.5875j],
        [-1.2037-0.9294j, -1.9260-0.0843j, -1.6937-0.3651j,  ...,
         -1.4054-0.6091j, -1.6908-0.3349j, -1.5309-0.5471j],
        [-2.3453+0.4176j, -0.8352+1.1407j, -0.2810+0.4942j,  ...,
         -0.3876+0.6609j, -2.8730+0.9634j, -0.9564-1.2422j]],
       grad_fn=<AddBackward0>)

In [635]:
psi_x

tensor([[ 0.6965-0.1732j,  0.2971-0.6767j,  0.2579-0.6938j,  ...,
          0.3527-0.6484j,  0.7091-0.1241j,  0.6904-0.2013j],
        [ 0.4320-0.5739j,  0.6759+0.0923j,  0.6806-0.0020j,  ...,
          0.6415+0.2341j,  0.4454-0.5647j,  0.4077-0.5937j],
        [ 0.6814-0.1205j,  0.6776-0.1526j,  0.6497-0.2423j,  ...,
          0.6965+0.0055j,  0.6930-0.0560j,  0.6752-0.1600j],
        ...,
        [-0.5646+0.4825j, -0.5629+0.4853j, -0.5520+0.4984j,  ...,
         -0.5897+0.4516j, -0.5793+0.4649j, -0.5622+0.4862j],
        [-0.4920+0.5745j, -0.4923+0.5744j, -0.4857+0.5797j,  ...,
         -0.5279+0.5422j, -0.5137+0.5558j, -0.4917+0.5749j],
        [-0.4723+0.5699j, -0.4764+0.5660j, -0.4802+0.5628j,  ...,
         -0.5119+0.5360j, -0.4970+0.5496j, -0.4755+0.5667j]],
       grad_fn=<MulBackward0>)

In [636]:
for i in tqs.parameters():
    print(i)

Parameter containing:
tensor([ 0.4510,  1.7612,  0.0871,  0.2424,  0.1194,  0.0306, -0.7946,  0.1000,
        -0.6361, -0.7865,  1.0720, -1.7430, -0.6640,  0.5250,  0.7989, -0.8286,
         0.1016, -0.7357, -1.6222,  0.6305,  0.8901,  0.1104,  1.0467, -0.9099,
         0.8417, -0.0781,  0.5098,  0.2046, -0.9705, -0.2621,  0.0456, -1.3932,
         1.7074,  0.4026,  2.1397,  0.2063,  0.9070,  0.9167,  0.0660,  1.1285,
         0.6552,  1.6142,  0.3154, -0.4186, -1.9631, -2.1122, -0.5247,  1.4032,
         0.8821,  1.0708, -1.1718,  1.0588,  0.9746, -0.1113, -0.9600,  1.8159,
        -0.5182,  0.1059, -0.3379,  0.5837, -1.3769, -0.7077,  0.7027, -0.0786],
       requires_grad=True)
Parameter containing:
tensor([ 0.0074,  0.3213, -1.0453, -0.0669,  0.3513, -0.2405, -1.9639,  1.1031,
         2.2512,  0.2668,  0.4270,  0.3283,  0.5749,  1.2590, -1.2979, -2.1353,
         0.3623,  0.3440,  0.9656, -1.2559,  0.3636,  1.7245,  0.2520,  0.3070,
         0.8970,  0.2844,  1.0224, -0.0115,  0.4

In [637]:
psi_x.shape

torch.Size([50, 32])

In [638]:
# https://discuss.pytorch.org/t/example-for-one-of-the-differentiated-tensors-appears-to-not-have-been-used-in-the-graph/58396
a = torch.rand(10, requires_grad=True)
b = torch.rand(10, requires_grad=True)

output = (2 * a).sum()

torch.autograd.grad(output, (a), allow_unused=True)

(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),)

In [639]:
print((d_ln_P_d_theta := tqs.d_ln_P_dtheta(psi_x)))

ValueError: NestedTensor tensor_attr_supported_getter(self: jt_all): expected self to be a jagged layout NestedTensor

In [None]:
print(len(d_ln_P_d_theta))

30


In [None]:
# Print the parameters shapes of the model
for i in tqs.parameters():
    print(i.shape)

torch.Size([64])
torch.Size([64])
torch.Size([192, 64])
torch.Size([192])
torch.Size([64, 64])
torch.Size([64])
torch.Size([128, 64])
torch.Size([128])
torch.Size([64, 128])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([192, 64])
torch.Size([192])
torch.Size([64, 64])
torch.Size([64])
torch.Size([128, 64])
torch.Size([128])
torch.Size([64, 128])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([2, 64])
torch.Size([2])
torch.Size([1, 64])
torch.Size([1])


In [None]:
for i in tqs.parameters():
    print(i)

Parameter containing:
tensor([ 0.9367,  0.2668, -1.6303, -1.7862,  0.8277,  0.1077,  1.8067, -2.5265,
         0.7316, -1.4697, -0.4042, -0.8399, -0.2293,  0.3884,  0.7997,  1.0304,
         0.6088, -0.0918, -0.6525, -0.7985,  0.5418,  0.5808,  0.5402,  0.6873,
         0.1134,  0.5988, -1.9142, -0.8318,  0.6114, -0.2724, -0.2606, -1.3139,
         1.0847,  1.9056, -1.2572,  1.5666,  0.2288, -0.4180,  0.6461, -0.3176,
        -0.3120, -0.2958, -0.3488, -0.3325, -0.6533, -0.8155, -0.2047,  0.2941,
        -1.8817, -2.4597, -1.3084,  1.9289,  1.8321, -0.0902,  0.4203,  1.6066,
        -0.7052, -0.0258, -0.1230,  0.3112,  1.2757,  0.0777, -1.9691, -0.9766],
       requires_grad=True)
Parameter containing:
tensor([ 0.4751, -0.3406,  0.2423,  0.2803,  1.7058,  0.7552, -0.9422,  0.0115,
         0.9143,  0.5812,  0.8512, -0.0914,  0.6948, -0.4445,  0.3403, -1.9351,
        -0.4757, -0.5741, -1.0688,  0.9787, -0.3464, -0.4479,  1.8879, -0.2108,
         1.2273,  0.2867,  0.2962, -1.8591,  1.5

In [None]:
for grad, param in zip(d_ln_P_d_theta, tqs.parameters()):
    if grad is not None:
        print(grad.shape, param.shape)

torch.Size([64]) torch.Size([64])
torch.Size([64]) torch.Size([64])
torch.Size([192, 64]) torch.Size([192, 64])
torch.Size([192]) torch.Size([192])
torch.Size([64, 64]) torch.Size([64, 64])
torch.Size([64]) torch.Size([64])
torch.Size([128, 64]) torch.Size([128, 64])
torch.Size([128]) torch.Size([128])
torch.Size([64, 128]) torch.Size([64, 128])
torch.Size([64]) torch.Size([64])
torch.Size([64]) torch.Size([64])
torch.Size([64]) torch.Size([64])
torch.Size([64]) torch.Size([64])
torch.Size([64]) torch.Size([64])
torch.Size([2, 64]) torch.Size([2, 64])
torch.Size([2]) torch.Size([2])
torch.Size([1, 64]) torch.Size([1, 64])
torch.Size([1]) torch.Size([1])


In [None]:
d_ln_P_d_theta

(tensor([ 12.0460,  -5.3368,  -5.4714,   6.7329, -21.8639,  16.2115,  -1.8841,
           2.7430,   6.1641,  25.9044,  -1.0642,   8.0647, -14.6524,   1.7242,
           3.3404,  19.5245,  -1.9994,  16.3765,   3.0707,  -3.7363,   5.9793,
          -3.7739,  16.6138,  -8.9862,   9.6171, -33.6170,  14.7179,  -7.1137,
          10.2698,   0.3457,   9.2687,   1.9637,  -3.1346, -19.9616,   6.5355,
           5.2249,  11.2312,   8.6126,  -4.7546,   2.2451,   3.1444, -15.6841,
          -0.5689,  -1.4260,  23.9463,   7.9263,   2.7100,   7.9186,  -0.6613,
           6.2767,  -8.4629,  -2.8857,  -9.8538,   7.5645, -11.8798,  13.5053,
          -9.0124,  16.5284,  -0.6507,  14.4156, -19.2918,  -6.8954,  -2.0077,
           8.1826], grad_fn=<ViewBackward0>),
 tensor([-1.4482,  0.8351,  1.4949, -1.6435,  0.7335, -1.1153,  0.9714,  0.2423,
         -2.1453, -1.2152,  0.3682,  0.1144, -0.5094,  0.1646, -0.4671, -0.9084,
         -1.0607,  0.3759,  0.3585,  0.0600,  0.9356,  1.6283, -0.6500, -0.7605,


In [None]:
torch.ones_like(psi_x).shape

torch.Size([50, 32])

In [None]:
print((P := psi_x * psi_x.conj()).shape)

torch.Size([50, 32])


In [None]:
ex_re = torch.randn(10, 2)
ex_im = torch.randn(10, 2)
ex_com = torch.complex(ex_re, ex_im)
ex_com

tensor([[ 0.5769-6.1739e-01j, -0.6152+8.4863e-01j],
        [ 1.7631+7.5909e-01j,  0.1033+1.3899e+00j],
        [ 0.3123-2.5167e-03j,  0.0912-1.1067e+00j],
        [ 0.9518-1.4823e+00j, -0.5423+4.7282e-01j],
        [ 1.5811+1.1638e+00j,  0.5020+2.9848e+00j],
        [-1.6522-1.4240e+00j, -0.2496-1.8722e-01j],
        [-1.4371-4.0499e-01j,  1.7068+1.1778e+00j],
        [-1.1488+1.8205e+00j, -0.1983-1.2034e+00j],
        [ 0.2536+7.6231e-01j,  0.9571+1.1045e+00j],
        [ 1.5606+4.7761e-01j,  0.1221-2.1698e-01j]])

In [None]:
ex_com.norm(dim=0)

tensor([5.1365, 4.6943])

What shape is psi_x * psi_x.conj()?

In [None]:
print(psi_x.shape)
print((psi_x * psi_x.conj()).shape)  # This produces a probability distribution

torch.Size([50, 32])
torch.Size([50, 32])


In [640]:
from torch import vmap

N = 5
M = 12  # Batch dimension
x = torch.randn(N, requires_grad=True)
y = torch.randn(N, requires_grad=True)
f = lambda x, y: x**2 + y**2
z = f(x, y)
print(z.shape)


def get_vjp(v):
    return torch.autograd.grad(z, x, grad_outputs=v, retain_graph=True)


(jacobian := torch.vmap(get_vjp, randomness="same")(torch.eye(N)))

torch.Size([5])


(tensor([[ 0.3270, -0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -2.3047,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0000,  0.4629,  0.0000,  0.0000],
         [ 0.0000, -0.0000,  0.0000,  5.6694,  0.0000],
         [ 0.0000, -0.0000,  0.0000,  0.0000,  1.4428]]),)

In [None]:
N = 5
M = 12  # Batch dimension
x = torch.randn((N, M), requires_grad=True)
y = torch.randn((N, M), requires_grad=True)
f = lambda x, y: x**2 + y**2
z = f(x, y)
print(z.shape)


def get_vjp(v):
    return torch.autograd.grad(z, x, grad_outputs=v, retain_graph=True)


(jacobian := torch.vmap(get_vjp, in_dims=1, randomness="same")(torch.eye(M)))

torch.Size([5, 12])


ValueError: NestedTensor tensor_attr_supported_getter(self: jt_all): expected self to be a jagged layout NestedTensor

In [644]:
# torch.dot has type [D], [D] -> []

# What if we wanted to apply it across a batch dimension where the second dim is the batch dim? [N, D], [N, D] -> [D]

batched_dot = vmap(torch.dot, in_dims=1)
x, y = torch.randn(5, 3), torch.randn(5, 3)
batched_dot(x, y)

tensor([ 0.1966, -0.7046,  0.8144])

We can use this to map torch.autograd.grad over a batch of inputs:

In [648]:
# A function that takes in a sample of dimension (N,) and returns the gradient with respect to
# x
def grad_on_seq(v):
    return torch.autograd.grad(z, x, grad_outputs=v, retain_graph=True)


grad_vmap = vmap(grad_on_seq, in_dims=1, randomness="same")
grad_vmap(torch.eye(N))

ValueError: NestedTensor tensor_attr_supported_getter(self: jt_all): expected self to be a jagged layout NestedTensor