# NanoGPT

In [1]:
# read it in to inspect it
with open('tiny_shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [2]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  1115394


In [3]:
# let's look at the first 1000 characters
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [4]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [5]:
# create a mapping from characters to integers
stoi = { ch: i for i, ch in enumerate(chars) }
itos = { i: ch for i, ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [6]:
import torch 

In [7]:
# let's now encode the entire text dataset and store it into a torch.Tensor
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

In [8]:
# Let's now split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

len(train_data), len(val_data)

(1003854, 111540)

In [9]:
block_size = 8 # context length
train_data[:block_size+1] 

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [10]:
# In a chunk of 9 characters, there are 8 individual training examples that get sent to the network.
# This also helps the network to see different context sizes, from 1 to block_size. This allows the network to learn how to predict in multiple context lengths.
# Let's take a look. 
X = train_data[:block_size]
Y = train_data[1:block_size + 1]
for t in range(block_size):
    context = X[:t+1]
    target = Y[t]
    print(f"When input is {context}, the target is: {target}")

When input is tensor([18]), the target is: 47
When input is tensor([18, 47]), the target is: 56
When input is tensor([18, 47, 56]), the target is: 57
When input is tensor([18, 47, 56, 57]), the target is: 58
When input is tensor([18, 47, 56, 57, 58]), the target is: 1
When input is tensor([18, 47, 56, 57, 58,  1]), the target is: 15
When input is tensor([18, 47, 56, 57, 58,  1, 15]), the target is: 47
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47]), the target is: 58


In [11]:
# For efficiency purposes, we'll add an extra "batch" dimension to take advantage of GPU parallelism
torch.manual_seed(1337)
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

def get_batch(split):
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) -  block_size, (batch_size, )) # 4 numbers randomly generated between 0 and len(data) -  block_size
    x = torch.stack([data[i: i + block_size] for i in ix])
    y = torch.stack([data[i + 1: i + block_size + 1] for i in ix])
    return x, y

xb, yb = get_batch("train")
print("Inputs:")
print(xb.shape)
print(xb)
print("Targets:")
print(yb.shape)
print(yb)
print("---")

# The total number of examples here would be 8 x 4 = 32 independent examples, packed into a single batch x with targets y, which will all be simultaneously processed by the transformers
for b in range(batch_size): # batch dim
    for t in range(block_size): # time dim (context length)
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f"When input is {context}, the target is: {target}")


Inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
Targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
---
When input is tensor([24]), the target is: 43
When input is tensor([24, 43]), the target is: 58
When input is tensor([24, 43, 58]), the target is: 5
When input is tensor([24, 43, 58,  5]), the target is: 57
When input is tensor([24, 43, 58,  5, 57]), the target is: 1
When input is tensor([24, 43, 58,  5, 57,  1]), the target is: 46
When input is tensor([24, 43, 58,  5, 57,  1, 46]), the target is: 43
When input is tensor([24, 43, 58,  5, 57,  1, 46, 43]), the target is: 39
When input is tensor([44]), the target is: 53
When input is tensor([44, 53]), the target is: 56
When input is tensor([44, 53,

## Simplest version: Bigram Language Model

In [12]:
import torch
import torch.nn as nn
from torch.nn import functional as F

class BigramLanguageModel(nn.Module):
    
    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, idx, targets=None):
        # idx and targets are both (B, T) tensor of integers where B = Batch size (4) and T = Time/Context Size (8)
        logits = self.token_embedding_table(idx) # (B, T, C) where C is channel/embedding dimension (vocab size, 65)
        
        if targets is None:
            loss = None
        else:
            # reshape the logits tensor so that it is (B*T, C) because PyTorch expects the channel dimension to be the second dimension
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
        
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        
        # Generate (B, T + 1), (B, T + 2), ... (B, T + max_new_tokens)
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            
            # focus only on the last time step
            logits = logits[:, -1, :] # Becomes (B, C)
            
            # convert to probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T + 1)
        
        return idx
    
model = BigramLanguageModel(vocab_size)
logits, loss = model(xb, yb)
print(logits.shape)
print(loss)

# First result is random if it's not trained
print(decode(model.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))

torch.Size([32, 65])
tensor(5.0364, grad_fn=<NllLossBackward0>)

lfJeukRuaRJKXAYtXzfJ:HEPiu--sDioi;ILCo3pHNTmDwJsfheKRxZCFs
lZJ XQc?:s:HEzEnXalEPklcPU cL'DpdLCafBheH


In [13]:
# Create an optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [14]:
batch_size = 32
for steps in range(10000):
    
    # sample a batch of data
    xb, yb = get_batch("train")
    
    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    

print(loss.item())

2.362440586090088


In [15]:
# This is a very simple model because we're only looking at the last character to see what the next character will be!
print(decode(model.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=300)[0].tolist()))


M:
IUSh t,
F th he d ke alved.
Thupld, cipbll t
I: ir w, l me sie hend lor ito'l an e

I:
Gochosen ea ar btamandd halind
Aust, plt t wadyotl
I bel qunganonoth he m he de avellis k'l, tond soran:

WI he toust are bot g e n t s d je hid t his IAces I my ig t
Ril'swoll e pupat inouleacends-athiqu heame


## Self-attention: The mathematical trick

In [16]:
# Starting with a toy example
torch.manual_seed(1337)
B, T, C = 4, 8, 2 # batch, time, channels
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

### What do we want? We want tokens to talk to each other (couple-them)

For example: The token at the 5th location should communicate to tokens in previous locations, and NOT from tokens in "future" locations.

Information should only flow in one direction.

How to do this? I'd like to average the features on tokens 1-5, which would represent a summary of the fifth token in the context of its history (tokens 1-4).

A downside is that we'd lose a lot of information by compressing all the tokens like that but we'll worry about that later.

For now, 

```
for every batch:
    for every t-th token in that sequence:
        calculate the average of all the tokens up to the current one.
```




#### Version 0: For-loop

In [17]:
# We want: x[b, t] = mean_{i <= t} x[b_i]
xbow = torch.zeros((B, T, C)) # bow = bag of words, term to mean "averaging". I wish we would call this x_mean haha
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t, C)
        xbow[b, t] = torch.mean(xprev, 0)

#### Version 1: Cool matrix multiplication trick

In [18]:
# but there's a more efficient way of doing this! Cool matrix multiplication trick :) 
# let's see it in an example
torch.manual_seed(42)

# This generates a "triangular" matrix where the lower triangle (including the diagonal) contains ones and the upper one contains zeros
# (Specifically, torch.tril zeros out the upper triangular part of the matrix)
# Why is this useful? Consider the case where you multiply/dot product this matrix A (of shape m x n) by another matrix B (of shape n x k)
# Every column in the resulting matrix will contain:
# C[i, j] = A[i] @ B[j]
# Which translated to our problem, it basically means we have added up all the values of column j of matrix B up to current row i (which would be "t")
a = torch.tril(torch.ones(3, 3))
print(f"{a=}")

# and then we can divide it by the number of rows we're adding up (normalizing so all the rows add up to 1) :O :O genius
# somehow this results in an average. I still need to internalize it but it's cool.
# intuition: averaging ~= normalization? NO! 
# when we divide a sum of numbers by a constant C, that is the same as multiplying each of those numbers by 1/C!!!!!! LMAOOOO and each row in this triangular matrix contains the 1/C "weight" to multiply the list of numbers by. BAM!!
# -> how much is each value contributing to the final average value... the "weight" of each number in matrix B is defined in matrix A!
a = a / torch.sum(a, dim=1, keepdim=True)
print(f"{a=}")
print(f"---")

b = torch.randint(0, 10, (3, 2)).float()
print(f"{b=}")
print(f"---")

c = a @ b
print(f"{c=}")

a=tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
a=tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
---
b=tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
---
c=tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [19]:
# Ok, then let's vectorize our operation: x[b, t] = mean_{i <= t} x[b_i]
# REMEMBER, the goal here is that a token in the t-th position only gets information from all the tokens preceding it! 
weights = torch.tril(torch.ones(T, T))
weights = weights / weights.sum(1, keepdim=True)
print(f"{weights=}")


# (T, T) @ (B, T, C) -- add a batch dim --> (B, T, T) @ (B, T, C) = (B, T, C)
xbow2 = weights @ x 
torch.allclose(xbow, xbow2)

weights=tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


True

#### Version 2: Use the softmax

In [20]:
# Lastly, another way of doing this would be... THE SOFTMAX!!!!!!!
tril = torch.tril(torch.ones(T, T))

# INTERPRETATION PT 1: Weights begin at 0, which represent interaction strength/affinity. How much of each token from the past do we want to aggregate/average up? 
#     At the beginning, all tokens can communicate with all tokens, but each of them has an "affinity" of zero to the current token
weights = torch.zeros((T, T))

# Make all the zeros in the upper triangular part of the matrix to be negative infinity because...
print(f"Tril: \n{tril}")
print(f"Weights before masking: \n{weights}")

# INTERPRETATION PT 2: This is limiting communication to tokens from the past by making all tokens from the future of the current token to have an affinity of negative infinity (So we WON'T aggregate ANYTHING from those tokens).
weights = weights.masked_fill(tril == 0, float("-inf"))
print(f"Weights after masking: \n{weights}")

# ... softmax will convert those to zero, and then it will normalize each row in the lower triangular part of the matrix :O :O :O 
# INTERPRETATION PT 3: This sets the amount of "affinity" that each token from the past will have when they get aggregated/averaged up, and for now each past token has an equal amount of "affinity" to the current token, which is 1/T (where T is the index of the current token)
weights = F.softmax(weights, dim=-1)
print(f"Weights after softmax: \n{weights}")

# INTERPRETATION PT 4: Here is where we aggregate the tokens with the affinity matrix/weights, which will end up being a reflection of how "interesting" each token finds each other (for now, equally interesting)
xbow3 = weights @ x
torch.allclose(xbow, xbow3)


# CONCLUSION: You can do weighted aggregation of your past elements by doing matrix multiplication of a lower-triangular fashion, where the elements of the lower triangular part of the matrix tell you how much of each element "fuses" into the current token/position
# TRIPLE BAM!

Tril: 
tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
Weights before masking: 
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
Weights after masking: 
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0

True

#### Version 3: Another way, more nuanced: Self-attention

What's our actual goal here? Grabbing information from the past... but in a data-dependent way - in other words, don't get information uniformly! Which is what we were doing earlier.


This is what self-attention helps with :)

**But how does self-attention do that?**

Every single token at each position will emit 2 vectors:
- _Query:_ What am I looking for?
- _Key:_ What do I contain?


*Then how do we get "the affinity"?*
- By calculating the dot product between the query vector of the current token and key vectors of this and other tokens. This becomes `weights`. Interpretation - if 2 tokens "align" strongly, then we'll get to learn more about that other token as opposed to learning about other tokens equally.


*For example*

- Let's say that we're the 8th token, and we're a vowel looking for a consonant up to the 4th position. 
- Then we "encode" that information in our query vector. 
- In the key vectors that the other tokens emit, we could imagine one token "encoding" the fact that it is a consonant and that it comes before the 4th position, in one of the channels, as represented by a high number in that specific channel.
- When we get the dot product between the query and key vectors for the 8th token and that other token, we'll see a high affinity represented by the "high" number resulting in the dot product, more than other keys from other tokens. In the weights vector here, we can see that the 8th token found the 4th token pretty interesting and therefore had a large affinity (`0.2297`) for it:
```
[0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]
```
- Ultimately, when the softmax comes, I'll end up including a lot of the information from that token in the aggregate of information.



But what is this "value" vector we see? Think of the original X as the "private" information about a token, kept in X. The final vector:
- _Value:_ What will I communicate to you, if you find me "interesting"?

The value is what gets aggregated, not X!

In [32]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32 # Batch, time, channels
x = torch.randn(B, T, C)

# Time to build a single head that performs self-attention
head_size = 16 # hyperparameter
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)
weights = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) -> (B, T, T)


tril = torch.tril(torch.ones(T, T))
# weights = torch.zeros((T, T))
weights = weights.masked_fill(tril == 0, float("-inf"))
weights = F.softmax(weights, dim=-1)

v = value(x)
out = weights @ v 
# out = weights @ x

out.shape

torch.Size([4, 8, 16])

In [33]:
weights[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

### Notes on attention:

**1. Attention is a communication mechanism**

Given a set of nodes in a directed graph, where every node has a vector of information, and it gets to aggregate information via a weighted sum from all of the nodes that point to it - where the criteria for aggregation will be based on the data. In our example above, the graph would look like this: 
- The first token node has an edge to itself. 
- The second token node has an edge pointing to itself, and an edge from the first token node pointing to it. 
- The third node has an edge pointing to itself and 2 edges pointing to it: One coming from the first node and one from the second
- And so on, until the 8th token node.

But in principle, "attention" can be applied to any directed graph-like structure similar to this!

**2. There is no notion of space**

Attention acts over a set of vectors in the graph, without any knowledge of "where" each node is in the graph. This is why we use positional encodings! This is a crucial to other mechanisms like a convolution, where it's clear how the filter acts in "space".

**3. Each example across batch dimension is of course processed completely independently and never "talk" to each other**

The matrix multiplications happen independly of each other. For our directed graph analogy, this means we have 4 different pools of 8-node directed graphs that only talk to each other.

**4. In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.**

Future tokens don't talk to past tokens! But if you want to include information from the future and allow nodes to talk to each other (encoder), don't use the masking! 
- Encoder blocks: All nodes talk to each other. 
- Decoder blocks: Triangular masking to avoid "looking at the answer" when making a prediction!

**5. "self-attention" just means that the keys, queries and values are produced from the same source X. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)**

But in Encoder/Decoder transformers, Keys and Values can come from another source and queries come from X. So it's self-attention if we only look at our set of nodes, but it's cross-attention if we look at other nodes and grab information from them through the key/value vectors. The attention from our example is basically "self-attention".

**6."Scaled" attention additional divides `weights` by `1/sqrt(head_size)`. This makes it so when input `Q, K` are unit variance, `weights` will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below**


- The variance of `weights` without scaling will be on the order of `head_size` (in our case, around 16)
- But the variance of `weights` with scaling will be 1!

This is important because `weights` get fed to softmax, and it is important that it stays "diffuse" (not very positive nor very negative numbers in it). *What happens if it contains very positive or very negative numbers?* **It will end up converging to one-hot vectors!**. Which means that every node will only be aggregating information from **a single node**, which is not ideal specially at initialization.

In [49]:
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)

weights = q @ k.transpose(-2, -1) 
weights_with_scaling = q @ k.transpose(-2, -1) * head_size**-0.5

In [48]:
print(f"{k.var()=}")
print(f"{q.var()=}")
print(f"{weights.var()=}")
print(f"{weights_with_scaling.var()=}")

k.var()=tensor(1.0966)
q.var()=tensor(0.9416)
weights.var()=tensor(16.1036)
weights_with_scaling.var()=tensor(1.0065)


In [60]:
# Weights that are diffuse
w1 = torch.tensor([0.1, 0.2, 0.3, -0.2, 0.5])
print(f"{w1}")
print(f"{torch.softmax(w1, dim=-1)=}")

tensor([ 0.1000,  0.2000,  0.3000, -0.2000,  0.5000])
torch.softmax(w1, dim=-1)=tensor([0.1799, 0.1988, 0.2197, 0.1333, 0.2684])


In [61]:
# Weights that are NOT diffuse and have large values
w2 = w1*8
print(f"{w2}")
print(f"{torch.softmax(w2, dim=-1)=}")

tensor([ 0.8000,  1.6000,  2.4000, -1.6000,  4.0000])
torch.softmax(w2, dim=-1)=tensor([0.0305, 0.0678, 0.1510, 0.0028, 0.7479])


In [65]:
# Weights with even larger values
w3 = w1*32
print(f"{w3}")
print(f"{torch.softmax(w3, dim=-1)=}")

tensor([ 3.2000,  6.4000,  9.6000, -6.4000, 16.0000])
torch.softmax(w3, dim=-1)=tensor([2.7560e-06, 6.7612e-05, 1.6587e-03, 1.8666e-10, 9.9827e-01])


### More optimizations to the transformer block

#### LayerNorm

First, recall that `BatchNorm` is a mechanism to ensure that any individual neuron had unit gaussian distribution (0 mean, 1 std output) across the batch dimension.

In [68]:
class BatchNorm1d:

    def __init__(self, dim, eps=1e-5, momentum=0.1):
        self.eps = eps
        self.momentum = momentum
        self.training = True
        # parameters (trained with backprop)
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)
        # buffers (trained with a running 'momentum update')
        self.running_mean = torch.zeros(dim)
        self.running_var = torch.ones(dim)
  
    def __call__(self, x):
        # calculate the forward pass
        if self.training:
            xmean = x.mean(0, keepdim=True) # batch mean
            xvar = x.var(0, keepdim=True) # batch variance
        else:
            xmean = self.running_mean
            xvar = self.running_var
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
        self.out = self.gamma * xhat + self.beta
        # update the buffers
        if self.training:
            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar
        return self.out
  
    def parameters(self):
        return [self.gamma, self.beta]
    
torch.manual_seed(1337)
module = BatchNorm1d(100)
x = torch.randn(32, 100) # Batch size: 32, 100-dimensional vectors
x = module(x)
x.shape

torch.Size([32, 100])

In [70]:
x[:, 0].mean(), x[:, 0].std() # mean and std of one feature across all batch inputs

(tensor(1.4901e-08), tensor(1.0000))

In [71]:
x[0, :].mean(), x[0, :].std() # mean and std of the features of a single input from the batch

(tensor(0.0411), tensor(1.0431))

So, what's the difference between `BatchNorm` and `LayerNorm`? Well, instead of normalizing across all batches (the column dimension), we... normalize the row dimension: for every individual example, its 100-dimensional vector will be normalized. And that's pretty much it. LOL.

This means the computation doesn't span across examples anymore!

In [75]:
class LayerNorm:

    def __init__(self, dim, eps=1e-5):
        self.eps = eps
        # parameters (trained with backprop)
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)
  
    def __call__(self, x):
        # calculate the forward pass
        xmean = x.mean(1, keepdim=True) # batch mean
        xvar = x.var(1, keepdim=True) # batch variance
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
        self.out = self.gamma * xhat + self.beta
        return self.out
  
    def parameters(self):
        return [self.gamma, self.beta]
    
torch.manual_seed(1337)
module = LayerNorm(100)
x = torch.randn(32, 100) # Batch size: 32, 100-dimensional vectors
x = module(x)
x.shape

torch.Size([32, 100])

In [76]:
x[:, 0].mean(), x[:, 0].std() # mean and std of one feature across all batch inputs

(tensor(0.1469), tensor(0.8803))

In [77]:
x[0, :].mean(), x[0, :].std() # mean and std of the features of a single input from the batch

(tensor(-3.5763e-09), tensor(1.0000))

BAM!

## The Transformer and Machine Translation

**But why does the Attention Is All You need paper an Encoder + Decoder Transformer?**

*Because we're concerned with a different task: Condition text generation on additional information other than itself, as opposed to an autoregressive model.*

```
# <--------- ENCODE ------------------><--------------- DECODE ----------------->
# les réseaux de neurones sont géniaux! <START> neural networks are awesome!<END>
```

- The cross-atention component of this architecture is essential - the keys and values are coming from the encoder (french sentence), NOT just the decoder. The queries are still generated from X though!
- Conditioning the decoding not just on the past of the current decoding but also on the fully encoded french sentence.