# Train a Transformer to Write Shakespeare

This tutorial notebook is based on [this video](https://www.youtube.com/watch?v=kCc8FmEb1nYhttps://www.youtube.com/watch?v=kCc8FmEb1nY).

## Imports

In [1]:
import torch
import torch.nn as nn

from torch.nn import functional as F

## Download and Explore the Dataset

In [2]:
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [3]:
with open('input.txt', 'r') as f:
    text = f.read()

In [4]:
print(f'Length of dataset in characters: {len(text):,}')

Length of dataset in characters: 1,115,394


In [5]:
# Inspect the first 1k characters

print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [6]:
# Extract all unique characters in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)

print(f'There are {vocab_size} unique characters in the text:')
print(''.join(chars))


There are 65 unique characters in the text:

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


## Create a Tokenizer

Tokenizing is to put the text in a language the computer can understand, in this case a list of characters. We're doing something really simple, just assigning an integer to each of the possible 65 characters in the text.

In [7]:
# String to Integer Mapping
stoi = {ch:i for i, ch in enumerate(chars)}

# Integer to String Mapping
itos = {i:ch for i, ch in enumerate(chars)}

In [8]:
# Create quick functions using lambda (lambda is used for one-liners like this)
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

In [9]:
# Test encoding on first 15 characters of text
encoded_txt = encode(text[:15])
encoded_txt

[18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0]

In [10]:
# Now decode it
decode(encoded_txt)

'First Citizen:\n'

## Tokenize the Text

We're going to encode the entire text and store it into a `torch.Tensor` object.

In [11]:
data = torch.tensor(encode(text), dtype=torch.long)
data

tensor([18, 47, 56,  ..., 45,  8,  0])

In [12]:
print(data.shape, data.dtype)

torch.Size([1115394]) torch.int64


## Split into Train/Validation Datasets

90/10 split. This will help us understand how much our model is overfitting by hiding 10% of the data.

In [13]:
n = int(0.9*len(data))  # Determine cut point of 90%
train_data = data[:n]  # First 90% is train
val_data = data[n:]  # Last 10% is validation

## Concept of Block Size

Feeding everything into the model would be computationally expensive and prohibitive. So we are going to separate it out into chunks of 8. But because we're interested not just in the character but the relationship from character to character, we're going to look at `block_size + 1`. 

In [14]:
block_size = 8
train_data[:block_size + 1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

Illustrating this with code:

In [15]:
x = train_data[:block_size]
y = train_data[1:block_size + 1]

for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f'When input is {context} the target is: {target}')

When input is tensor([18]) the target is: 47
When input is tensor([18, 47]) the target is: 56
When input is tensor([18, 47, 56]) the target is: 57
When input is tensor([18, 47, 56, 57]) the target is: 58
When input is tensor([18, 47, 56, 57, 58]) the target is: 1
When input is tensor([18, 47, 56, 57, 58,  1]) the target is: 15
When input is tensor([18, 47, 56, 57, 58,  1, 15]) the target is: 47
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target is: 58


## Concept of Minibatches

For effeciency we're going to batch multiple blocks together in batches. GPUs are really good at parellelizing.

In [16]:
torch.manual_seed(1337)

# Parameters
batch_size = 4  # How many independent sequences will we process in parallel?
block_size = 8  # What is the maximum context length for predictions?

def get_batch(split):
    # Generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [17]:
xb, yb = get_batch('train')
print('== Inputs ==')
print(xb.shape)
print(xb)

print('== Targets ==')
print(yb.shape)
print(yb)

print('========================')
for b in range(batch_size):  # Batch dimension
    for t in range(block_size):  # Block dimension
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f'When input is {context.tolist()} the target is: {target}')

== Inputs ==
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
== Targets ==
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
When input is [24] the target is: 43
When input is [24, 43] the target is: 58
When input is [24, 43, 58] the target is: 5
When input is [24, 43, 58, 5] the target is: 57
When input is [24, 43, 58, 5, 57] the target is: 1
When input is [24, 43, 58, 5, 57, 1] the target is: 46
When input is [24, 43, 58, 5, 57, 1, 46] the target is: 43
When input is [24, 43, 58, 5, 57, 1, 46, 43] the target is: 39
When input is [44] the target is: 53
When input is [44, 53] the target is: 56
When input is [44, 53, 56] the target is: 1
When input is [44, 53, 56, 1] the target is: 58
When input is [44, 53, 56, 1,

## Feed into a Neural Network

We'll start with the simplest type of NN, a bigram languge model.

In [18]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # Each token directly reads off the logits for the next token from a
        # lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx and targets are both (B, T) tensor of integers
        logits = self.token_embedding_table(idx)  # (BxTxC dimensions or 4x8x65)

        if targets is None:
            loss = None
        else:
            # reshape tensors for use in cross_entropy function
            B, T, C = logits.shape
            logits = logits.view(B * T, C)  # Logits needs to be 2 dimensional
            targets = targets.view(B * T)  # Reshape targets to 1 dimensional
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indeces in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)

            # focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)

            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)

            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)

            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx


@torch.no_grad()
def estimate_loss():
    out = {}
    m.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(1000)
        for k in range(1000):
            X, Y = get_batch(split)
            logits, loss = m(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    m.train()
    return out

In [19]:
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

torch.Size([32, 65])
tensor(5.0364, grad_fn=<NllLossBackward0>)


Our loss is stated in the `loss` tensor. This will change at each run but when I wrote this it was 5.0364. We can calculate the expected loss as the negative natural logarithm of 1 over the number of dimensions. In our case that is $-{\ln}\frac{1}{65} = 4.174$. The fact that our loss is higher means our model is too diffuse.

## Generate Predictions 

I've now added a `generate` method to the `BigramLanguageModel` class.

In [20]:
# Create a 1x1 tensor holding a zero to kick off the generation
# Remember a 0 corresponds to a newline (\n) character
idx = torch.zeros((1, 1), dtype=torch.long)
idx

tensor([[0]])

In [21]:
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))


l-QYjt'CL?jLDuQcLzy'RIo;'KdhpV
vLixa,nswYZwLEPS'ptIZqOZJ$CA$zy-QTkeMk x.gQSFCLg!iW3fO!3DGXAqTsq3pdgq


## WHAT!? This is Garbage!

It's garbage because this is a randomly initialized model. It hasn't been trained at all. Now let's actually train it.

In [22]:
# Create a pytorch optimization object using AdamW optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

The optimizer we just created will take the gradients and update the parameters using the gradients. Another possible optimizer would be Stochastic Gradient Descent (torch.optim.SDG).

Now let's set it all up.

In [23]:
batch_size = 32  # Use a bigger batch size now

for steps in range(10000):
    # every once in a while evaluate the loss on train and val sets
    if steps % 1000 == 0:
        losses = estimate_loss()
        print(f"step {steps}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data from train
    xb, yb = get_batch('train')

    # Evaluate the loss using a typical training loop
    logits, loss = m(xb, yb)  # Evalueate loss
    optimizer.zero_grad(set_to_none=True)  # Zero gradeints from prev ste
    loss.backward()  # Get gradients for all params
    optimizer.step()  # Use gradients to update parameters

print(loss.item())

step 0: train loss 4.6367, val loss 4.6454
step 1000: train loss 3.6943, val loss 3.7064
step 2000: train loss 3.1163, val loss 3.1335
step 3000: train loss 2.8047, val loss 2.8225
step 4000: train loss 2.6405, val loss 2.6581
step 5000: train loss 2.5641, val loss 2.5805
step 6000: train loss 2.5160, val loss 2.5428
step 7000: train loss 2.4933, val loss 2.5190
step 8000: train loss 2.4839, val loss 2.5065
step 9000: train loss 2.4701, val loss 2.4969
2.401021718978882


Our loss is now greatly reduced. Now lets get predictions again. It will look a little better, but it's not shakespeare yet! Not even readable.

In [24]:
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=300)[0].tolist()))


tamby,
Say d, ango thinouggof d RLUTous ds.
O:
NNoucoferesainteiss,
ENCh;
Heme LEEN:
RYConsprear tiop!
If k arony ob's arin seler e 'sulisthet is. maitr:
Mou e, th se ent arare dow inencary;
LerX?
Ascut w d Dime d tho,
SThtaive b,
Twhe ho.
AYe t y mu d tha'd
Thayofe st sos sowin himy itore pas bethe


# Next... Increase complexity of the model

Our model up to this point has been incredibly simple, the tokens aren't talking to each other. We're only using the last character to figure out what comes next. Now we're going to get these tokens talking to each other.

## The mathematical trick in self-attention

In [25]:
# Consider the following toy example:

torch.manual_seed(1337)

B, T, C = 4, 8, 2  # batch, time, channels
x = torch.randn(B, T, C)
# print(x)

Currenty the 8 tokens in a batch (aka time) are not talking to each other. We want them to talk to each other. But we only want tokens talking to the token that comes before them. AKA, token in position 5 can't see tokens 6, 7, 8 but can see tokens 1-4. 

### Version 1: For Loop

In [26]:
# We want x[b,t] = mean_{i<=t} x[b,i]

# bow = bag of words, there's a word stored at every one of the T positions
xbow = torch.zeros((B, T, C))

for b in range(B):
    # print('======')
    # print(b)
    for t in range(T):
        # print(f'  ->{t}')
        xprev = x[b, :t+1]
        # print(f'    {xprev}')
        xbow[b,t] = torch.mean(xprev, 0)

In [27]:
(.1808 + -0.3596) / 2

-0.0894

#### Look at the output!
Now it starts to make sense. Let's start by comparing `x[0]` to `xbow[0]`. The first line is equal, but the second line is now the average of the two because $(0.1808 + -0.3596) / 2 = -0.0894$ and $(-0.0700 + -0.9152) / 2 = -0.4926$. Then the third position is the average of the three, and so on. The last element is the vertical average of all the elements.

In [28]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [29]:
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

### Version 2: Matrix Multiplication
This is great BUT for loops are inefficient. Would be much better to do with Matrix multiplication. Lets look at an example:

In [30]:
torch.manual_seed(42)

# Create a 3x3 tensor of all ones
a = torch.ones(3, 3)

# Create a random 3x2 tensor 
b = torch.randint(0,10,(3,2)).float()

# Compute dot product of two matrices
c = a @ b

print(a)
print(b)
print(c)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


The dot product in position 1 of tensor `c` (value of 14) comes by multiplying the first **row** of tensor `a` by the first **column** of tensor `b`. Since `a` is all ones it ends up just being a sum of `a` column 1.

But `torch` has another great tool called `torch.tril()` that gives just the lower triangular values of a tensor and zeroes out the rest. Let's try that:

In [31]:
torch.manual_seed(42)

# Create a 3x3 tensor of all ones, wrap in tril to zero out the top right and give running sums
a = torch.tril(torch.ones(3, 3))

# Create a random 3x2 tensor 
b = torch.randint(0,10,(3,2)).float()

# Compute dot product of two matrices
c = a @ b

print(a)
print(b)
print(c)

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


When we do this we end up with a **running sum**! To make them averages only takes a little more effort:

In [32]:
torch.manual_seed(42)

# Create a 3x3 tensor of all ones
a = torch.tril(torch.ones(3, 3)) 
a = a / torch.sum(a, 1, keepdim=True)

# Create a random 3x2 tensor 
b = torch.randint(0,10,(3,2)).float()

# Compute dot product of two matrices
c = a @ b

print(a)
print(b)
print(c)

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


Now we can vectorize this and scale to our previous example with `x` and `xbow`:

In [33]:
# Call it wei (short for weights)
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)

# New version of xbow, to do all the Matrix multiplication
xbow2 = wei @ x
xbow2[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

### Version 3: Softmax

Softmax is a noramlization operation, so we get the same matrix. The reason we want to use *this one* is because the weights will start with zero but eventually the relationships will not be constant at zero but will be data dependent. We'll call these affiinities, some values find others more interesting and that is the basis of attention.

In [34]:
# Call it wei (short for weights)
tril = torch.tril(torch.ones(T, T))  # create same tril matrix and set aside
wei = torch.zeros((T, T))  # Initialize wei as all zeros
wei = wei.masked_fill(tril == 0, float('-inf'))  # replace zeros w/ -inf, this is how we keep the past from seeing the future

# Doing a softmax along every row (because dim = -1), softmax is a regularization
# function so it does the same as in version 2.
wei = F.softmax(wei, dim=-1)
print(wei)
wei = wei / wei.sum(1, keepdim=True)

# New version of xbow, to do all the Matrix multiplication
xbow3 = wei @ x
xbow3[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

### Version 4: Self-Attention!

In [35]:
torch.manual_seed(1337)

B, T, C = 4, 8, 32  # Now C is 32
x = torch.randn(B, T, C)

Previously we initiliazed `wei` with all zeros. We need to change this, because we don't want all zeros because some tokens will have more/less affinity for others. If we initialize as zero then wei has uniform rows, but that's not what we want in real life. Some tokens will find other tokens more/less interseting and we want this to be data-dependent.  Ex., a vowel might be looking for consanants in past, we want to know what they are and flow to the next token.

How we solve: Every token at each position will emit TWO VECTORS. A query and a key vector: What am I looking for? Key vector: What do I contain. The affinities come by doing a dot product with keys and queries. **That dot product becomes `wei`.** If key/query are aligned, they interact to a higher degree.

#### Let's see a single head perform self-attention



In [42]:
head_size = 16  # Hyperparameter for head size
key = nn.Linear(C, head_size, bias=False)  # Initialize linear modules with bias=False to avoid getting fixed weights
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

# Produce k and q by forwarding key and query on x
k = key(x)  # (B, T, 16)
q = query(x)  # (B, T, 16)

print(k[0])
print(q[0])

tensor([[ 0.0243,  0.7806,  0.3592,  0.2484,  0.2203, -0.2434,  0.5941,  1.2976,
          0.4907, -0.9231, -0.7141,  0.8663,  0.5429, -0.4535, -0.5635, -0.2417],
        [-0.0669,  0.8692, -0.0702,  0.5468,  1.0354, -1.2248, -0.0357, -0.3485,
         -0.3257,  0.8254,  0.6185,  1.2044,  0.7290, -0.6172, -0.1570, -0.2020],
        [-0.1794, -0.5119, -0.6379, -0.2220,  0.7300, -0.0094, -0.3451,  1.1641,
         -0.2302, -0.1707,  0.2256,  0.2680,  0.8527, -0.4898,  0.6744,  0.1902],
        [-0.5010,  0.0236, -0.3745,  0.2033,  0.3464, -0.5504, -0.2041, -0.0295,
         -0.4679,  1.0379,  0.3958,  0.1890,  0.5071,  0.5054,  0.7219, -0.6512],
        [ 0.2949, -0.3837, -0.3887, -0.4149, -0.2714, -0.2772, -0.1827,  0.6626,
          0.7409, -0.6932, -0.0868, -0.8386,  0.5869,  0.3283, -0.7405,  0.5118],
        [-0.0936,  0.6212,  0.5187,  1.2611, -0.8386, -0.1025, -0.3600,  0.1297,
         -0.3774, -0.6213, -0.0153,  0.2046,  0.1706, -1.3380, -0.7898, -0.7704],
        [ 0.0977,  1.0

When we forwarded `key` and `query` on `x`. Each token in all positions in BxT arrangement produce a key and query. No communication has happened yet. That happens now. All the queries will dot product with each of the keys.

In [43]:
# Need to transpose to the right dimension before dot product. What we get is:
# (B, T, 16) @ (B, 16, T) ---> (B, T, T)
wei = q @ k.transpose(-2, -1)

In [44]:
# Now add the other steps in to remove the top triangle of data (otherwise we're leaking data)
tril = torch.tril(torch.ones(T, T))  # create same tril matrix and set aside
wei = wei.masked_fill(tril == 0, float('-inf'))  # replace zeros w/ -inf, this is how we keep the past from seeing the future
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v

out.shape

torch.Size([4, 8, 16])

In [45]:
# For every row of B we'll have a T^2 matrix giving us the weight.
# The weights are no longer uniform like they were before but indicate
# affinities between tokens.
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6513, 0.3487, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3639, 0.3610, 0.2751, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6837, 0.1142, 0.1494, 0.0528, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0632, 0.3618, 0.1033, 0.3048, 0.1669, 0.0000, 0.0000, 0.0000],
        [0.4138, 0.1132, 0.0976, 0.0557, 0.0559, 0.2638, 0.0000, 0.0000],
        [0.4149, 0.1918, 0.0849, 0.1211, 0.0862, 0.0532, 0.0479, 0.0000],
        [0.1260, 0.0522, 0.1892, 0.3630, 0.0945, 0.0475, 0.0607, 0.0670]],
       grad_fn=<SelectBackward0>)

## NB -> Script Conversion

This will be easier if we work from a script instead of a notebook.