# Train a Transformer to Write Shakespeare

This tutorial notebook is based on [this video](https://www.youtube.com/watch?v=kCc8FmEb1nYhttps://www.youtube.com/watch?v=kCc8FmEb1nY).

## Imports

In [1]:
import torch
import torch.nn as nn
import numpy as np

from torch.nn import functional as F

## Download and Explore the Dataset

In [2]:
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [3]:
with open('input.txt', 'r') as f:
    text = f.read()

In [4]:
print(f'Length of dataset in characters: {len(text):,}')

Length of dataset in characters: 1,115,394


In [5]:
# Inspect the first 500 characters

print(text[:500])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor


In [6]:
# Extract all unique characters in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)

print(f'There are {vocab_size} unique characters in the text:')
print(''.join(chars))


There are 65 unique characters in the text:

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


## Create a Tokenizer

Tokenizing is to put the text in a language the computer can understand, in this case a list of characters. We're doing something really simple, just assigning an integer to each of the possible 65 characters in the text.

First we build the mapping ourselves by creating a lookup table in both directions, string to integer (`stoi`) and vice versa (`itos`).

In [7]:
# String to Integer Mapping
stoi = {ch:i for i, ch in enumerate(chars)}

# Integer to String Mapping
itos = {i:ch for i, ch in enumerate(chars)}

Now write two quick functions, one to encode (string to int) and one to decode.

In [8]:
# Create quick functions using lambda (lambda is used for one-liners like this)
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

Now test:

In [9]:
print(encode("hii there!"))
print(decode(encode("hii there!")))

[46, 47, 47, 1, 58, 46, 43, 56, 43, 2]
hii there!


And test on the Shakespeare text

In [10]:
# Test encoding on first 15 characters of text
encoded_txt = encode(text[:15])
encoded_txt

[18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0]

In [11]:
# Now decode it
decode(encoded_txt)

'First Citizen:\n'

There are many different ways to build a tokenizer, we just did a very simple one. For example, Google created one called [SentencePiece](https://github.com/google/sentencepiece). It tokenizes at the sub-word level, which is usually adopted in practice for LLMs.

## Tokenize the Text

We're going use our basic encoder to encode the entire text and store it into a `torch.Tensor` object.

In [12]:
data = torch.tensor(encode(text), dtype=torch.long)
data

tensor([18, 47, 56,  ..., 45,  8,  0])

In [13]:
print(data.shape, data.dtype)

torch.Size([1115394]) torch.int64


## Split into Train/Validation Datasets

90/10 split. This will help us understand how much our model is overfitting by hiding 10% of the data.

In [14]:
n = int(0.9*len(data))  # Determine cut point of 90%
train_data = data[:n]  # First 90% is train
val_data = data[n:]  # Last 10% is validation

## Concept of Block Size

Feeding everything into the model would be computationally expensive and prohibitive. So we are going to separate it out into chunks of 8. But because we're interested not just in the character but the relationship from character to character, we're going to look at `block_size + 1`. This way we can go to up to 8 characters and see the following character.

In [15]:
block_size = 8
train_data[:block_size + 1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

There are 8 examples of data embedded here. There's the first (18), which is followed by 47. There's the first AND second (18, 47) and followed by 56. We're going to refer to this dimension as the **time dimension**. Illustrating this with code:

In [16]:
x = train_data[:block_size]
y = train_data[1:block_size + 1]

for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f'When input is {context} the target is: {target}')

When input is tensor([18]) the target is: 47
When input is tensor([18, 47]) the target is: 56
When input is tensor([18, 47, 56]) the target is: 57
When input is tensor([18, 47, 56, 57]) the target is: 58
When input is tensor([18, 47, 56, 57, 58]) the target is: 1
When input is tensor([18, 47, 56, 57, 58,  1]) the target is: 15
When input is tensor([18, 47, 56, 57, 58,  1, 15]) the target is: 47
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target is: 58


## Concept of Minibatches

For effeciency we're going to batch multiple blocks together in batches (also called *minibatches*). They are processed separately and don't talk to each other. GPUs are really good at parellelizing. This is the **batch dimension**.

In [17]:
torch.manual_seed(1337)  # Now we're setting a random seed because we're grabbing random samples in the data

# Parameters
batch_size = 4  # B: How many independent sequences will we process in parallel?
block_size = 8  # T: What is the maximum context length for predictions?

def get_batch(split):
    # Generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))  # create 4 random starting places in the text
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [18]:
xb, yb = get_batch('train')
print('== Inputs ==')
print(xb.shape)
print(xb)

print('== Targets ==')
print(yb.shape)
print(yb)

print('========================')
for b in range(batch_size):  # Batch dimension
    for t in range(block_size):  # Block dimension
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f'When input is {context.tolist()} the target is: {target}')

== Inputs ==
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
== Targets ==
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
When input is [24] the target is: 43
When input is [24, 43] the target is: 58
When input is [24, 43, 58] the target is: 5
When input is [24, 43, 58, 5] the target is: 57
When input is [24, 43, 58, 5, 57] the target is: 1
When input is [24, 43, 58, 5, 57, 1] the target is: 46
When input is [24, 43, 58, 5, 57, 1, 46] the target is: 43
When input is [24, 43, 58, 5, 57, 1, 46, 43] the target is: 39
When input is [44] the target is: 53
When input is [44, 53] the target is: 56
When input is [44, 53, 56] the target is: 1
When input is [44, 53, 56, 1] the target is: 58
When input is [44, 53, 56, 1,

## Feed into a Neural Network

Now we know what kind of input we want to feed to our transformer, it's an 8x4 tensor. We'll start with the simplest type of NN, a bigram languge model.

### Simplest Bigram Language Model

[Video Chapter](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=1331s)

In [19]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # Each token directly reads off the logits for the next token from a
        # lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, idx, targets):
        # idx and targets are both (B, T) tensor of integers
        logits = self.token_embedding_table(idx)  # (BxTxC dimensions or 4x8x65)
        
        return logits
        

In [20]:
m = BigramLanguageModel(vocab_size)
out = m(xb, yb)  # Feed a minibatch into it

print(out.shape)

torch.Size([4, 8, 65])


So what is this doing? We're feeding our minibatch into it. See previous section to remind what `xb` and `yb` are. The embedding table is 65x65. As the `forward` method runs, every integer is going to pluck a row of the embedding table corresponding to its index. Then Pytorch arranges it as a *Batch x Time x Channel* tensor. So we have a 65-length tensor for every one of the 4 x 8 positions in our input set. 

The output are the `logits` or the scores for the next characters in the sequence. We're predicting what comes next based *only* on the location of the token in the sequence. Ie what are the odds of each character coming up in position 5 in the 1st minibatch?

We can figure that out!

First, use `torch.sort()` to sort the embedding table of the fifth position of the first minibatch (remember zero-indexing). Then recall the top 5 indices. Looks like our most likely character is 22. We can remind ourselves which character is index 22 in our vocab list. This really isn't going to be very accurate, we're not measuring loss, we're not optimizing anything. It's about as good as a random guess, as we'll see soon.

In [21]:
sorted, indices = torch.sort(out[0][4], descending=True)

print(sorted[:5])
print(indices[:5])
print(f'Index {indices[0]} is: "{chars[indices[0]]}"!')


tensor([2.1471, 1.9577, 1.3954, 1.1822, 1.1564], grad_fn=<SliceBackward0>)
tensor([10, 55, 44, 64,  4])
Index 10 is: ":"!


### Add Loss

[Video Chapter](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=1500s)

Now that we've made predictions about what comes next, let's evaluate them with a **loss function**. A good way to measure the loss, or quality of the prediction is to use *negative log likelihood loss*, which is implemented in PyTorch under the function `cross_entropy()`. Loss is the cross entropy of the predictions and the targets. We measure the quality of the logits with respect to the targets. We have the identity of the next character, so we want to know how well we're predicting it.

In [22]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # Each token directly reads off the logits for the next token from a
        # lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, idx, targets):
        # idx and targets are both (B, T) tensor of integers
        logits = self.token_embedding_table(idx)  # (BxTxC dimensions or 4x8x65)
        
        # Loss is cross entropy, we have to reshape the input though per documentation
        # It wants B x C x T rather than B x T x C, so we'll reshap logits
        B, T, C = logits.shape # unpack
        logits = logits.view(B*T, C)  # Stretch them out into 1-dim sequence, C is 2nd dim
        
        # Also need to do same to targets
        targets = targets.view(B*T)
        
        loss = F.cross_entropy(logits, targets)
        
        return logits, loss

In [23]:
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)  # Feed a minibatch into it

print(logits.shape)

torch.Size([32, 65])


It should be obvious now how we "stretched" logits to make it work for the cross_entropy function. Instead of 4 x 8 x 65 we have 32 x 65. Let's look at the first one, which would have corresponded to the first index of the first minibatch. It's a 65-len tensor giving all my logits. I can again sort it to figure out the prediction.

In [24]:
print(logits[0])

# Sort
sorted, indices = torch.sort(logits[0], descending=True)

# Output Results
print(sorted[:5])
print(indices[:5])
print(f'Index {indices[0]} is: "{chars[indices[0]]}"!')


tensor([ 1.4849, -0.8661, -1.0219, -1.1038, -0.3763, -0.2251, -0.2825,  0.4622,
        -0.7136, -1.4831,  1.3023, -1.2810,  1.2857,  0.8153, -1.4935, -0.1127,
         0.7719,  2.9956,  0.2084, -1.6306,  1.4533, -1.1483,  0.7007,  1.2882,
         0.7806,  1.2904,  0.0471,  1.4801, -0.6316, -1.1766, -1.5717,  0.5684,
         1.2815,  0.4047,  0.0632,  0.5846, -0.1738,  0.8185, -0.5315, -0.7415,
         0.6128,  0.9535, -0.0584, -0.4370,  0.2026, -0.8318, -0.1020,  0.9157,
        -0.6446, -0.5180,  0.8405, -1.3159,  0.0663, -0.7541,  0.7109, -0.3921,
        -1.4153, -0.0123,  0.2143,  1.5742, -1.7377,  0.9368,  0.1410,  1.5414,
         0.7376], grad_fn=<SelectBackward0>)
tensor([2.9956, 1.5742, 1.5414, 1.4849, 1.4801], grad_fn=<SliceBackward0>)
tensor([17, 59, 63,  0, 27])
Index 17 is: "E"!


In [25]:
loss

tensor(4.7288, grad_fn=<NllLossBackward0>)

Our loss is stated in the `loss` tensor, output in the previous cell. This will change at each run but when I wrote this it was 4.7288. We can calculate the expected loss as the negative natural logarithm (hey look, NEGATIVE LOG LIKELIHOOD!?) of 1 over the number of dimensions. In our case that is $-{\ln}\frac{1}{65} = 4.174$. The fact that our loss is higher than expected means our model is NOT diffuse. We have entropy, we're guessing wrong. We want much lower loss than the expected value.

### Add Generation

[Video Chapter](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=28m50s)

Still, despite our garbage model, we're still going to try some **generation**! Lets add that to our model now.

In [26]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # Each token directly reads off the logits for the next token from a
        # lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, idx, targets=None):
        # idx and targets are both (B, T) tensor of integers
        logits = self.token_embedding_table(idx)  # (BxTxC dimensions or 4x8x65)
        
        if targets is None:
            loss = None
        else:
            # Loss is cross entropy, we have to reshape the input though per documentation
            # It wants B x C x T rather than B x T x C, so we'll reshap logits
            B, T, C = logits.shape # unpack
            logits = logits.view(B*T, C)  # Stretch them out into 1-dim sequence, C is 2nd dim

            # Also need to do same to targets
            targets = targets.view(B*T)

            loss = F.cross_entropy(logits, targets)
        
        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indeces in the current context. The job of generate is to 
        # take the B x T array and extend it to be B x T +1, +2, +3, etc. up until
        # max_new_tokens is reached.
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)  # For now we ignore loss, we're only using logits

            # focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)

            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)

            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)

            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx

In [27]:
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)  # Feed a minibatch into it

print(logits.shape)
print(loss)

torch.Size([32, 65])
tensor(4.2793, grad_fn=<NllLossBackward0>)


In [28]:
# Create a 1x1 tensor holding a zero to kick off the generation
# Remember a 0 corresponds to a newline (\n) character
idx = torch.zeros((1, 1), dtype=torch.long)
idx

tensor([[0]])

In [29]:
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))


MufR3!yC NHKGY?C3eDEGMp-VCcZN3aoRN&powfXmB.EOX.aFl3sitGj&WcGpoEfRvHLbR?
JxHofRX,MccAWBTF.ALF$mQ;!GY:


### WHAT!? This is Garbage!
[Video Link](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=2093s)

It's garbage because this model was never trained, not at all. Now let's actually train it.

In [30]:
# Create a pytorch optimization object using AdamW optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

The optimizer we just created will take the gradients and update the parameters using the gradients. Another possible optimizer would be Stochastic Gradient Descent (`torch.optim.SDG`).

Now let's set it all up and run with just 10 iterations and see how much loss we get. Now we're actually training the model, we're feeding the data into our model and updating parameters using our `optimizer` object.

In [31]:
batch_size = 32  # Use a bigger batch size now

for steps in range(10):
    # sample a batch of data from train
    xb, yb = get_batch('train')

    # Evaluate the loss using a typical training loop
    logits, loss = m(xb, yb)  # Evalueate loss
    optimizer.zero_grad(set_to_none=True)  # Zero gradeints from prev ste
    loss.backward()  # Get gradients for all params
    optimizer.step()  # Use gradients to update parameters

    print(loss.item())

4.6318359375
4.627881050109863
4.483257293701172
4.629354476928711
4.51035213470459
4.547611713409424
4.6271467208862305
4.50986385345459
4.6163458824157715
4.603818893432617


It's getting lower, but lets train more next. We'll do 10k iterations, and only print loss at the end

In [32]:
batch_size = 32  # Use a bigger batch size now

for steps in range(10000):
    # sample a batch of data from train
    xb, yb = get_batch('train')

    # Evaluate the loss using a typical training loop
    logits, loss = m(xb, yb)  # Evalueate loss
    optimizer.zero_grad(set_to_none=True)  # Zero gradeints from prev ste
    loss.backward()  # Get gradients for all params
    optimizer.step()  # Use gradients to update parameters

print(f'final loss: {loss.item()}')

final loss: 2.52305006980896


Our loss is now greatly reduced. Now lets get predictions again. It will look a little better, but it's not shakespeare yet! Not even readable.

In [33]:
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=250)[0].tolist()))


Dugomyof y, beche iberJubs R; thathagutth sedy t jPcant my thes hararis?
Parared u, waviteavedo winegot Wied fon:
LIUD wigotho f beaind ailit bhill spanone w;
Thad f yor m h avingease'r prG t laresothord, s d thept en, char pou;
te'd is!pe; tadsplout


We could even train 10k more times and see if that helps, probably not though, because we've likely reached some kind of maximum given the simplicity of our model.

In [34]:
batch_size = 32  # Use a bigger batch size now

for steps in range(10000):
    # sample a batch of data from train
    xb, yb = get_batch('train')

    # Evaluate the loss using a typical training loop
    logits, loss = m(xb, yb)  # Evalueate loss
    optimizer.zero_grad(set_to_none=True)  # Zero gradeints from prev ste
    loss.backward()  # Get gradients for all params
    optimizer.step()  # Use gradients to update parameters

print(f'final loss: {loss.item()}')

print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=250)[0].tolist()))

final loss: 2.405074119567871


LUCisurl mar ad
Whith?
Whewid m S:

Hurt ges d at fu.
Ance y h y Wesomy MILOLLAnouceand or at is des
A: buthivequchen thy? wit ay t hinsofo youe r ar p gromig: twiglillldica dritond
NBusthouth:
Whaf be ssothary:
E ou sour'dod

In AKINUSuseaze.

Mas 


Yeah... training this over and over isn't helping, there's just only so much such simple architecture can do!

# Next... Increase complexity of the model

Our model up to this point has been incredibly simple, the tokens aren't talking to each other. Though we're training on blocks of up to 8, we're only using the last character to figure out what comes next. Now we're going to get these tokens talking to each other in a more optimized way!

## The mathematical trick in self-attention

[Video Link](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=2533s)

In [35]:
# Consider the following toy example with random data in a 4x8x2 array:

torch.manual_seed(1337)

B, T, C = 4, 8, 2  # batch, time, channels
x = torch.randn(B, T, C)
# print(x)

Currenty the 8 tokens in a batch (aka time) are not talking to each other. We want them to talk to each other. But we only want tokens talking to the token that comes before them. AKA, token in position 5 can't see tokens 6, 7, 8 but can see tokens 1-4. We want to do a rolling average but not allow tokens to look into the future. A rolling average isn't great, we lose information but that's fine for this example.

So for every *nth* step, we want to capture avereage of the *nth* and previous steps.

### Version 1: For Loop

In [36]:
# We want x[b,t] = mean_{i<=t} x[b,i]

# bow = bag of words, there's a word stored at every one of the T positions
xbow = torch.zeros((B, T, C))

for b in range(B):
    # print('======')
    # print(b)
    for t in range(T):
        # print(f'  ->{t}')
        xprev = x[b, :t+1]
        # print(f'    {xprev}')
        xbow[b,t] = torch.mean(xprev, 0)
        
print(x.shape)
print(xbow.shape)

torch.Size([4, 8, 2])
torch.Size([4, 8, 2])


#### Look at the output!
Now it starts to make sense. Let's start by comparing `x[0]` to `xbow[0]`. The first line is equal because it's just an average of itself, but the second line is now the average of the two because $(0.1808 + -0.3596) / 2 = -0.0894$ and $(-0.0700 + -0.9152) / 2 = -0.4926$. Then the third position is the average of the three, and so on. The last element is the vertical average of all the elements.

In [37]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [38]:
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

### Version 2: Matrix Multiplication
[Video Link](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=2831s)

This is great BUT for loops are inefficient. Would be much better to do with Matrix multiplication. Lets look at an example:

In [39]:
torch.manual_seed(42)

# Create a 3x3 tensor of all ones
a = torch.ones(3, 3)

# Create a random 3x2 tensor 
b = torch.randint(0,10,(3,2)).float()

# Compute dot product of two matrices
c = a @ b

print(a)
print(b)
print(c)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


The **dot product** in position 1 of tensor `c` (value of 14) comes by multiplying the first **row** of tensor `a` by the first **column** of tensor `b`. Since `a` is all ones it ends up just being a sum of `a` column 1. See here for [dot product refresher](https://www.khanacademy.org/math/precalculus/x9e81a4f98389efdf:matrices/x9e81a4f98389efdf:multiplying-matrices-by-matrices/a/multiplying-matrices).

But `torch` has another great tool called `torch.tril()` that gives just the lower triangular values of a tensor and zeroes out the rest. Let's try that:

In [40]:
torch.manual_seed(42)

# Create a 3x3 tensor of all ones, wrap in tril to zero out the top right and give running sums
a = torch.tril(torch.ones(3, 3))

# Create a random 3x2 tensor 
b = torch.randint(0,10,(3,2)).float()

# Compute dot product of two matrices
c = a @ b

print(a)
print(b)
print(c)

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


When we do this we end up with a **running sum**! To make them averages only takes a little more effort:

In [41]:
torch.manual_seed(42)

# Create a 3x3 tensor of all ones
a = torch.tril(torch.ones(3, 3)) 
a = a / torch.sum(a, 1, keepdim=True)

# Create a random 3x2 tensor 
b = torch.randint(0,10,(3,2)).float()

# Compute dot product of two matrices
c = a @ b

print(a)
print(b)
print(c)

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


Now we can vectorize this and scale to our previous example with `x` and `xbow`:

In [42]:
# Call it wei (short for weights)
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)

# New version of xbow, to do all the Matrix multiplication
xbow2 = wei @ x
xbow2[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

### Version 3: Softmax
[Video Link](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=3282s)

Softmax is a normalization operation, so we get the same matrix. The reason we want to use *this one* is because the weights will start with zero but eventually the relationships will not be constant at zero but will be data dependent. Now we'll start to call these **affiinities**, some values find others more interesting and that is the **basis of attention**.

In [102]:
# Call it wei (short for weights)
tril = torch.tril(torch.ones(T, T))  # create same tril matrix and set aside
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [106]:
wei = torch.zeros((T, T))  # Initialize wei as all zeros
wei = wei.masked_fill(tril == 0, float('-inf'))  # replace zeros w/ -inf, this is how we keep the past from seeing the future
wei

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

If we do a softmax across every row of `wei`, what does it do? Softmax is a normalization operation, it *exponentiates* each element along a row ($e^x$), and then divides by the sums of those exponents. $e^0 = 1$ so we get 1 every time the `tril` tensor has a zero. Likewise $e^\text{- inf} = 0$.

With this in mind, we can calculate each row below by altering the `wei_idx` variable in the following cell:

In [108]:
wei_idx = 1

print(f'Row {wei_idx} of wei is:                {wei[wei_idx]}')
print(f'Exponentiated that row becomes: {torch.exp(wei[wei_idx])}')
print(f'The sum of the exponents is:    {torch.sum(torch.exp(wei[wei_idx]))}')
print(f'So softmax is:                  {torch.exp(wei[wei_idx]) / torch.sum(torch.exp(wei[wei_idx]))}')

Row 1 of wei is:                tensor([0., 0., -inf, -inf, -inf, -inf, -inf, -inf])
Exponentiated that row becomes: tensor([1., 1., 0., 0., 0., 0., 0., 0.])
The sum of the exponents is:    2.0
So softmax is:                  tensor([0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000])


Fortunately pytorch includes a softmax function, we just need to tell it along which dimension to calculate it and we can calculate softmax for each of our 8 rows.

In [109]:
# Doing a softmax along every row (because dim = -1), softmax is a regularization
# function so it does the same as in version 2.
wei = F.softmax(wei, dim=-1)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

And with that creatively devised mask in place, we can find the dot product and once again we get our nice rolling average.e

In [113]:
wei = wei / wei.sum(1, keepdim=True)

# New version of xbow, to do all the Matrix multiplication
xbow3 = wei @ x
xbow3[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

#### Why do we care about softmax and all these different methods of doing a rolling average!?
The reason we'll use this in self-attention (next), is because we will start changing the affinities. Up until now we've made them all zero (we initialized `wei` with `torch.zeros()`). Using softmax, we can do weighted self-attention.



### Version 4: Self-Attention!
[Video Link](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=3720s)

In [None]:
torch.manual_seed(1337)

B, T, C = 4, 8, 32  # Now C is 32
x = torch.randn(B, T, C)

Previously we initiliazed `wei` with all zeros. We need to change this, because we don't want all zeros because some tokens will have more/less affinity for others. If we initialize as zero then `wei` has uniform rows, but that's not what we want in real life. Some tokens will find other tokens more/less interseting and we want this to be data-dependent.  Ex., a vowel might be looking for consanants in past, we want to know what they are and flow to the next token.

How we solve: Every token at each position will emit TWO VECTORS. A query and a key vector: What am I looking for? Key vector: What do I contain. The affinities come by doing a dot product with keys and queries. **That dot product becomes `wei`.** If key/query are aligned, they interact to a higher degree.

#### Let's see a single head perform self-attention



In [None]:
head_size = 16  # Hyperparameter for head size
key = nn.Linear(C, head_size, bias=False)  # Initialize linear modules with bias=False to avoid getting fixed weights
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

# Produce k and q by forwarding key and query on x
k = key(x)  # (B, T, 16)
q = query(x)  # (B, T, 16)

print(k[0])
print(q[0])

When we forwarded `key` and `query` on `x`. Each token in all positions in BxT arrangement produce a key and query. No communication has happened yet. That happens now. All the queries will dot product with each of the keys.

In [None]:
# Need to transpose to the right dimension before dot product. What we get is:
# (B, T, 16) @ (B, 16, T) ---> (B, T, T)
wei = q @ k.transpose(-2, -1)

In [None]:
# Now add the other steps in to remove the top triangle of data (otherwise we're leaking data)
tril = torch.tril(torch.ones(T, T))  # create same tril matrix and set aside
wei = wei.masked_fill(tril == 0, float('-inf'))  # replace zeros w/ -inf, this is how we keep the past from seeing the future
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v

out.shape

In [None]:
# For every row of B we'll have a T^2 matrix giving us the weight.
# The weights are no longer uniform like they were before but indicate
# affinities between tokens.
wei[0]

## NB -> Script Conversion

This will be easier if we work from a script instead of a notebook. Use version control to see the various steps we took to go from a basic implementation to something much more fully featured. 