In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from model import TQS
import torch

In [3]:
EMBED_DIM = 64
MAX_LENGTH = 100
NUM_HEADS = 1
NUM_LAYERS = 1
DIM_FEEDFORWARD = 128
TEST_LENGTH = 50
TEST_BATCH = 32

In [4]:
test_potentials = torch.randn(TEST_LENGTH, TEST_BATCH)
test_spins = torch.zeros(TEST_LENGTH, TEST_BATCH)
test_spins[torch.randint(0, TEST_LENGTH, (TEST_BATCH,)), torch.arange(TEST_BATCH)] = 1

In [5]:
tqs = TQS(
    embed_dim=EMBED_DIM,
    max_chain_len=MAX_LENGTH,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    possible_spins=2,
    dim_feedforward=DIM_FEEDFORWARD,
)



# Model Output

In [6]:
probs, phases = tqs(test_potentials, test_spins)

In [7]:
probs

tensor([[[0.3880, 0.6342],
         [0.3739, 0.5909],
         [0.3262, 0.1927],
         ...,
         [0.2705, 0.2694],
         [0.4124, 0.2042],
         [0.3862, 0.1753]],

        [[0.2565, 0.3889],
         [0.2719, 0.2789],
         [0.2870, 0.5833],
         ...,
         [0.2133, 0.3178],
         [0.4191, 0.2030],
         [0.2331, 0.3933]],

        [[0.2866, 0.5047],
         [0.2910, 0.5431],
         [0.3412, 0.5623],
         ...,
         [0.3558, 0.6297],
         [0.2745, 0.4499],
         [0.2557, 0.4731]],

        ...,

        [[0.5138, 0.6078],
         [0.5418, 0.6153],
         [0.5583, 0.6089],
         ...,
         [0.5002, 0.5850],
         [0.5501, 0.6191],
         [0.5615, 0.5280]],

        [[0.5419, 0.5915],
         [0.5416, 0.6317],
         [0.5403, 0.6200],
         ...,
         [0.5285, 0.6102],
         [0.5311, 0.6396],
         [0.5144, 0.6207]],

        [[0.5273, 0.7059],
         [0.5227, 0.6584],
         [0.5212, 0.6967],
         ...,
 

# Autoregressive Sampling

In [8]:
def autoregressive_sample(model, initial_potentials, max_length):
    model.eval()
    # Allocate a buffer for the sampled basis states
    sampled_spins = torch.zeros(max_length, initial_potentials.size(1))
    batch_size = initial_potentials.size(1)
    batch_remaining_idx = torch.arange(batch_size)

    for i in range(max_length):
        # get P(s_{i+1} | V, s_{1:i}) and phi(s_{i+1} | V, s_{1:i}) distributions
        probs, _ = model(
            initial_potentials[:, batch_remaining_idx],
            sampled_spins[:i, batch_remaining_idx],
        )

        # sample s_{i+1} from P(s_{i+1} | V, s_{1:i})
        last_probs = probs[-1]  # (batch, 2); P(s_{i+1} | V, s_{1:i})
        sampled_spins[i, batch_remaining_idx] = (
            torch.multinomial(last_probs, 1).squeeze().float()
        )

        # a mask with dimension (batch_remaining,); True if we sampled a 1
        newly_completed = sampled_spins[i, batch_remaining_idx] == 1.0

        # mask out the batch_remaining_idx that have been completed
        batch_remaining_idx = batch_remaining_idx[~newly_completed]

    return sampled_spins


sampled_spins = autoregressive_sample(tqs, test_potentials, TEST_LENGTH)
sampled_spins

tensor([[0., 0., 0.,  ..., 1., 1., 1.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [9]:
sampled_spins.sum(dim=0)

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [10]:
sampled_spins.shape

torch.Size([50, 32])

In [11]:
sampled_spins = tqs.sample_spins(test_potentials, TEST_LENGTH)
sampled_spins

tensor([[1., 0., 1.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [12]:
assert (
    sampled_spins.sum(dim=0) == 1.0
).all()  # each batch should have exactly one 1 spin

In [13]:
TEST_BATCH

32

In [14]:
TEST_LENGTH

50

In [15]:
T = 1.0
print(tqs.compute_psi(sampled_spins, test_potentials, T).shape)

torch.Size([50, 32])


In [16]:
psi_x, psi_l, psi_r = tqs.psi_terms(sampled_spins, test_potentials, T)
assert (
    psi_x.shape == psi_l.shape == psi_r.shape == torch.Size([TEST_LENGTH, TEST_BATCH])
)
psi_x, psi_l, psi_r

(tensor([[ 8.6045e-01-0.2001j,  4.7735e-01+0.1902j,  8.6123e-01-0.1947j,
           ...,  8.5597e-01-0.2224j,  8.2973e-01-0.3186j,
           8.5885e-01-0.2116j],
         [ 4.7721e-01-0.2207j,  4.4824e-01-0.3046j,  4.7875e-01-0.2148j,
           ...,  4.8823e-01-0.1710j,  4.3818e-01-0.3268j,
           4.7186e-01-0.2397j],
         [ 4.2254e-01-0.3376j,  6.4744e-01-0.6134j,  4.2621e-01-0.3316j,
           ...,  4.5389e-01-0.2752j,  3.5103e-01-0.4388j,
           4.1185e-01-0.3547j],
         ...,
         [ 2.7540e-01+0.7052j,  3.7484e-01+0.6695j,  2.6165e-01+0.7087j,
           ...,  2.1205e-01+0.7190j,  4.1202e-01+0.6511j,
           2.9132e-01+0.7007j],
         [-1.3793e-01+0.7691j, -4.6143e-02+0.7894j, -1.4981e-01+0.7655j,
           ..., -1.8211e-01+0.7545j, -5.2191e-04+0.7942j,
          -1.2066e-01+0.7743j],
         [-3.8453e-01+0.7177j, -3.1498e-01+0.7600j, -3.9147e-01+0.7128j,
           ..., -4.0423e-01+0.7029j, -2.7110e-01+0.7805j,
          -3.7147e-01+0.7270j]], grad_fn

In [17]:
E_loc = tqs.E_loc(psi_x, psi_l, psi_r, test_potentials, T)
E_loc

tensor([[-2.0252-0.5088j, -2.5745-0.2215j, -0.2366-0.9237j,  ...,
         -0.8084-0.8268j, -0.1522-1.0205j, -0.2360-0.9256j],
        [-2.7907+0.3581j, -2.5697+0.1430j, -3.0114+0.4719j,  ...,
         -2.6933+0.5022j, -2.0626-0.2703j, -2.7904+0.3153j],
        [-2.2532+0.2022j, -1.7075+0.5655j, -2.4258+0.3280j,  ...,
         -2.5045+0.3018j, -2.0796+0.1011j, -2.1521+0.1252j],
        ...,
        [-2.0713-0.1817j, -1.8650+0.2448j, -2.0019-0.0023j,  ...,
         -1.9699+0.1073j, -1.7641+0.3744j, -2.6508-1.5726j],
        [-1.8399-0.8943j, -1.9672-0.5611j, -1.9661-0.1751j,  ...,
         -1.6152-1.5932j, -2.0006+0.4676j, -1.9523-0.3093j],
        [-1.1517+0.7919j, -2.2078+0.5073j, -0.9665+0.4490j,  ...,
         -1.1032+0.8251j, -1.1559+0.5609j, -1.1045+0.6404j]],
       grad_fn=<AddBackward0>)

In [18]:
psi_x

tensor([[ 8.6045e-01-0.2001j,  4.7735e-01+0.1902j,  8.6123e-01-0.1947j,
          ...,  8.5597e-01-0.2224j,  8.2973e-01-0.3186j,
          8.5885e-01-0.2116j],
        [ 4.7721e-01-0.2207j,  4.4824e-01-0.3046j,  4.7875e-01-0.2148j,
          ...,  4.8823e-01-0.1710j,  4.3818e-01-0.3268j,
          4.7186e-01-0.2397j],
        [ 4.2254e-01-0.3376j,  6.4744e-01-0.6134j,  4.2621e-01-0.3316j,
          ...,  4.5389e-01-0.2752j,  3.5103e-01-0.4388j,
          4.1185e-01-0.3547j],
        ...,
        [ 2.7540e-01+0.7052j,  3.7484e-01+0.6695j,  2.6165e-01+0.7087j,
          ...,  2.1205e-01+0.7190j,  4.1202e-01+0.6511j,
          2.9132e-01+0.7007j],
        [-1.3793e-01+0.7691j, -4.6143e-02+0.7894j, -1.4981e-01+0.7655j,
          ..., -1.8211e-01+0.7545j, -5.2191e-04+0.7942j,
         -1.2066e-01+0.7743j],
        [-3.8453e-01+0.7177j, -3.1498e-01+0.7600j, -3.9147e-01+0.7128j,
          ..., -4.0423e-01+0.7029j, -2.7110e-01+0.7805j,
         -3.7147e-01+0.7270j]], grad_fn=<MulBackward0>)

In [19]:
for i in tqs.parameters():
    print(i)

Parameter containing:
tensor([-9.0678e-01, -3.4904e-01, -5.8869e-01, -5.8170e-02, -4.1572e-01,
         3.4841e-01,  8.9565e-01,  8.1810e-01, -3.4364e-01,  2.9650e-01,
        -2.8507e-01,  1.6111e-03, -1.2105e+00, -1.5277e+00, -8.0490e-01,
         1.0923e+00, -4.3110e-01,  1.2369e+00, -5.4244e-01,  1.2190e+00,
        -1.4323e+00, -2.0283e-01,  9.3111e-01, -3.9909e-01,  1.4032e+00,
         1.0671e+00, -1.0788e-01,  1.7218e-01,  4.4500e-01, -9.8702e-01,
         1.4303e+00,  3.1006e-02,  5.0137e-01,  5.9805e-01, -3.8208e-01,
        -1.9614e+00, -1.5861e+00,  2.2086e-03,  1.3585e+00,  8.5802e-01,
        -1.9340e-01,  2.1199e-01,  2.2446e-01, -1.6655e+00,  3.4554e-01,
        -1.0952e+00, -8.4067e-01, -8.3924e-01, -4.1254e-01,  1.0622e+00,
         5.1169e-01, -1.8868e+00,  2.0948e+00,  4.1218e-01, -1.1691e+00,
         8.3573e-01,  9.1111e-01, -2.5913e-01,  3.3681e-01, -6.5043e-01,
         4.2778e-01, -4.1978e-01, -9.3214e-02,  5.1071e-01],
       requires_grad=True)
Parameter cont

In [20]:
psi_x.shape

torch.Size([50, 32])

In [21]:
psi_x

tensor([[ 8.6045e-01-0.2001j,  4.7735e-01+0.1902j,  8.6123e-01-0.1947j,
          ...,  8.5597e-01-0.2224j,  8.2973e-01-0.3186j,
          8.5885e-01-0.2116j],
        [ 4.7721e-01-0.2207j,  4.4824e-01-0.3046j,  4.7875e-01-0.2148j,
          ...,  4.8823e-01-0.1710j,  4.3818e-01-0.3268j,
          4.7186e-01-0.2397j],
        [ 4.2254e-01-0.3376j,  6.4744e-01-0.6134j,  4.2621e-01-0.3316j,
          ...,  4.5389e-01-0.2752j,  3.5103e-01-0.4388j,
          4.1185e-01-0.3547j],
        ...,
        [ 2.7540e-01+0.7052j,  3.7484e-01+0.6695j,  2.6165e-01+0.7087j,
          ...,  2.1205e-01+0.7190j,  4.1202e-01+0.6511j,
          2.9132e-01+0.7007j],
        [-1.3793e-01+0.7691j, -4.6143e-02+0.7894j, -1.4981e-01+0.7655j,
          ..., -1.8211e-01+0.7545j, -5.2191e-04+0.7942j,
         -1.2066e-01+0.7743j],
        [-3.8453e-01+0.7177j, -3.1498e-01+0.7600j, -3.9147e-01+0.7128j,
          ..., -4.0423e-01+0.7029j, -2.7110e-01+0.7805j,
         -3.7147e-01+0.7270j]], grad_fn=<MulBackward0>)

In [22]:
# https://discuss.pytorch.org/t/example-for-one-of-the-differentiated-tensors-appears-to-not-have-been-used-in-the-graph/58396
a = torch.rand(10, requires_grad=True)
b = torch.rand(10, requires_grad=True)

output = (2 * a).sum()

torch.autograd.grad(output, (a), allow_unused=True)

(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),)

In [24]:
print(test_potentials.shape)
print(sampled_spins.shape)

torch.Size([50, 32])
torch.Size([50, 32])


In [26]:
print(dlnP_dTheta := tqs.dlnP_dTheta(sampled_spins, test_potentials))

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


RuntimeError: grad_and_value(f)(*args): Expected f(*args) to return a scalar Tensor, got tensor with 3 dims. Maybe you wanted to use the vjp or jacrev APIs instead?

Why is `jacrev` the right function to use here and what does it do?

In [37]:
from torch.func import functional_call, vmap, jacrev

In [None]:
# grab a spin chain and a corresponding potential function,
# recalling that their dimensions are (seq, batch)
a_spin_chain = sampled_spins.clone().detach().requires_grad_(True)[:, 1]
a_potential_func = test_potentials.clone().detach().requires_grad_(True)[:, 1]

In [None]:
print(a_spin_chain)
print(a_potential_func)

tensor([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       grad_fn=<SelectBackward0>)
tensor([-1.2020,  0.1472, -0.7155, -0.0764, -0.3391, -0.7191, -0.3792, -0.7279,
         0.6413, -0.0596,  0.5746, -0.0770,  0.3573,  0.3304, -0.9190,  0.4531,
         2.0006, -0.2597,  0.3847,  0.9929,  2.4995, -0.9985,  1.0245, -1.8078,
        -0.1615,  0.7566,  1.5770, -1.7874, -0.0235,  0.0087, -0.2809, -0.5713,
         0.2140, -0.9265,  1.5526,  0.9636, -0.0928,  1.3318,  0.5748,  0.2954,
         0.2612,  1.5846,  0.2285, -0.4883, -0.1614,  0.1229,  0.8383,  0.3604,
        -0.7181,  0.6586], grad_fn=<SelectBackward0>)


In [None]:
params = {k: v.detach for k, v in tqs.named_parameters()}
buffers = {k: v.detach for k, v in tqs.named_buffers()}

deriv = jacrev(tqs, argnums=1)(a_potential_func.unsqueeze(1), a_spin_chain.unsqueeze(1))

# Jacobian of probabilities; Jacobian of phases.
for d in deriv:
    print(d.shape)

torch.Size([100, 1, 2, 50, 1])
torch.Size([100, 1, 1, 50, 1])


## Finding the Jacobian of $\ln P(x; \Theta)$
- For a singleton-batch spin chain
- Should be with the jacrev function

In [54]:
params = {k: v.detach() for k, v in tqs.named_parameters()}
buffers = {k: v.detach() for k, v in tqs.named_buffers()}

lnP = tqs._ln_P(a_spin_chain, a_potential_func, params, buffers)
lnP

tensor([[[-0.9709, -0.4955]],

        [[-1.3604, -1.4039]],

        [[-1.1900, -0.5641]],

        [[-1.1640, -1.0406]],

        [[-0.8777, -0.7167]],

        [[-0.7257, -0.4631]],

        [[-0.6685, -0.4927]],

        [[-0.6746, -0.3667]],

        [[-0.5620, -1.1268]],

        [[-0.4962, -0.6093]],

        [[-0.4236, -1.1076]],

        [[-0.3556, -0.5925]],

        [[-0.3733, -0.8585]],

        [[-0.4237, -0.7637]],

        [[-0.6484, -0.2778]],

        [[-0.5567, -0.8247]],

        [[-0.6161, -1.2343]],

        [[-0.6695, -0.3852]],

        [[-0.6328, -0.7922]],

        [[-0.6597, -1.1328]],

        [[-0.6509, -1.3098]],

        [[-0.6534, -0.3330]],

        [[-0.5704, -1.1740]],

        [[-0.6532, -0.2968]],

        [[-0.5602, -0.3840]],

        [[-0.6507, -0.8223]],

        [[-0.6980, -1.1214]],

        [[-0.8199, -0.2763]],

        [[-0.7413, -0.5248]],

        [[-0.6266, -0.6431]],

        [[-0.6043, -0.4830]],

        [[-0.6659, -0.3691]],

        

The following is the derivative of tqs._lnP with respect to a single spin chain:

In [56]:
params = {k: v.detach() for k, v in tqs.named_parameters()}
buffers = {k: v.detach() for k, v in tqs.named_buffers()}
deriv = jacrev(tqs._ln_P, argnums=0)(a_spin_chain, a_potential_func, params, buffers)
deriv

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


tensor([[[[-6.9693e-04, -8.1649e-04, -1.7328e-03,  ..., -3.1597e-04,
           -3.0234e-04, -2.9149e-04],
          [-6.0595e-04, -6.6552e-04, -1.5533e-03,  ..., -3.9051e-04,
           -3.5299e-04, -3.3270e-04]]],


        [[[-7.7657e-04, -8.8852e-04, -2.0411e-03,  ..., -5.5390e-04,
           -6.3683e-04, -6.1466e-04],
          [-4.5378e-03, -4.5588e-03, -9.6902e-03,  ..., -4.7605e-03,
           -4.8577e-03, -4.6207e-03]]],


        [[[-1.0534e-03, -1.1253e-03, -2.6640e-03,  ..., -5.3529e-04,
           -5.7606e-04, -5.7928e-04],
          [-8.6301e-04, -8.7886e-04, -2.3665e-03,  ..., -7.0268e-04,
           -6.9369e-04, -6.6843e-04]]],


        ...,


        [[[-9.9364e-04, -1.0009e-03, -2.8760e-03,  ..., -8.5855e-01,
           -5.3808e-04, -5.8313e-04],
          [-1.9138e-03, -1.8252e-03, -5.5144e-03,  ...,  1.4042e+00,
           -1.7142e-03, -1.8204e-03]]],


        [[[-6.6645e-04, -6.4199e-04, -1.5567e-03,  ..., -1.8454e-04,
           -7.6325e-01, -3.5101e-04],
      

...so now we need to map this single-spin-chain `jacrev(tqs)` operator over the spin chains in our batch
- (noting that `jacrev(tqs)` is *itself* an operator)

In [None]:
dlnP_dTheta_sample = jacrev(tqs._ln_P)
vmap_dlnP_dTheta = vmap(dlnP_dTheta_sample, in_dims=(1, 1, None, None))
vmap_dlnP_dTheta(a_spin_chain, a_potential_func, params, buffers)

In [None]:
print(len(d_ln_P_d_theta))

30


In [None]:
# Print the parameters shapes of the model
for i in tqs.parameters():
    print(i.shape)

torch.Size([64])
torch.Size([64])
torch.Size([192, 64])
torch.Size([192])
torch.Size([64, 64])
torch.Size([64])
torch.Size([128, 64])
torch.Size([128])
torch.Size([64, 128])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([192, 64])
torch.Size([192])
torch.Size([64, 64])
torch.Size([64])
torch.Size([128, 64])
torch.Size([128])
torch.Size([64, 128])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([2, 64])
torch.Size([2])
torch.Size([1, 64])
torch.Size([1])


In [None]:
for i in tqs.parameters():
    print(i)

Parameter containing:
tensor([ 0.9367,  0.2668, -1.6303, -1.7862,  0.8277,  0.1077,  1.8067, -2.5265,
         0.7316, -1.4697, -0.4042, -0.8399, -0.2293,  0.3884,  0.7997,  1.0304,
         0.6088, -0.0918, -0.6525, -0.7985,  0.5418,  0.5808,  0.5402,  0.6873,
         0.1134,  0.5988, -1.9142, -0.8318,  0.6114, -0.2724, -0.2606, -1.3139,
         1.0847,  1.9056, -1.2572,  1.5666,  0.2288, -0.4180,  0.6461, -0.3176,
        -0.3120, -0.2958, -0.3488, -0.3325, -0.6533, -0.8155, -0.2047,  0.2941,
        -1.8817, -2.4597, -1.3084,  1.9289,  1.8321, -0.0902,  0.4203,  1.6066,
        -0.7052, -0.0258, -0.1230,  0.3112,  1.2757,  0.0777, -1.9691, -0.9766],
       requires_grad=True)
Parameter containing:
tensor([ 0.4751, -0.3406,  0.2423,  0.2803,  1.7058,  0.7552, -0.9422,  0.0115,
         0.9143,  0.5812,  0.8512, -0.0914,  0.6948, -0.4445,  0.3403, -1.9351,
        -0.4757, -0.5741, -1.0688,  0.9787, -0.3464, -0.4479,  1.8879, -0.2108,
         1.2273,  0.2867,  0.2962, -1.8591,  1.5

In [None]:
for grad, param in zip(d_ln_P_d_theta, tqs.parameters()):
    if grad is not None:
        print(grad.shape, param.shape)

torch.Size([64]) torch.Size([64])
torch.Size([64]) torch.Size([64])
torch.Size([192, 64]) torch.Size([192, 64])
torch.Size([192]) torch.Size([192])
torch.Size([64, 64]) torch.Size([64, 64])
torch.Size([64]) torch.Size([64])
torch.Size([128, 64]) torch.Size([128, 64])
torch.Size([128]) torch.Size([128])
torch.Size([64, 128]) torch.Size([64, 128])
torch.Size([64]) torch.Size([64])
torch.Size([64]) torch.Size([64])
torch.Size([64]) torch.Size([64])
torch.Size([64]) torch.Size([64])
torch.Size([64]) torch.Size([64])
torch.Size([2, 64]) torch.Size([2, 64])
torch.Size([2]) torch.Size([2])
torch.Size([1, 64]) torch.Size([1, 64])
torch.Size([1]) torch.Size([1])


In [None]:
d_ln_P_d_theta

(tensor([ 12.0460,  -5.3368,  -5.4714,   6.7329, -21.8639,  16.2115,  -1.8841,
           2.7430,   6.1641,  25.9044,  -1.0642,   8.0647, -14.6524,   1.7242,
           3.3404,  19.5245,  -1.9994,  16.3765,   3.0707,  -3.7363,   5.9793,
          -3.7739,  16.6138,  -8.9862,   9.6171, -33.6170,  14.7179,  -7.1137,
          10.2698,   0.3457,   9.2687,   1.9637,  -3.1346, -19.9616,   6.5355,
           5.2249,  11.2312,   8.6126,  -4.7546,   2.2451,   3.1444, -15.6841,
          -0.5689,  -1.4260,  23.9463,   7.9263,   2.7100,   7.9186,  -0.6613,
           6.2767,  -8.4629,  -2.8857,  -9.8538,   7.5645, -11.8798,  13.5053,
          -9.0124,  16.5284,  -0.6507,  14.4156, -19.2918,  -6.8954,  -2.0077,
           8.1826], grad_fn=<ViewBackward0>),
 tensor([-1.4482,  0.8351,  1.4949, -1.6435,  0.7335, -1.1153,  0.9714,  0.2423,
         -2.1453, -1.2152,  0.3682,  0.1144, -0.5094,  0.1646, -0.4671, -0.9084,
         -1.0607,  0.3759,  0.3585,  0.0600,  0.9356,  1.6283, -0.6500, -0.7605,


In [None]:
torch.ones_like(psi_x).shape

torch.Size([50, 32])

In [None]:
print((P := psi_x * psi_x.conj()).shape)

torch.Size([50, 32])


In [None]:
ex_re = torch.randn(10, 2)
ex_im = torch.randn(10, 2)
ex_com = torch.complex(ex_re, ex_im)
ex_com

tensor([[ 0.5769-6.1739e-01j, -0.6152+8.4863e-01j],
        [ 1.7631+7.5909e-01j,  0.1033+1.3899e+00j],
        [ 0.3123-2.5167e-03j,  0.0912-1.1067e+00j],
        [ 0.9518-1.4823e+00j, -0.5423+4.7282e-01j],
        [ 1.5811+1.1638e+00j,  0.5020+2.9848e+00j],
        [-1.6522-1.4240e+00j, -0.2496-1.8722e-01j],
        [-1.4371-4.0499e-01j,  1.7068+1.1778e+00j],
        [-1.1488+1.8205e+00j, -0.1983-1.2034e+00j],
        [ 0.2536+7.6231e-01j,  0.9571+1.1045e+00j],
        [ 1.5606+4.7761e-01j,  0.1221-2.1698e-01j]])

In [None]:
ex_com.norm(dim=0)

tensor([5.1365, 4.6943])

What shape is psi_x * psi_x.conj()?

In [None]:
print(psi_x.shape)
print((psi_x * psi_x.conj()).shape)  # This produces a probability distribution

torch.Size([50, 32])
torch.Size([50, 32])


In [None]:
from torch import vmap

N = 5
M = 12  # Batch dimension
x = torch.randn(N, requires_grad=True)
y = torch.randn(N, requires_grad=True)
f = lambda x, y: x**2 + y**2
z = f(x, y)
print(z.shape)


def get_vjp(v):
    return torch.autograd.grad(z, x, grad_outputs=v, retain_graph=True)


(jacobian := torch.vmap(get_vjp, randomness="same")(torch.eye(N)))

torch.Size([5])


(tensor([[ 0.3270, -0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -2.3047,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0000,  0.4629,  0.0000,  0.0000],
         [ 0.0000, -0.0000,  0.0000,  5.6694,  0.0000],
         [ 0.0000, -0.0000,  0.0000,  0.0000,  1.4428]]),)

In [None]:
N = 5
M = 12  # Batch dimension
x = torch.randn((N, M), requires_grad=True)
y = torch.randn((N, M), requires_grad=True)
f = lambda x, y: x**2 + y**2
z = f(x, y)
print(z.shape)


def get_vjp(v):
    return torch.autograd.grad(z, x, grad_outputs=v, retain_graph=True)


(jacobian := torch.vmap(get_vjp, in_dims=1, randomness="same")(torch.eye(M)))

torch.Size([5, 12])


ValueError: NestedTensor tensor_attr_supported_getter(self: jt_all): expected self to be a jagged layout NestedTensor

In [None]:
# torch.dot has type [D], [D] -> []

# What if we wanted to apply it across a batch dimension where the second dim is the batch dim? [N, D], [N, D] -> [D]

batched_dot = vmap(torch.dot, in_dims=1)
x, y = torch.randn(5, 3), torch.randn(5, 3)
batched_dot(x, y)

tensor([ 0.1966, -0.7046,  0.8144])

We can use this to map torch.autograd.grad over a batch of inputs:

In [None]:
# A function that takes in a sample of dimension (N,) and returns the gradient with respect to
# x
def grad_on_seq(v):
    return torch.autograd.grad(z, x, grad_outputs=v, retain_graph=True)


grad_vmap = vmap(grad_on_seq, in_dims=1, randomness="same")
grad_vmap(torch.eye(N))

ValueError: NestedTensor tensor_attr_supported_getter(self: jt_all): expected self to be a jagged layout NestedTensor