# Creating a GPT from Scratch

In this notebook, I implement a GPT from scratch following Andrej Karpathy's YouTube series, along with my notes of the lecture.

In [53]:
# imports
import torch
import torch.nn as nn
from torch.nn import functional as F

In [54]:
# globals
batch_size = 32
block_size = 8
lr = 1e-3
epochs = 10_000
eval_interval = 1000
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200

torch.manual_seed(1337);

### Creating the Dataset

We create the vocabulary and the dataset as we have been doing in the other `makemore` notebooks. But here, instead of using the `names` dataset, we will be using Tiny Shakespeare. 

In [55]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print("Length of the dataset in characters: ", len(text))

Length of the dataset in characters:  1115394


Here, we are building a character level language model. Our vocabulary is going to be all the characters in the dataset, and the *tokens* in our language model are the characters mapped to integers. In LLMs, this tokenization could be at *subword* level, or something else also! 

The larger the vocabulary, the larger integer to token mapping you have. That means, that you can represent larger sentences using fewer tokens. On the contrary, if you have less number of tokens in your vocabulary, you will need more tokens to represent larger sentence. 

For example, with character level language model, we need `len(sentence)` tokens to represent it. But if we had a word level tokenization, then we would need `len(sentence.split(" "))` tokens, which would be fewer than the characters.

In [56]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

print("Vocab Size is: ", vocab_size)

Vocab Size is:  65


In [57]:
# Create an integer to character mapping- i.e. the tokenizer that encodes and decodes tokens

stoi = { ch:i for i, ch in enumerate(chars) }
itos = { i:ch for i, ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]  # takes an input string, and outputs a list of integers. i.e. the character map
decode = lambda l: "".join([itos[i] for i in l]) # takes the token list, and produces the string for it

print(encode("Hii, there!"))
print(decode(encode("Hii, there!")))

[20, 47, 47, 6, 1, 58, 46, 43, 56, 43, 2]
Hii, there!


We encode the text into PyTorch tensor now, and split the encoded dataset into train and validation split.

In [58]:
data = torch.tensor(encode(text), dtype=torch.long)

In [59]:
cut = int(0.9 * len(data))
train_data = data[:cut]
validation_data = data[cut:]

We define the context length first. This context length is the maximum context that the model can look at when making a prediction. However, there doesn't have to be 8 characters always- you can have less than that. Thus, you get something as this. But notice that now we're dealing with tokens and not integers.

In [60]:
block_size = 8 # context length: maximum 8 tokens can be taken as context

sample_x = train_data[:block_size]
sample_y = train_data[1:block_size + 1]

for t in range(block_size):
    context = sample_x[:t+1]
    target = sample_y[t]

    print(f"When input is {context} the target: {target}")

When input is tensor([18]) the target: 47
When input is tensor([18, 47]) the target: 56
When input is tensor([18, 47, 56]) the target: 57
When input is tensor([18, 47, 56, 57]) the target: 58
When input is tensor([18, 47, 56, 57, 58]) the target: 1
When input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
When input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58


### Creating Batches

Now that the dataset is there, we need to think about how the input text can be passed as a batch. 

Before that, an important thing to note about transformers is that there is a maximum number of tokens that you can pass to them. They are able to handle sequential inputs of arbitrary length, but this arbitrary length is also capped to some number such as 512. This number is the context length. You can have at maximum that many tokens but at minimum, you can have any number of tokens. 

Now let's think about how would we create and pass a batch of sequences to the model. Our wishlist is the following:

1. We want to pick arbitrary sequences so that the model can generalize well. How do we pick random sequences? Just pick out random starting indexes.
2. How big a sequence should you pick? Well, it cannot be more than the context length of the model. For the moment, assume you would pick the input of size `block_size` i.e. the context size. For example, if you have a `batch_size` of 4 and `block_size` of 8, then you would randomly pick 4 indices in the dataset, and index 8 characters from that index. 
3. What should be the targets? The targets are just the next character. 

As we have seen before, one sequence of 8 characters gives us 8 training examples ( in cell above ). So when we have a batch of size 4, with each having a sequence of 32, it is going to give 32 training samples. 

**Important:** Each training sample can be passed independently to the transformer!

The key is going to be figuring out how to pass this to the transformer.

In [61]:
torch.manual_seed(1337)

batch_size = 4
block_size = 8

def get_batch(split:str):
    data = train_data if split == 'train' else validation_data
    ix = torch.randint(len(data) - block_size, (batch_size, )) # randomly select batch_size many indices. len(data) - block_size just handles edge case
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])

    x, y = x.to(device), y.to(device)
    return x, y

xb, yb = get_batch('train')

print(f"Shape of inputs: {xb.shape}")
print(f"Shape of outputs: {yb.shape}")

Shape of inputs: torch.Size([4, 8])
Shape of outputs: torch.Size([4, 8])


For this batched input, we can again split the training examples. But note, this is NOT relevant till we get to transformers. At the moment, we are just training a bigram model.

In [62]:
i = 0
for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension ( PyTorch convention: (B, T, C) = (Batch, Time, Channel))
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f"{i}: When input is {context} the target is: {target}")
        i+=1

0: When input is tensor([24]) the target is: 43
1: When input is tensor([24, 43]) the target is: 58
2: When input is tensor([24, 43, 58]) the target is: 5
3: When input is tensor([24, 43, 58,  5]) the target is: 57
4: When input is tensor([24, 43, 58,  5, 57]) the target is: 1
5: When input is tensor([24, 43, 58,  5, 57,  1]) the target is: 46
6: When input is tensor([24, 43, 58,  5, 57,  1, 46]) the target is: 43
7: When input is tensor([24, 43, 58,  5, 57,  1, 46, 43]) the target is: 39
8: When input is tensor([44]) the target is: 53
9: When input is tensor([44, 53]) the target is: 56
10: When input is tensor([44, 53, 56]) the target is: 1
11: When input is tensor([44, 53, 56,  1]) the target is: 58
12: When input is tensor([44, 53, 56,  1, 58]) the target is: 46
13: When input is tensor([44, 53, 56,  1, 58, 46]) the target is: 39
14: When input is tensor([44, 53, 56,  1, 58, 46, 39]) the target is: 58
15: When input is tensor([44, 53, 56,  1, 58, 46, 39, 58]) the target is: 1
16: Wh

## First Self-Attention Block

The bigram model was not paying any attention to the context. It was just considering the previous character. Now, we will build a self-attention block that pays attention to context of characters.

### Attention, Mathematically

Think about what attention means. Attention is just a mechanism to let the past context interact with the current token. How do you make two vectors *interact*? By the way of addition and multiplication.

One simple way of encoding context would be to take average of embeddings of all the previous tokens in the context, and remember that the context length is fixed. But what taking a simple average means is that the current token is paying equal attention to the all the previous tokens in the context. We don't want that. We want the attention to be data dependent, which in other words means that this attention should be *learned* from the model.

So we're going to take a weighted average of the previous tokens but these weights are going to be learned by the model.

These weights need to sum up to 1. And each token should only look at itself and the previous token- not the tokens that come after it.

How do you encode this operation mathematically? Why, matrix multiplication, of course! If you think about it, if you have a lower diagonal weight matrix, then you can multiply it by the context and you're going to satisfy the condition that each token should only look at the token that comes before it.

```
weights @ context
```

In PyTorch, `torch.tril` gives us a lower triangular matrix. You need to pass the matrix that you want to convert to lower triangular.

How do we encode a simple average with this matrix multiplication?
Just divide each row by the number of non-zero entries in it. Think about it, in a lower triangular matrix, everything above the main diagonal is zero. So if you divide the each row by the number of non-zero entries in it, then essentially you are taking average. See the example below for the weights:

In [63]:
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [64]:
weights = torch.tril(torch.ones(T, T))
weights = weights / weights.sum(1, keepdim=True)
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

Now we can do matrix multiplication to get the simple average. But think about the dimensions of this multiplication.

```
weights.shape = (T, T)
x.shape = (B, T, C)
```

So, PyTorch is going to *slide* the matrix multiplication across the batch dimension and we're going to get an output that is of shape: ```(B, T, C)```

In [65]:
xbow = weights @ x
xbow.shape

torch.Size([4, 8, 2])

Our main requirement is that these attention weights need to be trainable. For the moment, they are not. So we need a way that allows us to learn those attention weights while still preserving the averaging mechanism.

That's where softmax is going to come in. Look up Mitesh Sir's Teacher Forcing or Masked Attention lectures for the intuition behind the `-inf` values. But essentially, these `-inf` and `0` values achieve the same thing as before:

1. They let the model only focus on previous tokens
2. Each row sums to 1.

Because:
1. $softmax(- \infty) = 0$
2. $exp(0) = 1$

What we want is that the weights at each row should sum to 1 and they should be learnable. We know a function that lets generate a row that sums to 1 and that is softmax.

In [66]:
tril = torch.tril(torch.ones(T, T))

weights = torch.zeros((T, T))
weights = weights.masked_fill(tril == 0, float('-inf')) # wherever tril == 0, replace with -inf
weights = F.softmax(weights, dim=-1)
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

Again we can do matrix multiplication and we get the same matrix as above:

In [67]:
(weights @ x).allclose(xbow)

True

Now, let's take a moment to appreciate what is happening here. For each token, we said we would combine the information from the previous tokens by addition and taking the average. With this mask, we effectively achieve the same. We prevent the information present in the future tokens to be combined with the present token because that information is going to get multiplied by zero.

The `weights` matrix is essentially what is going to be defining the how much attention to pay to each of the previous tokens given a certain token. We want the model to learn this matrix.

### Building a Self-Attention Block

Now, let's think about how self-attention is implemented. Each token (of the block_size tokens) *emits* two vectors- a key and a query. You'll introduce a Key and Query matrix as parameters which transform a given token. Then linear transformation of token with Key matrix gives the key vector and similarly for the query vector.

- **Key Vector:** Captures "What I contain"
- **Query Vector:** Captures "What am I looking for"

So far, there is no communication between tokens.

But then, the way we're going to find whether a token is worth paying attention to is by doing dot product of queries and keys. This is how we're going to find attention weights.   For each token, you want to find out whether the other tokens contain what the current token is looking for. If the dot product of these two vectors is high, then that means that the key vector contains what the query vector was looking for. 

This dot product is what the attention weights matrix will be!

In [68]:
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

head_size = 16 # hyperaparam
key = nn.Linear(C, head_size, bias=False) # W_K matrix
query = nn.Linear(C, head_size, bias=False) # W_Q matrix

# Apply linear transforms to get key and query vectors for each
k = key(x) # shape is (B, T, 16)
q = query(x) # shape is (B, T, 16)

weights = q @ k.transpose(-2, -1) # Transpose T, C of `k`


Small comment on the matrix multiplication of `q` and `k`. This is essentially going to achieve pairwise dot product of queries and keys. But notice that we have a batch dimension. We don't want to use batch dimension when doing matrix multiplication. So we need to transpose only the last two dimensions. So the shape of weights now is: `(B, T, T)`. $ T \times T$ is what we want- pairwise attention weights, and we have those pairwise attention weights for each the batches. 

Now that we have the `weights`, we can do what we did before, we can mask it and apply softmax to get attention weighted aggregates of the previous tokens. Again, these are just weights to take a weighted average of the previous tokens such that the information in them is encoded in the current token.

In [69]:
tril = torch.tril(torch.ones(T, T))

weights = weights.masked_fill(tril==0, float('-inf'))
weights = F.softmax(weights, dim=-1)
out = weights @ x

weights[0]


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6002, 0.3998, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2444, 0.0169, 0.7387, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6302, 0.1891, 0.1313, 0.0494, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2442, 0.3194, 0.0989, 0.2715, 0.0659, 0.0000, 0.0000, 0.0000],
        [0.0893, 0.2987, 0.0265, 0.1255, 0.0576, 0.4024, 0.0000, 0.0000],
        [0.1572, 0.0057, 0.3695, 0.0077, 0.3306, 0.0557, 0.0735, 0.0000],
        [0.0229, 0.5707, 0.0145, 0.0094, 0.0155, 0.2216, 0.0624, 0.0830]],
       grad_fn=<SelectBackward0>)

As we can see, `weights` now contains the attention weights to take the weighted average of only the past tokens and they are not uniform- so we have learnable parameters now!

And `out` is now the modified input `x` by taking the weighted average using the attention weights.

There is one additional component in the self attention block. That is, we don't modify the input `x` directly. We pass `x` through a linear transformation to produce a `value` for that input and then we modify this value. 

Another way of thinking about this is that the information contained in the `x` is private. So it tells the other vectors what it contains (key), what it is looking for (query), and what it will communicate with them (value). So we need a linear transform, which again are parameters.

In [70]:
value = nn.Linear(C, head_size)
v = value(x)
out = weights @ v
out.shape # out shape is now 16D instead of 32D as before

torch.Size([4, 8, 16])

All this Q, K, V matrix multiplications and the weighted aggregation constitutes one *head* of attention. 

### Andrej's Notes on Self-Attention

1. Attention is just a communication mechanism- not specific to language modeling. You can look at it as a directed graph where each node points to other nodes with which it is communicating, and the weights of those edges? The attention weights! So you can apply attention to ANY directed graph!

2. There is no notion of space in attention. Attention simply acts over a set of vectors. So if you have a problem where the sequence is important, then you need to encode the positional information, like we did.

3. There is no communication across batches. With the matrix multiplications, the elements within a batch communicate but across batches there is no communication. 

4. Not all tasks require you to have a masked attention. In this task, we wanted to predict the future token and thus we didn't want the tokens to communicate with the future. But let's say your task is sentiment analysis, then there is no need for masking. You will have the entire sentiment with you and you can allow the communication with future tokens also. Without masking the self-attention block is called an "encoder block", and with masking it is called the "decoder block". Look up Mitesh sir's lectures on Encoder only and Decoder only models. Attention by itself has no notion about which nodes it can communicate with.

5. Self-attention means that the keys, queries, and values are coming from same source- the input `x`. But there is cross-attention also where the queries come from different source and the keys & values come from an entirely different source, such as an encoder-decoder architecture.

6. In the original "Attention is All You Need" paper, we have the equation:
$softmax(\dfrac{QK^T}{\sqrt{d_k}})V$. We have done all the matrix multiplication part but there is also `sqrt(head_size)`. The attention is being *scaled* so it's called Scaled attention, and it is to make sure that when the input to Q and K is unit variance, then the `weights` is unit variance too, so as to avoid saturation when being passed through the softmax. So at least at initialization, you want `weights` to avoid taking too extreme values. So when we are doing the `q @ k` matrix multiplication, we need to multiply it by `C ** -0.5`.

### Creating a Head Module:

In [71]:
n_embd = 32 # this is a global variable
block_size = 8

class Head(nn.Module):
    """Single head of self-attention"""

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape

        k = self.key(x)
        q = self.query(x)

        weights = q @ k.transpose(-2, -1) * C**-0.5

        # TODO: need to understand indexing in tril for this line. It's most likely to handle less than block_size inputs
        weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf') )
        weights = F.softmax(weights, dim=-1)

        v = self.value(x)

        out = weights @ v
        return out


In PyTorch, whatever is not a parameter of the model is called a `buffer`. It is just creating the `tril` variable and assigning it to the module but it is not a parameter of the module. 

## A Language Model With Attention Head

We introduce an additional linear layer called `lm_head` which will transform the output of attention head to match the dimensions of the targets. 

We also need to make sure that there are at maximum only `block_size` tokens that are passed to the model, otherwise the `position_embedding_table`, which captures the positional information upto block size, is going to throw an error. 

**Pro Tip:** Self-attention cannot handle very high learning rates. So you need something like `1e-3` or lower, and you need to train for more number of iterations. 

In [72]:
n_embd = 32
block_size = 8

class LangModelWithOneHead(nn.Module):
    def __init__(self):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_head = Head(n_embd)

        # Final linear layer to produce logits of equal dim as targets
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        
        # Combine the information of the token and the position
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb

        # pass the position + token information through the attention head
        x = self.sa_head(x) # -> (B, T, head_size)

        # produce logits
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            
            # You can't have more than block_size tokens in context now because pos_embd will throw an error.
            # so keep only the last block_size tokens
            idx_cond = idx[:, -block_size:]

            logits, loss = self(idx_cond) # logits.shape is (4, x, 65)

            # same as bigram
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=1)

            next_idx = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_idx), dim=1)
        
        # idx will be the sequence generated for each batch
        return idx

### Training the Attention Based Model

In [73]:
model = LangModelWithOneHead()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

We also want a smoother loss estimate since the loss can vary batch to batch based on what sample is drawn. So we do what we did before to smooth the loss estimate.

In [74]:
@torch.no_grad()
def estimate_loss():
    out = { }

    # set model to eval mode
    model.eval()

    # for train and val data, take mean of 300 iters
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    
    # set model back to train mode
    model.train()
    return out

In [75]:
for epoch in range(epochs):
    if epoch % eval_interval == 0:
        losses = estimate_loss()
        print(f"Step {epoch}: train loss {losses['train']:.4f}, and val loss:{losses['val']:.4f}")
    
    # sample a batch from the dataset
    xb, yb = get_batch('train')

    # Evaluate loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(f"Loss is: {loss.item():.4f}\n")
print("Generated sequence: ")
print(decode(model.generate(torch.zeros((1, 1), dtype=torch.long), max_new_tokens=150)[0].tolist()))

Step 0: train loss 4.2181, and val loss:4.2123
Step 1000: train loss 2.8455, and val loss:2.8353
Step 2000: train loss 2.5760, and val loss:2.6110
Step 3000: train loss 2.5465, and val loss:2.4903
Step 4000: train loss 2.4745, and val loss:2.5097
Step 5000: train loss 2.4687, and val loss:2.5062
Step 6000: train loss 2.4688, and val loss:2.4536
Step 7000: train loss 2.4384, and val loss:2.4494
Step 8000: train loss 2.4328, and val loss:2.4681
Step 9000: train loss 2.4604, and val loss:2.4686
Loss is: 2.6558

Generated sequence: 

Sthithang angh drar did yo my pulas grires lounowm tor.

Pr ing
I acrevont, atheen thitanimis BEfe no.
 gep,
Wif sere
no age Your wlo, phourind ind wi


The loss is slightly better but it needs to be a whole lot better than this also, as the text being produced is still not as impressive. 

## Multi-Head Attention

Instead of having just one attention head, we can have multiple attention heads and we can concatenate their outputs, and then we can pass this concatenated output to the feedforward neural network. This is much like convolutions. You train multiple small filters in the hopes that they specialize in capturing something very specific. Similarly, here also, you have multiple smaller heads instead of one big one in similar hopes. Further, we want all these attention heads to be running in parallel. 

Running parallely is very easy to do in PyTorch and since we're only concatenating the outputs, the heads are also easy to implement. 

In [76]:
class MultiHeadedAttention(nn.Module):
    """Multiple attention heads running in parallel"""

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)]) # all heads run in parallel

    def forward(self, x):
        return torch.cat([h(x) for h in self.heads], dim=-1)

Then we can include this multi-headed attention in the language model. But there is one more component that is there in the "Attention is All You Need" paper, that we need to add in the model. And that is the feedforward network. 

Currently, we have one linear layer at the end to produce logits. The feedforward neural network that we want is a simple MLP, but why is this needed?

Think of it this way: The multi-headed attention did the communication part. It found out and reflected which context tokens are important for each token. Then if we just have a linear layer that directly produces the logits, then we are going too fast! We need to let the model *think* on what it found on the other tokens. So to allow the model to do this, we add a linear layer also. We add a simple MLP here.

All the tokens have gathered information about their specific attentions via the attention heads. Now, they need to *think* on this information on their own. So this linear level is applied on each of the individual token independently, without worrying about the context now.

In [77]:
class FeedForward(nn.Module):
    """Simple MLP followed by non-linearity"""

    def __init__(self, n_embd):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(n_embd, n_embd), 
            nn.ReLU()
        )

    def forward(self, x):
        return self.net(x)

In [78]:
class LangModelWithMultiHead(nn.Module):

    def __init__(self):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_heads = MultiHeadedAttention(4, n_embd // 4) # 4 heads of size (32 / 4 = 8)
        
        self.ffwd = FeedForward(n_embd)

        # Final linear layer to produce logits of equal dim as targets
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        
        # Combine the information of the token and the position
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb # (B, T, C)

        # pass the position + token information through the attention heads
        x = self.sa_heads(x) # For this example, the output is: (B, T, 32)

        # Call feedforward
        x = self.ffwd(x) # -> (B, T, n_embd=32)

        # produce logits
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            
            # keep only the last block_size tokens
            idx_cond = idx[:, -block_size:]

            logits, loss = self(idx_cond) # logits.shape is (4, x, 65)

            # same as bigram
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=1)

            next_idx = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_idx), dim=1)
        
        # idx will be the sequence generated for each batch
        return idx

This is a language model that has one attention block. In the transformer architecture, you have multiple such blocks that are stacked one after the other. So we are interspersing the attention, followed by linear, followed again by attention and so on.

But for the moment let's train this single block model and see the output that we get.

In [79]:
model = LangModelWithMultiHead()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

In [80]:
for epoch in range(epochs):
    if epoch % eval_interval == 0:
        losses = estimate_loss()
        print(f"Step {epoch}: train loss {losses['train']:.4f}, and val loss:{losses['val']:.4f}")
    
    # sample a batch from the dataset
    xb, yb = get_batch('train')

    # Evaluate loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(f"Loss is: {loss.item():.4f}\n")
print("Generated sequence: ")
print(decode(model.generate(torch.zeros((1, 1), dtype=torch.long), max_new_tokens=150)[0].tolist()))

Step 0: train loss 4.2029, and val loss:4.2021
Step 1000: train loss 2.7353, and val loss:2.7016
Step 2000: train loss 2.5472, and val loss:2.5305
Step 3000: train loss 2.4571, and val loss:2.4705
Step 4000: train loss 2.4199, and val loss:2.4289
Step 5000: train loss 2.4011, and val loss:2.4005
Step 6000: train loss 2.3493, and val loss:2.3913
Step 7000: train loss 2.3530, and val loss:2.3418
Step 8000: train loss 2.3074, and val loss:2.3463
Step 9000: train loss 2.2718, and val loss:2.3193
Loss is: 2.3365

Generated sequence: 

KNIENGLORWERNGERNLESS:'s
And led frour she, a thave
Whe canle, seld spnow froscoune maghe to cout gom.
Thed, the Maim mart dent
Halor; cot hert rut ho


The loss is significantly improved. But the text still isn't great, however we are able to see some words there. We need to improve this loss even more, which we will do by stacking up these multi-headed attention blocks. But simply stacking up these blocks won't necessarily improve performance. Why?

Because stacking up the blocks will make the neural network quite deep, and you need things like batch & layer normalization, skip connections, etc. to be able to successfully train deep networks. 

## Transformer Block With Skip Connections

We will implement the block with skip connections, as was introduced in the paper. But let's first think a bit about skip connections. Refer to FastAI course notebook on ResNets to know the basics of skip connections. But here, I am adding Andrej's take on Skip Connections which is super intuitive. 

### Skip Connections

Skip connections is one way of letting us train deeper neural networks. 

![Skip Connection](skip-connection.png)

You have a residual pathway that directly produces the output. But you also have an option to drift off and go into the other direction, do some computation, and add this output to the residual output.

One important thing about this is that the residual connection and the output is combined via addition. Addition operation equally distributes gradients to all the branches, in this case- the residual path and the other branch. 

So, by allowing this residual connection, the gradients have a highway to flow, instead of getting saturated because of the depth.

Further, the "branches" are initialized such that they contribute very little, if at all, at the beginning. So initially, we're pretty much training a shallow model. But over time, we hope that these branches will learn something that is not known to the shallow model and improve the model performance. 

This lets us train deeper network as if the deeper network has nothing to contribute, the model can learn to *ignore* the deep layers as it has a shallow network within it also.

### Layer Normalization

It's very similar to batch normalization, and we're going to implement it directly inside the code block. In the original paper, LayerNorm was applied after the feedforward and the attention blocks. But now, it is much more common to apply layer normalization before these transformations. It's called "pre-norm" formulation.

We will implement residual connections and refactor the attention mechanism into transformer blocks. We will also add dropouts to avoid overfitting. Dropouts are generally added as the last layer before the residual and the branch is merged into one.

In [81]:
class Head(nn.Module):
    """Single head of self-attention"""

    def __init__(self, head_size, dropout=0.2):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape

        k = self.key(x)
        q = self.query(x)

        weights = q @ k.transpose(-2, -1) * C**-0.5

        # TODO: need to understand indexing in tril for this line. It's most likely to handle less than block_size inputs
        weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf') )
        weights = F.softmax(weights, dim=-1)

        weights = self.dropout(weights)

        v = self.value(x)

        out = weights @ v
        return out


In [82]:
class Block(nn.Module):
    """Single Transformer block"""

    def __init__(self, n_embd, n_head):
        super().__init__()

        head_size = n_embd // n_head # make head size smaller based on num of heads
        self.sa = MultiHeadedAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)

        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x)) # branch off, do some computation (LayerNorm applied before transformation) and come back i.e. skip connection
        x = x + self.ffwd(self.ln2(x)) # branch off, do some computation (LayerNorm applied before transformation) and come back i.e. skip connection

        return x

There is one slight problem here. That is in the `forward` method of this block. We have naively added the current input with the output of the self-attention block and the feedforward block. But this addition may not be compatible as the dimensions may change. So in the feedforward and the self-attention layers, we need to introduce a projection that maps their output to the same dimension as the input and so that we can perform the addition mentioned in the `forward` method of the `Block`.

In the original paper, the dimension of inner layer of feedforward network is four times the dimensionality of the self-attention. 

In [83]:

class MultiHeadedAttention(nn.Module):
    """Multiple attention heads running in parallel"""

    def __init__(self, num_heads, head_size, dropout=0.2):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)]) # all heads run in parallel
        self.proj = nn.Linear( n_embd, n_embd) # (the output of self-attention, the required output). In this case, both are same. But they need not be
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))

        return out
    
class FeedForward(nn.Module):
    """Simple MLP followed by non-linearity"""

    def __init__(self, n_embd, dropout=0.2):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd), 
            nn.ReLU(), 
            nn.Linear( 4 * n_embd, n_embd), # (dim of output of ffwd, the required dim). In this case, both are same. But they need not be
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

Let's define the architecture for this, train the model and see the output now.

In [84]:
class LanguageModelTransformer(nn.Module):

    def __init__(self, n_head, n_blocks, block_size, n_embd):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)

        self.blocks = nn.Sequential( *[Block(n_embd, n_head) for _ in range(n_blocks)]  )

        self.ln_final = nn.LayerNorm(n_embd)

        # Final linear layer to produce logits of equal dim as targets
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        
        # Combine the information of the token and the position
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb # (B, T, C)

        # pass the position + token information through the attention heads
        x = self.blocks(x) # For this example, the output is: (B, T, 32)

        x = self.ln_final(x)

        # produce logits
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            
            # keep only the last block_size tokens
            idx_cond = idx[:, -block_size:]

            logits, loss = self(idx_cond) # logits.shape is (4, x, 65)

            # same as bigram
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=1)

            next_idx = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_idx), dim=1)
        
        # idx will be the sequence generated for each batch
        return idx

In [85]:
model = LanguageModelTransformer(n_blocks=4, n_head=4, block_size=8, n_embd=32)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

In [86]:
for epoch in range(epochs):
    if epoch % eval_interval == 0:
        losses = estimate_loss()
        print(f"Step {epoch}: train loss {losses['train']:.4f}, and val loss:{losses['val']:.4f}")
    
    # sample a batch from the dataset
    xb, yb = get_batch('train')

    # Evaluate loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(f"Loss is: {loss.item():.4f}\n")
print("Generated sequence: ")
print(decode(model.generate(torch.zeros((1, 1), dtype=torch.long), max_new_tokens=250)[0].tolist()))


Step 0: train loss 4.3513, and val loss:4.3609
Step 1000: train loss 2.5287, and val loss:2.5301
Step 2000: train loss 2.4189, and val loss:2.3963
Step 3000: train loss 2.3714, and val loss:2.4043
Step 4000: train loss 2.3527, and val loss:2.3573
Step 5000: train loss 2.3200, and val loss:2.3082
Step 6000: train loss 2.2617, and val loss:2.2762
Step 7000: train loss 2.2245, and val loss:2.2669
Step 8000: train loss 2.2005, and val loss:2.2368
Step 9000: train loss 2.1924, and val loss:2.2188
Loss is: 2.2543

Generated sequence: 


I:
But thour geing?
Thats Ieed bet hey,
A:
And nome singnbay:
KI with sad, balll hingt or a stam's tio man and thif at y,maper lemens and with shor's than:

onish nochas tiull lomede porsit uyrron.
Comydfolgeae muall takepto as ham tnou dit tho
to c


Now we have all the building blocks to train a much deeper model. So let's do that:

In [87]:
batch_size = 64
block_size = 256
epochs = 5000
eval_interval = 1000
learning_rate = 3e-4
eval_iters = 200
n_embd = 384
n_head = 6
n_blocks = 6
dropout = 0.2

In [88]:
model = LanguageModelTransformer(n_blocks=n_blocks, n_head=n_head, block_size=block_size, n_embd=n_embd)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:
for epoch in range(epochs):
    if epoch % eval_interval == 0:
        losses = estimate_loss()
        print(f"Step {epoch}: train loss {losses['train']:.4f}, and val loss:{losses['val']:.4f}")
    
    # sample a batch from the dataset
    xb, yb = get_batch('train')

    # Evaluate loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(f"Loss is: {loss.item():.4f}\n")
print("Generated sequence: ")
print(decode(model.generate(torch.zeros((1, 1), dtype=torch.long), max_new_tokens=250)[0].tolist()))

I trained this model on Google Colab with GPU. After 20,000 epochs, I had to reduce the learning rate by dividing it by 10, as the loss was not decreasing and was oscillating between 1.90 to 1.95. After decreasing it, I got a better loss at around 1.85. After running at this learning rate of `3e-5` for 10,000 epochs I got a better loss of around 1.80. After repeating this process several times, we get a comparable performance as Andrej's lecture. So here's my tiny Shaekespeare when it outputs 10,000 characters: