In [55]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [56]:
from model import TQS
import torch

In [57]:
EMBED_DIM = 64
MAX_LENGTH = 100
NUM_HEADS = 1
NUM_LAYERS = 1
DIM_FEEDFORWARD = 128
TEST_LENGTH = 50
TEST_BATCH = 32

In [58]:
test_potentials = torch.randn(TEST_LENGTH, TEST_BATCH)
test_spins = torch.zeros(TEST_LENGTH, TEST_BATCH)
test_spins[torch.randint(0, TEST_LENGTH, (TEST_BATCH,)), torch.arange(TEST_BATCH)] = 1

In this situation each site has its own potential value

In [59]:
test_potentials  # (seq_len, batch_size)

tensor([[-0.8355, -1.7400,  1.0942,  ..., -1.1610,  1.1955, -1.3504],
        [-0.7105, -0.3849,  1.4603,  ..., -0.0645,  0.0686, -0.0424],
        [ 0.4510,  2.5088,  0.0482,  ...,  1.5008, -0.0663,  1.2951],
        ...,
        [-0.3750, -1.1446,  0.2271,  ..., -0.3643, -0.4005, -0.4082],
        [ 0.1729, -0.7864,  0.2976,  ...,  1.2363,  2.0191,  0.8710],
        [-0.5019,  0.4500, -0.5023,  ..., -0.2191,  0.0394,  0.5011]])

In [60]:
test_spins  # (seq_len, batch_size)

tensor([[1., 0., 1.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [61]:
tqs = TQS(
    embed_dim=EMBED_DIM,
    max_chain_len=MAX_LENGTH,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    possible_spins=2,
    dim_feedforward=DIM_FEEDFORWARD,
)



In [62]:
probs, phases = tqs(test_potentials, test_spins)

In [63]:
probs

tensor([[[0.4460, 0.4607],
         [0.4032, 0.4227],
         [0.5811, 0.5957],
         ...,
         [0.4324, 0.4341],
         [0.6108, 0.5829],
         [0.4041, 0.4245]],

        [[0.4560, 0.4223],
         [0.4975, 0.3897],
         [0.5926, 0.5843],
         ...,
         [0.5123, 0.4368],
         [0.5567, 0.5168],
         [0.5307, 0.4366]],

        [[0.5944, 0.5031],
         [0.6034, 0.5769],
         [0.5397, 0.4638],
         ...,
         [0.5894, 0.5567],
         [0.5592, 0.4499],
         [0.6072, 0.5547]],

        ...,

        [[0.6344, 0.4580],
         [0.6461, 0.4298],
         [0.6110, 0.4403],
         ...,
         [0.6812, 0.4494],
         [0.6696, 0.4140],
         [0.6364, 0.4233]],

        [[0.6722, 0.4469],
         [0.6540, 0.4431],
         [0.6951, 0.4497],
         ...,
         [0.6424, 0.4508],
         [0.6817, 0.4990],
         [0.6902, 0.5106]],

        [[0.7199, 0.4970],
         [0.7130, 0.5040],
         [0.7044, 0.5347],
         ...,
 

In [64]:
phases

tensor([[[-0.5843],
         [-0.8018],
         [ 0.2631],
         ...,
         [-0.4593],
         [ 0.2807],
         [-0.8551]],

        [[-0.2700],
         [-0.0188],
         [ 0.3330],
         ...,
         [ 0.4155],
         [ 0.0548],
         [ 0.2054]],

        [[ 0.4977],
         [ 0.5756],
         [ 0.0881],
         ...,
         [ 0.4130],
         [-0.6362],
         [-0.2154]],

        ...,

        [[ 2.0218],
         [ 2.1927],
         [ 1.9539],
         ...,
         [ 2.1012],
         [ 2.4222],
         [ 2.2464]],

        [[ 2.2077],
         [ 2.1917],
         [ 2.1699],
         ...,
         [ 2.0066],
         [ 1.9795],
         [ 1.8199]],

        [[ 2.2033],
         [ 1.9242],
         [ 2.0259],
         ...,
         [ 2.2054],
         [ 2.0765],
         [ 1.8804]]], grad_fn=<MulBackward0>)

In [65]:
# Spins sampled from the wave function the TQS models
spins = tqs.sample_spins(test_potentials, tqs.max_chain_len)
spins

tensor([[0., 1., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [66]:
T = 1.0
psi_x, psi_l, psi_r = tqs.psi_terms(spins, test_potentials, T)
E_loc = tqs.E_loc(psi_x, psi_l, psi_r, test_potentials, T)

## Model Parameters

In [67]:
for i in tqs.parameters():
    print(i)

Parameter containing:
tensor([ 0.7700, -0.1845, -0.5371,  1.5401,  0.3018,  1.5673, -0.5706,  0.5321,
        -0.0204, -1.1571,  1.4552,  1.0573,  0.3328,  1.2447,  0.1200,  0.9041,
         0.1951, -0.8236,  0.8809,  0.8374, -0.3279,  0.5729, -0.2325,  0.4923,
         1.0704,  0.6647, -0.4704, -0.8055, -1.3952,  0.6256,  0.2798,  1.0896,
         0.4670,  0.0300, -0.9752,  3.0311,  2.4092,  0.3060, -0.0418, -1.2533,
         0.5433, -1.1362,  1.2849,  0.8085, -0.4620,  0.3075, -0.4362,  1.1863,
        -0.2932,  0.5470,  0.3605,  0.4797,  1.2236,  0.3412,  0.1848, -0.2307,
         0.5630, -1.4219,  1.0084,  0.5428, -0.8511,  1.1395, -0.5453,  0.4003],
       requires_grad=True)
Parameter containing:
tensor([-0.1294, -0.9766, -1.2813, -0.9127, -1.9270,  0.5791, -2.0543, -0.5088,
        -1.1811,  0.9136, -2.1359, -0.4256, -0.2987, -0.1058,  0.7262,  0.4244,
        -2.0061, -0.4241, -0.3564, -0.9538, -1.7064,  0.8925, -0.2464,  1.2790,
         0.4191,  0.0980,  0.9135, -0.9046, -1.1

## Per-Sample Gradients

The target is

$$\frac{d}{d\Theta}\left ( P(x; \theta)\right )$$


...which ends up being represented as some tuple of gradients of $P$ with respect to $\theta_1, \theta_2, ...$:

$$\left ( \frac{dP}{d\theta_1}, \frac{dP}{d\theta_2}, ..., \frac{dP}{d\theta_p} \right )$$

...assuming we have $p$ model parameters

### A Simple CNN Example

From https://pytorch.org/tutorials/intermediate/per_sample_grads.html

In [68]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def loss_fn(predictions, targets):
    return F.nll_loss(predictions, targets)

In [69]:
device = "cuda"

num_models = 10
batch_size = 64
data = torch.randn(batch_size, 1, 28, 28, device=device)

targets = torch.randint(10, (64,), device=device)

In [70]:
# Simple training computes an average across a batch
model = SimpleCNN().to(device=device)
predictions = model(data)  # move the entire mini-batch through the model

loss = loss_fn(predictions, targets)
loss  # Averaged over the batch

tensor(2.3039, device='cuda:0', grad_fn=<NllLossBackward0>)

In [71]:
loss.backward()  # back propagate the 'average' gradient of this mini-batch

Instead of computing an averaged loss, we'd like to separate the loss function's computation graph (and thus its loss gradient, post backprop) between each sample

In [72]:
from torch.func import functional_call, vmap, grad

params = {k: v.detach() for k, v in model.named_parameters()}
buffers = {k: v.detach() for k, v in model.named_buffers()}

In [None]:
def compute_loss(params, buffers, sample, target):
    """
    Computes loss gradients for a single sample with respect to
    the model parameters (affected by differentiation) and buffers
    (not affected by differentiation).
    """
    batch = sample.unsqueeze(0)
    targets = target.unsqueeze(0)

    # Swaps model parameters and buffers with the ones provided,
    # using a name -> tensor (or value) mapping
    predictions = functional_call(model, (params, buffers), (batch,))
    loss = loss_fn(predictions, targets)
    return loss

Arguments of `functional_call`:
1. The module to transplant the parameters and buffers into
2. The parameters and buffers themselves (named, using a dictionary, so that PyTorch knows where each one goes)
3. The arguments to the function--in this case, the entire "batch"

In [81]:
for param in model.parameters():
    print(param)

Parameter containing:
tensor([[[[-0.1084,  0.0424, -0.3139],
          [ 0.1655, -0.1783,  0.2891],
          [ 0.2499,  0.0475,  0.1136]]],


        [[[-0.1207, -0.1051, -0.1374],
          [-0.0912,  0.1755, -0.0016],
          [-0.0721,  0.1740, -0.0377]]],


        [[[-0.0946, -0.2715,  0.3007],
          [ 0.2387,  0.0555, -0.1468],
          [-0.2477,  0.2970,  0.1652]]],


        [[[-0.0218,  0.0265, -0.0242],
          [ 0.0138, -0.2554,  0.2154],
          [-0.0746,  0.0563,  0.1636]]],


        [[[ 0.1062,  0.2369,  0.1340],
          [-0.1994,  0.0486, -0.0190],
          [ 0.2310,  0.1852, -0.0057]]],


        [[[-0.2098, -0.1552,  0.0312],
          [ 0.2002,  0.2086, -0.0492],
          [-0.1418,  0.1076, -0.1305]]],


        [[[ 0.2461,  0.2975, -0.0725],
          [-0.2895,  0.1223, -0.1883],
          [-0.1135, -0.1375,  0.2733]]],


        [[[ 0.1612,  0.2042, -0.1575],
          [ 0.2084, -0.3041, -0.0166],
          [-0.1108,  0.1487,  0.0416]]],


        [[

In [82]:
# A function computing the gradient of compute_loss with respect to its first argument
ft_compute_grad = grad(compute_loss)
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)
ft_per_sample_grads

{'conv1.weight': tensor([[[[[-1.6652e-02, -2.2975e-02,  1.9531e-04],
            [ 2.5385e-02, -2.2959e-02, -1.2314e-02],
            [-1.4160e-02, -1.8586e-03,  6.4033e-03]]],
 
 
          [[[ 2.1709e-03,  2.7222e-03,  3.5510e-03],
            [ 6.0316e-03, -2.2517e-02,  3.3607e-03],
            [ 1.6695e-02,  4.2553e-02, -2.8156e-03]]],
 
 
          [[[ 7.0681e-03, -1.2002e-02,  1.3895e-02],
            [-2.9352e-02, -9.3097e-03, -5.0439e-03],
            [ 4.7273e-03, -2.4532e-02,  1.0372e-02]]],
 
 
          ...,
 
 
          [[[ 4.8668e-03,  2.0931e-03,  7.6923e-03],
            [-9.2889e-03, -1.1912e-02,  2.5961e-03],
            [-7.6692e-03, -1.2963e-02,  4.3653e-03]]],
 
 
          [[[-9.2796e-04,  8.0672e-03,  4.2809e-04],
            [ 6.0073e-03,  9.2049e-03,  3.8828e-02],
            [ 1.6068e-02,  1.1058e-03, -2.5562e-02]]],
 
 
          [[[ 9.2217e-03, -5.0106e-03, -6.9974e-03],
            [ 9.8014e-03, -1.6241e-03,  1.1001e-02],
            [-2.0466e-02, -3.3676e

In [77]:
ft_per_sample_grads.keys()

dict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])

`in_dims`?
- Denotes to PyTorch which dimensions to apply the function along
    - Similar to indexing those dimensions and then passing the result into the function being vmap'd
- None indicates no map dimension--i.e., don't map the function over these
- `in_dim` should have a structur reflecting the inputs themselves, selecting dimensions
- In this case:
    - None - Don't iterate over params; treat them as a constant wrt summation indexj
    - None - Don't iterate over buffers; they're a constant with respect to iteration
    - 0 - The first dimension of data is the batch dimension
    - 0 - ...which should be aligned with the first dimension of the parameters

[Documentation about vmap](https://pytorch.org/docs/stable/generated/torch.vmap.html)

This can be thought of as
$$\sum _ {i \in \Gamma_\text{batch}} \text{compute\_loss}(\text{params}, \text{buffers}, \text{samples}[i], \text{targets}[i])$$

The value of `in_dims` places the $i$ in the index-brackets of the input tensors

Performance of `vmap`
- Much faster (documentation claims order of magnitude) than computing in a simple for loop
- TODO: where is this speedup coming from?