In [1]:
import torch

The (B, N) is the input and the output targets will also be (B, N) but shift to the right. So the 0th entry predicts the next, the 1st predicts the 2nd, and so on.

- We start with (B, N) where b is the batch and N is the max number of tokens we use to predict the next. These tokens are integers here now.
- We index into the token embeddings such that for each of the tokens in our past history, to get (B, N, T) where T is the embedding of each token now
- To encode positions, we can take a (N, T) embeddings and based on the positions add to (B, N, T) matrix
- Then, we take the (B, N, T) and run it through self attention (this notebook aims to show)  

To give an example of how that input to output tokens would be constructed

In [26]:
input_str = "Hello there!"
inputs = range(0, len(input_str)-1)
outputs = range(1, len(input_str))

for i, j in zip(inputs, outputs):
	print(f"'{input_str[:i+1]}' used to predict the next '{input_str[j]}'")

'H' used to predict the next 'e'
'He' used to predict the next 'l'
'Hel' used to predict the next 'l'
'Hell' used to predict the next 'o'
'Hello' used to predict the next ' '
'Hello ' used to predict the next 't'
'Hello t' used to predict the next 'h'
'Hello th' used to predict the next 'e'
'Hello the' used to predict the next 'r'
'Hello ther' used to predict the next 'e'
'Hello there' used to predict the next '!'


So despite there being in total 10 characters in 'Hello There!' there were in total 11 ways we can take a value string and predict the next character.

In other words, during training, when I slice a given string, it will be one less than the context size that I actually want. So account for that!

In [67]:
torch.manual_seed(0)
str_slice_size = 7
B = 1
T = 2
N = str_slice_size+1
x = torch.randn((B, N, T))
x

tensor([[[-1.1258, -1.1524],
         [-0.2506, -0.4339],
         [ 0.8487,  0.6920],
         [-0.3160, -2.1152],
         [ 0.3223, -1.2633],
         [ 0.3500,  0.3081],
         [ 0.1198,  1.2377],
         [ 1.1168, -0.2473]]])

One strategy to model the dependencies between the embedding is to simply sum them up, then hope that information can be used to predict the next token. For example, we could predict the 5th token, but summing up the previous 4s info and somehow using that info later on.

In [68]:
prev_4 = x[:, :4, :]
print("prevous 4", prev_4)
print("summed previous 4", prev_4.mean(1, keepdim=True))

prevous 4 tensor([[[-1.1258, -1.1524],
         [-0.2506, -0.4339],
         [ 0.8487,  0.6920],
         [-0.3160, -2.1152]]])
summed previous 4 tensor([[[-0.2109, -0.7524]]])


Another way to think about this, is that every value was weighed the same and we did a weighted sum where all the previous values are equally important to predict the next. 

In [69]:
prev_4_weighted = torch.ones((B, 4, 1))/4
torch.bmm(prev_4_weighted.transpose(1,2), prev_4)

tensor([[[-0.2109, -0.7524]]])

Then I can do this all at once for all the various weighted sums pretty easily.

In [79]:
mask = torch.tril(torch.ones((B, N, N)))
mask /= mask.sum(-1, keepdim=True)
print(mask)
print(mask.shape)
print(x.shape)
all = torch.bmm(
	mask, x
)
all

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
         [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
         [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]]])
torch.Size([1, 8, 8])
torch.Size([1, 8, 2])


tensor([[[-1.1258, -1.1524],
         [-0.6882, -0.7931],
         [-0.1759, -0.2981],
         [-0.2109, -0.7524],
         [-0.1043, -0.8546],
         [-0.0286, -0.6608],
         [-0.0074, -0.3896],
         [ 0.1331, -0.3718]]])

In [84]:
# just to verify indeed we get all the same averages here!
for i in range(0, N):
	prev = x[:, :i+1, :]
	print(torch.allclose(prev.mean(1), all[:, i, :]))


True
True
True
True
True
True
True
True


The key insight into attention is that the weight values in that mask are not constant, but learned. So we can learn to model which tokens we need to pay attention to in order to predict the next token accurately.

So let's compute that weight matrix on the fly. 

The Attention is all you need paper does 

$$\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

So the softmax portion is what computes the weighted sum basically and the V is essentially the x (although projected).

But to make sure that none of the information is accessed before the token is predicted, we causal mask at $QK^T$ which will then produce a matrix that has 0s in the upper triangular, so when we multiply by $V$ we are all good!

In [86]:
d_k = 10
Q = torch.randn((B, N, d_k))
K = torch.randn((B, N, d_k))
V = torch.randn((B, N, d_k))

In [89]:
QKT = torch.bmm(Q, K.transpose(1,2)) * (d_k**-.5)
QKT

tensor([[[-1.4061,  1.6575, -0.5871, -0.4502,  0.4991, -1.4148, -0.6164,
           0.5815],
         [-0.1199,  0.9972, -0.7784,  0.5000,  0.9681,  2.2110, -0.6460,
           2.5752],
         [-0.8138,  0.2423,  0.2972,  0.5500,  0.5481, -0.3925,  0.3161,
          -1.8521],
         [-1.2321, -0.4832,  0.1016, -0.2666, -0.4511, -1.4859,  0.4228,
           0.3468],
         [ 0.4941,  0.3841,  1.7922,  0.5413, -0.8161, -1.0426, -0.4646,
          -2.1853],
         [ 1.3196, -0.5372, -0.0496, -1.1341, -0.2754,  0.6426, -0.2624,
          -1.8283],
         [ 1.3901,  0.4294, -1.0545, -0.2860,  0.3455,  2.4669, -0.5157,
           2.4567],
         [ 0.2245, -2.2039,  0.6325,  0.3728, -0.2734, -2.2012,  1.3488,
           0.1603]]])

In [115]:
def softmax_mask(QKT, mask):
	return QKT.masked_fill(mask, float("-inf"))

def softmax_mask2(QKT, neg_infs):
	return neg_infs + QKT

mask = torch.triu(torch.ones((N, N)), diagonal=1).bool()
neg_infs = torch.triu(torch.full((N, N), float("-inf")), diagonal=1)
torch.allclose(softmax_mask(QKT, mask), softmax_mask2(QKT, neg_infs))

True

In [120]:
QKT_masked = softmax_mask(QKT, mask)

In [124]:
weights = torch.softmax(QKT_masked, dim=-1)
weights

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2465, 0.7535, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1447, 0.4159, 0.4394, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1049, 0.2218, 0.3980, 0.2754, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1454, 0.1303, 0.5326, 0.1525, 0.0392, 0.0000, 0.0000, 0.0000],
         [0.4530, 0.0707, 0.1152, 0.0389, 0.0919, 0.2302, 0.0000, 0.0000],
         [0.1964, 0.0751, 0.0170, 0.0367, 0.0691, 0.5764, 0.0292, 0.0000],
         [0.1181, 0.0104, 0.1777, 0.1370, 0.0718, 0.0104, 0.3637, 0.1108]]])

In [127]:
weights.sum(-1) # works!

tensor([[1., 1., 1., 1., 1., 1., 1., 1.]])

Then I can just use my scaled value before on V

In [None]:
attn = torch.bmm(
	weights,
	V
)
attn

tensor([[[-0.0118,  0.9797, -1.0661,  1.7720, -0.2793, -0.2769,  0.7489,
          -0.6435, -0.9518,  0.2715],
         [ 0.5031,  1.6354,  0.6345, -0.0076,  0.6580, -1.2055,  1.8842,
           0.7672, -0.6004,  0.4087],
         [ 0.8715,  1.1018,  0.1156, -0.0707,  0.4820, -0.6196,  1.2045,
           0.0868, -0.2388,  0.1161],
         [ 1.1207,  0.6184, -0.3822, -0.0261,  0.7907, -0.6119,  1.6589,
          -0.2516, -0.2471,  0.5078],
         [ 1.0384,  0.5649, -0.4053,  0.1306,  0.5143, -0.2784,  1.0980,
          -0.3669, -0.1385,  0.1659],
         [ 0.1054,  0.7243, -0.2848,  0.8400,  0.1375, -0.1596,  0.3091,
          -0.0489, -0.6927,  0.0544],
         [-0.2171,  0.6273,  0.2146,  0.2448,  0.2760, -0.2169, -0.4110,
           0.3250, -0.8813,  0.0926],
         [ 0.6966, -0.3517,  0.2894, -0.8149,  0.1164, -0.2958,  0.3576,
          -0.5450, -0.2323,  0.6254]]])

In [139]:
def attn(Q, K, V, mask):
	QKT = torch.bmm(Q, K.transpose(1,2)) * (d_k**-.5)
	masked_QKT = QKT.masked_fill(mask, float("-inf"))
	weights = torch.softmax(masked_QKT, dim=-1)
	return torch.bmm(weights, V)


attn(Q, K, V, mask=torch.triu(torch.ones((N, N), dtype=torch.bool), diagonal=1))

tensor([[[-0.0118,  0.9797, -1.0661,  1.7720, -0.2793, -0.2769,  0.7489,
          -0.6435, -0.9518,  0.2715],
         [ 0.5031,  1.6354,  0.6345, -0.0076,  0.6580, -1.2055,  1.8842,
           0.7672, -0.6004,  0.4087],
         [ 0.8715,  1.1018,  0.1156, -0.0707,  0.4820, -0.6196,  1.2045,
           0.0868, -0.2388,  0.1161],
         [ 1.1207,  0.6184, -0.3822, -0.0261,  0.7907, -0.6119,  1.6589,
          -0.2516, -0.2471,  0.5078],
         [ 1.0384,  0.5649, -0.4053,  0.1306,  0.5143, -0.2784,  1.0980,
          -0.3669, -0.1385,  0.1659],
         [ 0.1054,  0.7243, -0.2848,  0.8400,  0.1375, -0.1596,  0.3091,
          -0.0489, -0.6927,  0.0544],
         [-0.2171,  0.6273,  0.2146,  0.2448,  0.2760, -0.2169, -0.4110,
           0.3250, -0.8813,  0.0926],
         [ 0.6966, -0.3517,  0.2894, -0.8149,  0.1164, -0.2958,  0.3576,
          -0.5450, -0.2323,  0.6254]]])