# Preliminaries

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

Load dataset

In [2]:
with open("shakespeare.txt", "r", encoding="utf-8") as f:
    text = f.read()

In [3]:
print(f"Number of chars in dataset: {len(text)}")

Number of chars in dataset: 1115394


In [4]:
print(text[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


Build vocabulary

In [5]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"{vocab_size=}")
"".join(chars)

vocab_size=65


"\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"

Create "char to index" and "index to char" mappings

In [6]:
stoi = { s:i for i, s in enumerate(chars) }
itos = { i:s for s, i in stoi.items() }

In [7]:
stoi["&"]

4

In [8]:
itos[4]

'&'

In [9]:
encode = lambda s: [stoi[c] for c in s] # encodes a string
decode = lambda e: "".join([itos[i] for i in e]) # decodes an encoding

In [10]:
encode("Hello world")

[20, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42]

In [11]:
decode([20, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42])

'Hello world'

Tokenize dataset

In [12]:


data = torch.tensor(encode(text), dtype=torch.long)
data.shape

torch.Size([1115394])

In [13]:
data[:100]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])

Split dataset into train and validation

In [14]:
n = int(0.9*len(data))

train_data = data[:n]
val_data = data[n:]

print(train_data.shape)
print(val_data.shape)

torch.Size([1003854])
torch.Size([111540])


Experimenting with context length

In [15]:
context_len = 8
sample = train_data[:context_len+1].tolist()
sample

[18, 47, 56, 57, 58, 1, 15, 47, 58]

In [16]:
print("Input ==> Target")
print("----------------")

for i in range(context_len):
    x = sample[:i+1]
    y = sample[i+1]
    print(f"{x} ==> {y}")

Input ==> Target
----------------
[18] ==> 47
[18, 47] ==> 56
[18, 47, 56] ==> 57
[18, 47, 56, 57] ==> 58
[18, 47, 56, 57, 58] ==> 1
[18, 47, 56, 57, 58, 1] ==> 15
[18, 47, 56, 57, 58, 1, 15] ==> 47
[18, 47, 56, 57, 58, 1, 15, 47] ==> 58


In [17]:
print("Input (with padding) ==> Target")
print("-------------------------------")

for i in range(context_len):
    x = [0] * (context_len-(i+1)) + sample[:i+1]
    y = sample[i+1]
    print(f"{x} ==> {y}")

Input (with padding) ==> Target
-------------------------------
[0, 0, 0, 0, 0, 0, 0, 18] ==> 47
[0, 0, 0, 0, 0, 0, 18, 47] ==> 56
[0, 0, 0, 0, 0, 18, 47, 56] ==> 57
[0, 0, 0, 0, 18, 47, 56, 57] ==> 58
[0, 0, 0, 18, 47, 56, 57, 58] ==> 1
[0, 0, 18, 47, 56, 57, 58, 1] ==> 15
[0, 18, 47, 56, 57, 58, 1, 15] ==> 47
[18, 47, 56, 57, 58, 1, 15, 47] ==> 58


Function to get a batch

In [18]:
torch.manual_seed(1337)

# params
batch_size = 4
context_len = 8

def get_batch(split):
    data = train_data if split == "train" else val_data
    # get "batch_size" number of random indices
    ixs = torch.randint(low=0, high=len(data)-context_len, size=(batch_size,))
    # get inputs
    x = torch.stack([data[i:i+context_len] for i in ixs])
    # get labels
    y = torch.stack([data[i+1:i+context_len+1] for i in ixs])
    return x, y

# get a sample batch
xb, yb = get_batch("train")

print("------")
print("inputs")
print(xb.shape)
print(xb)
print("-------")
print("targets")
print(yb.shape)
print(yb)

------
inputs
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
-------
targets
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])


# BigramLanguageModel

What is the expected negative log likelihood of a completely uniform model?


In [19]:
-torch.log(torch.tensor(1/vocab_size))

tensor(4.1744)

Define the model

In [20]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # interesting how the output dimension is vocab_size
        self.embedding_table = nn.Embedding(vocab_size, vocab_size)
        # this is a very simple model
        # the embedding can be directly interpreted as the logits (predictions for next token)

    def forward(self, x, y=None):
        # x: (B, T)
        # y: (B, T)
        # embedding table essentially replaces each index with its corresponding embedding
        logits = self.embedding_table(x) # logits: (B, T, C)
        # calculate loss:
        if y is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            y = y.view(-1)
            loss = F.cross_entropy(logits, y)
        return logits, loss
    
    def generate(self, x, max_new_tokens):
        """
        quick note about this method:
        this is a simple bigram model, so it only needs the immediate previous
        character to predict the next token
        however this method's implementation feeds the entire previous context,
        and then we just extract the last prediction
        this is obviously inefficient, as we could simply pass the most recent token,
        to predict the next one
        however this method's implementation will scale to more complex architectures,
        which actually care about context length :)        
        """
        # x: (B, T)
        for _ in range(max_new_tokens):
            # get predictions
            logits, _ = self(x) # logits: (B, C, T)
            # for each prediction, get last timestep prediction
            logits = logits[:, -1, :] # (B, C)
            # calculate probs
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1, replacement=True)
            # append sampled index to the running sequence
            x = torch.cat((x, idx_next), dim=1) # (B, T+1)
        return x

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

torch.Size([32, 65])
tensor(5.0364, grad_fn=<NllLossBackward0>)


Sample generation before training

In [21]:
inputs = torch.zeros((1, 1), dtype=torch.long)
inputs

tensor([[0]])

In [22]:
generation = m.generate(inputs, max_new_tokens=100)
generation

tensor([[ 0, 50,  7, 29, 37, 48, 58,  5, 15, 24, 12, 48, 24, 16, 59, 29, 41, 24,
         64, 63,  5, 30, 21, 53, 11,  5, 23, 42, 46, 54, 34,  0, 60, 24, 47, 62,
         39,  6, 52, 57, 61, 37, 38, 61, 24, 17, 28, 31,  5, 54, 58, 21, 38, 55,
         27, 38, 22,  3, 15, 13,  3, 64, 63,  7, 29, 32, 49, 43, 25, 49,  1, 62,
          8, 45, 29, 31, 18, 15, 24, 45,  2, 47, 35,  9, 44, 27,  2,  9, 16, 19,
         36, 13, 55, 32, 57, 55,  9, 54, 42, 45, 55]])

In [23]:
print(decode(generation.squeeze().tolist()))


l-QYjt'CL?jLDuQcLzy'RIo;'KdhpV
vLixa,nswYZwLEPS'ptIZqOZJ$CA$zy-QTkeMk x.gQSFCLg!iW3fO!3DGXAqTsq3pdgq


Training

In [24]:
# initialize Adam optimizer with bigram model params
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [25]:
batch_size = 32

for steps in range(1, 10001):
    # get sample batch
    xb, yb = get_batch("train")
    # calculate loss
    logits, loss = m(xb, yb)
    # set gradients to zero
    optimizer.zero_grad(set_to_none=True)
    # calculate gradients
    loss.backward()
    # update parameters
    optimizer.step()

    if steps % 1000 == 0 or steps == 1:
        print(loss.item())

4.658271312713623
3.6379246711730957
3.089521646499634
2.8084068298339844
2.5052883625030518
2.5904765129089355
2.492192268371582
2.568422555923462
2.520578622817993
2.396061420440674
2.5589075088500977


Calculate loss on datasets

In [26]:
@torch.no_grad()
def calc_loss(eval_iters=200):
    m.eval()
    out = {}
    for split in ["train", "val"]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            _, loss = m(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    m.train()
    return out

calc_loss()

{'train': tensor(2.4627), 'val': tensor(2.4975)}

Generation after training

In [27]:
inputs = torch.zeros((1, 1), dtype=torch.long)
generation = m.generate(inputs, max_new_tokens=200)
print("".join(decode(generation.squeeze().tolist())))


Yo fyour me than!
Sow
Dorce d, ather tod a ping hal ld ot d
Se nel thans ocontherat, aise prmis
Whal ong w veldlaleerMI l-my,


At: awhit Sinealathslle t hie s sh ke,-ck:

Carnth mey d cocthathacer, r


# Attention

In [30]:
B, T, C = 4, 8, 2 # batch, time, channels
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [31]:
xbow = torch.zeros(B, T, C)
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t, C)
        xprev_avg = torch.mean(xprev, 0) # (C)
        xbow[b, t] = xprev_avg

In [32]:
x[0]

tensor([[ 0.5057, -0.0894],
        [-0.6165,  0.7783],
        [ 0.6331,  1.1547],
        [-0.2711,  0.3987],
        [-1.7937, -1.6626],
        [-0.9701, -0.9219],
        [ 0.1350,  1.1365],
        [-1.3586, -0.2989]])

In [33]:
xbow[0]

tensor([[ 0.5057, -0.0894],
        [-0.0554,  0.3444],
        [ 0.1741,  0.6145],
        [ 0.0628,  0.5606],
        [-0.3085,  0.1159],
        [-0.4188, -0.0570],
        [-0.3397,  0.1135],
        [-0.4670,  0.0619]])

In [34]:
# a = torch.ones((3, 3))
# a = torch.tril(torch.ones((3, 3)))
a = torch.tril(torch.ones(3, 3))
a = a / a.sum(1, keepdim=True)

b = torch.tensor([[5, 11],
                  [4, 7],
                  [1, 3]], dtype=torch.float)

c = a @ b

print(a)
print(b)
print(c)

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
tensor([[ 5., 11.],
        [ 4.,  7.],
        [ 1.,  3.]])
tensor([[ 5.0000, 11.0000],
        [ 4.5000,  9.0000],
        [ 3.3333,  7.0000]])


In [35]:
# another way to create the attention matrix (a)

a = torch.ones((3, 3))
mask = torch.tril(torch.ones((3, 3)))
a_masked = a.masked_fill(mask==0, -torch.inf)
attention = F.softmax(a_masked, dim=1)

print(a)
print(mask)
print(a_masked)
print(attention)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[1., -inf, -inf],
        [1., 1., -inf],
        [1., 1., 1.]])
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])


In [36]:
x.shape

torch.Size([4, 8, 2])

In [37]:
a = torch.ones((8, 8))
mask = torch.tril(torch.ones((8, 8)))
a_masked = a.masked_fill(mask==0, -torch.inf)
attention = F.softmax(a_masked, dim=1)

print(attention.shape)
print(attention)

torch.Size([8, 8])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


In [38]:
# the attention is copied for each element in the batch
# (8, 8) @ (B, 8, 2) => (B, 8, 8) @ (B, 8, 2) = (B, 8, 2)
out = attention @ x 
out.shape

torch.Size([4, 8, 2])

In [39]:
xbow[0]

tensor([[ 0.5057, -0.0894],
        [-0.0554,  0.3444],
        [ 0.1741,  0.6145],
        [ 0.0628,  0.5606],
        [-0.3085,  0.1159],
        [-0.4188, -0.0570],
        [-0.3397,  0.1135],
        [-0.4670,  0.0619]])

In [40]:
out[0]

tensor([[ 0.5057, -0.0894],
        [-0.0554,  0.3444],
        [ 0.1741,  0.6145],
        [ 0.0628,  0.5606],
        [-0.3085,  0.1159],
        [-0.4188, -0.0570],
        [-0.3397,  0.1135],
        [-0.4670,  0.0619]])

These attention values are hardcoded so far, i.e., each embedding pays equal attention to all the other embeddings.

Ideally, these attention values should be data-dependent.

Therefore we would want to somehow ***learn*** these attention values.

In [41]:
B, T, C = 4, 8, 2 # batch, time, channels
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [43]:
head_size = 16

query = nn.Linear(C, head_size)
key = nn.Linear(C, head_size)

q = query(x) # what am I looking for?
k = key(x) # what can I offer?

print(q.shape)
print(k.shape)

torch.Size([4, 8, 16])
torch.Size([4, 8, 16])


Each character embedding (dim=2) produces a query (dim=16) and a key (dim=16).

The query represents what the character is looking for and the key represents what the character can offer.

The query of a particular character *c* is dot-producted with the keys of all other vectors. When the query of *c* and key of another character *d* align (i.e. the dot product is high), that means that *c* is interested in *d*'s value.

We want to calculate these dot products between all the keys and values, so we can know for each character, how much it is interested in all the other characters.

We can use matrix multiplication to calculate all these dot products.

In [44]:
print(q.shape)
print(k.shape)

torch.Size([4, 8, 16])
torch.Size([4, 8, 16])


In [46]:
attention = q @ k.transpose(-2, -1) # (B, T, C) @ (B, C, T) = (B, T, T)
attention.shape

torch.Size([4, 8, 8])

In [47]:
attention[0]

tensor([[ 1.0138, -0.2891,  1.4568,  1.4923,  1.3120, -1.2630,  1.4331,  1.2584],
        [ 0.2317,  0.8337,  0.0379, -0.4659,  0.1128,  0.9699,  0.0516, -0.0364],
        [ 1.2670, -0.6817,  1.9259,  2.1470,  1.7065, -2.0300,  1.8894,  1.6863],
        [ 1.8598, -0.2274,  2.5678,  2.7008,  2.3345, -1.7383,  2.5293,  2.2759],
        [ 1.1707, -0.5648,  1.7582,  1.9210,  1.5634, -1.7876,  1.7259,  1.5333],
        [ 0.0150,  1.9850, -0.6344, -1.6072, -0.4002,  2.8656, -0.5932, -0.6474],
        [ 1.2495, -0.6640,  1.8966,  2.1084,  1.6813, -1.9913,  1.8608,  1.6596],
        [ 1.3425, -0.3456,  1.9152,  2.0192,  1.7266, -1.5699,  1.8841,  1.6779]],
       grad_fn=<SelectBackward0>)

Remember, a character at time *t* shouldn't be able to communicate to a character at time *t-1* or any other character prior to it.

So we still need to mask the attention values.

In [52]:
mask = torch.tril(torch.ones((8, 8)))
attention_masked = attention.masked_fill(mask==0, -torch.inf)
print(attention_masked.shape)
attention_masked[0]

torch.Size([4, 8, 8])


tensor([[0.1079,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.0494, 0.1816,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.1390, 0.0399, 0.1729,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.2514, 0.0629, 0.3285, 0.3951,   -inf,   -inf,   -inf,   -inf],
        [0.1262, 0.0449, 0.1462, 0.1812, 0.2905,   -inf,   -inf,   -inf],
        [0.0397, 0.5743, 0.0134, 0.0053, 0.0408, 0.9808,   -inf,   -inf],
        [0.1366, 0.0406, 0.1679, 0.2185, 0.3268, 0.0076, 0.4942,   -inf],
        [0.1499, 0.0558, 0.1711, 0.1999, 0.3420, 0.0116, 0.5058, 1.0000]],
       grad_fn=<SelectBackward0>)

Now we can apply softmax to normalize.

In [53]:
attention = F.softmax(attention_masked, dim=-1)
attention[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4670, 0.5330, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3401, 0.3080, 0.3519, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2461, 0.2038, 0.2659, 0.2842, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1932, 0.1781, 0.1971, 0.2041, 0.2276, 0.0000, 0.0000, 0.0000],
        [0.1218, 0.2079, 0.1186, 0.1177, 0.1219, 0.3121, 0.0000, 0.0000],
        [0.1326, 0.1204, 0.1368, 0.1439, 0.1603, 0.1165, 0.1895, 0.0000],
        [0.1017, 0.0926, 0.1039, 0.1069, 0.1232, 0.0886, 0.1452, 0.2380]],
       grad_fn=<SelectBackward0>)

Now we have real attention values.

The first character can only pay attention to itself, because it has no previous characters.

The second character has decided to pay a certain amount of attention to itself, and a certain amount of attention to the previous character.

And so on...

The learning of these attention values happens in the learning of query and key vectors. As the model trains, the query and key weights get a sense of what characters are looking for, and what characters can offer. And of course, the results are data dependent, so the character "a" might be looking for constants preceding it, while the character "c" might be looking if there is "k" nearby.

Now we can apply this attention to our input.

In [58]:
attention[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4670, 0.5330, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3401, 0.3080, 0.3519, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2461, 0.2038, 0.2659, 0.2842, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1932, 0.1781, 0.1971, 0.2041, 0.2276, 0.0000, 0.0000, 0.0000],
        [0.1218, 0.2079, 0.1186, 0.1177, 0.1219, 0.3121, 0.0000, 0.0000],
        [0.1326, 0.1204, 0.1368, 0.1439, 0.1603, 0.1165, 0.1895, 0.0000],
        [0.1017, 0.0926, 0.1039, 0.1069, 0.1232, 0.0886, 0.1452, 0.2380]],
       grad_fn=<SelectBackward0>)

In [56]:
x[0]

tensor([[-0.4274, -0.8996],
        [ 0.9806,  0.2727],
        [-0.8925, -1.3135],
        [-1.5466, -0.6571],
        [-0.7258, -1.1945],
        [ 1.6365,  1.5920],
        [-0.8633, -1.2960],
        [-0.8878, -0.9006]])

In [59]:
attended_x = attention @ x
print(attended_x.shape)
attended_x[0]

torch.Size([4, 8, 2])


tensor([[-0.4274, -0.8996],
        [ 0.3231, -0.2748],
        [-0.1573, -0.6841],
        [-0.5821, -0.7018],
        [-0.5647, -0.7901],
        [ 0.2862,  0.0652],
        [-0.3725, -0.6123],
        [-0.4919, -0.6816]], grad_fn=<SelectBackward0>)

As you can see, the first embedding has remained the same, because it has no previous characters. However the next embeddings now have information from all previous embeddings, and they can control how much attention they pay to each previous embedding.

There's just one more thing. Right now the attended values are being computed directly on x, the input. We actually don't do this, instead we compute another vector *v* for each embedding, and then compute attention on this vector.

In [60]:
value = nn.Linear(2, head_size)
v = value(x)
v.shape

torch.Size([4, 8, 16])

In [61]:
out = attention @ v
out.shape

torch.Size([4, 8, 16])

You can think of x as "private" information, it contains the identity of the embedding. For each x, we compute 3 pieces of "public" information:
- query (what am I looking for)
- key (what can I offer)
- value (what I will give to you if we link up)

This is the full code for a single head of self-attention:

In [10]:
B, T, C = 4, 8, 2 # batch, time, channels
x = torch.randn(B, T, C)
print(f"x: {x.shape}")

head_size = 16
query = nn.Linear(C, head_size, bias=False)
key = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

# compute "public" information
q = query(x)
k = key(x)
v = value(x)
print(f"query: {q.shape}")
print(f"key: {k.shape}")
print(f"value: {v.shape}")

# compute dot products
unmasked_attention = q @ k.transpose(-2, -1) # (B, T, head_size) @ (B, head_size, T) = (B, T, T)
print(f"unmasked_attention: {unmasked_attention.shape}")
print("unmasked_attention of first example:")
print(unmasked_attention[0])

# mask cross-attention to produce self-attention
mask = torch.tril(torch.ones(T, T))
self_attention = unmasked_attention.masked_fill(mask==0, -torch.inf)
print(f"self attention: {self_attention.shape}")
print("self-attention of first example:")
print(self_attention[0])
self_attention = F.softmax(self_attention, dim=-1)
print(f"self-attention (softmax) of first example:")
print(self_attention[0])

# compute output
out = self_attention @ v # (B, T, T) @ (B, T, head_size) = (B, T, head_size)
print(f"out: {out.shape}")

x: torch.Size([4, 8, 2])
query: torch.Size([4, 8, 16])
key: torch.Size([4, 8, 16])
value: torch.Size([4, 8, 16])
unmasked_attention: torch.Size([4, 8, 8])
unmasked_attention of first example:
tensor([[-0.1690, -0.5314, -0.2700,  0.0854, -0.2352, -0.2612,  0.2816,  0.0705],
        [-0.3213, -1.0205, -0.4736,  0.1056, -0.3938, -0.5363,  0.4894,  0.1262],
        [-1.1215, -3.4872, -1.9543,  0.7974, -1.7783, -1.5729,  2.0563,  0.5002],
        [ 1.2961,  4.0195,  2.3017, -0.9828,  2.1129,  1.7750, -2.4262, -0.5866],
        [-1.3759, -4.2737, -2.4159,  1.0042, -2.2061, -1.9115,  2.5438,  0.6173],
        [ 0.5834,  1.7951,  1.0930, -0.5234,  1.0273,  0.7424, -1.1579, -0.2753],
        [ 1.2653,  3.9332,  2.2093, -0.9058,  2.0121,  1.7701, -2.3249, -0.5652],
        [ 0.2390,  0.7438,  0.4141, -0.1664,  0.3757,  0.3377, -0.4354, -0.1061]],
       grad_fn=<SelectBackward0>)
self attention: torch.Size([4, 8, 8])
self-attention of first example:
tensor([[-0.1690,    -inf,    -inf,    -inf,  

Thus we get the magical formula:

$$Attention=softmax(\frac{Q\cdot K^T}{\sqrt{d_k}})\cdot V$$

Wait, but what's that thing in the denominator? Why are we dividing by $\sqrt{d_k}$? Note: $d_k$ is the head_size (16 in our case).

Well, as usual with scaling, it comes down to a variance issue.

In [28]:
q = torch.randn((B, T, head_size))
q.var() # unit variance

tensor(1.1369)

In [29]:
k = torch.randn((B, T, head_size))
k.var() # unit variance

tensor(1.0263)

In [30]:
unmasked_attention = q @ k.transpose(-2, -1)
unmasked_attention.var()

tensor(16.6851)

The variance is way to large after the matrix multiplication, and it scales up with the head size. Remember, this unmasked attention will eventually be fed into a softmax. The high variance means there will be some extreme values, which is not good for softmax, at least in the initialization. We generally want a diffused distribution in the beginning, with no extreme peaks. Since the variance scales up with the head size, scaling down by the square root of the head size will restore the variance to unit variance:

In [32]:
(unmasked_attention/(head_size**0.5)).var()

tensor(1.0428)