The goal of this book is to justify to myself that the implementation of multiheaded attention is right.

In [92]:
from boilerplate import setup_nb
import torch
%load_ext autoreload
%autoreload 2
device = setup_nb()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
CUDA is available.
Changed working dir to /home/spandan/Projects/hubbard-transformer


In [None]:
TOKEN_DIMS = 16
PARAM_DIMS = 32
EMBED_DIMS = TOKEN_DIMS + PARAM_DIMS
N_PARAMS = 10
MAX_LEN = 100
WAVELEN_FACT = 1e6
MAX_OCC = 5
MIN_OCC = 0
BATCH = 32
N_HEADS = 8
KEY_DIMS = 12

In [94]:
from embedding import HubbardEmbedding

init_len = 10

he = HubbardEmbedding(TOKEN_DIMS, PARAM_DIMS, N_PARAMS, MAX_LEN, WAVELEN_FACT)
he.to(device)
params = torch.randn(N_PARAMS, BATCH).to(device)
occupations = torch.randint(1, 5, (init_len, BATCH)).to(device)
logits = he(params, occupations).to(device, dtype=torch.complex64)

In [95]:
import einops as ein

In [96]:
logits.shape

torch.Size([20, 32, 64])

In [97]:
from complex_model import ComplexAttention

print("Shape of logits before attention:", logits.shape)
ca = ComplexAttention(EMBED_DIMS, KEY_DIMS, N_HEADS, MAX_LEN).to(device)
logits = ca.forward(logits)
print("Shape of logits after attention:", logits.shape)

Shape of logits before attention: torch.Size([20, 32, 64])
Using a value space of dimension 8
Shape of logits after attention: torch.Size([20, 32, 64])


In [98]:
logits[:, 0, :]

tensor([[  8.3989+13.3636j,   0.4310-29.0144j,   5.3808+13.9602j,
          ...,   7.4629+1.6668j,  38.0070+9.0520j,
          19.0019-2.2060j],
        [  9.2172+13.3212j,   0.3921-28.9608j,   5.8645+13.9536j,
          ...,   7.4354+1.6831j,  37.9700+9.0356j,
          18.9245-2.1504j],
        [ 12.5984+20.0453j,   0.6466-43.5216j,   8.0712+20.9404j,
          ...,  11.1943+2.5002j,  57.0105+13.5780j,
          28.5028-3.3091j],
        ...,
        [ 30.9294-22.2259j, -25.0534+1.1900j, -12.9271+7.1966j,
          ...,  10.1066+12.5140j,  21.1985+12.8976j,
          27.8866-35.0836j],
        [  6.7175+8.7586j, -17.1525+18.2251j,   3.9634+21.8257j,
          ...,  61.8793+20.0443j, -12.9123-15.1276j,
          34.3297+23.7823j],
        [-14.0472-23.5252j, -10.2845+6.1496j,   1.6564+2.2612j,
          ...,  24.9164+22.6090j,  -7.1596-17.9396j,
          19.0006+38.1621j]], device='cuda:0', grad_fn=<SliceBackward0>)

Attention mechanism attention patterns as einsum inner products

If you matmul two matrices you create a matrix of inner products

$$ 
C_{ij} 

= A_{ik} B_{kj}
$$

The specific case for query and key matul is:

$$

C_{sS}

= (K_{ks})^H Q_{kS}

= K'_{sk} Q_{kS}

= C

$$

Note that the axis s is the "same" as the axis S but they need different names for einsum. 

This matmul is all possible inner products over adjacency relationships and is the $\braket{K, Q}$ described in the paper.

We need to compute a $\braket{k|q}$ for each query and key vector

The softmax for the pattern is applied across the axis of summation that the linear combination from E to E' happens on.

In [99]:
import torch

torch.triu(torch.ones(5, 5), diagonal=1).bool()

tensor([[False,  True,  True,  True,  True],
        [False, False,  True,  True,  True],
        [False, False, False,  True,  True],
        [False, False, False, False,  True],
        [False, False, False, False, False]])

In [100]:
# e.g., the fourth embedding update (row) should only be
# computed using the first four tokens--which is why
# there are only four unmasked relevances in the fourth row
torch.randn(5, 5).masked_fill(
    torch.triu(torch.ones(5, 5), diagonal=1).bool(), torch.tensor(float("-inf"))
)

tensor([[ 0.0710,    -inf,    -inf,    -inf,    -inf],
        [-0.7091, -0.4091,    -inf,    -inf,    -inf],
        [ 0.4236,  0.5352,  0.5777,    -inf,    -inf],
        [-0.0455, -0.6346, -0.2280, -0.6564,    -inf],
        [-1.0733, -0.9297,  0.2490,  0.1138, -0.4720]])