# Creating a GPT from Scratch

In this notebook, I implement a GPT from scratch following Andrej Karpathy's YouTube series, along with my notes of the lecture.

In [1]:
# imports
import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
# globals
batch_size = 32
block_size = 8
epochs = 3000
eval_interval = 300 
learning_rate = 1e-2
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200

torch.manual_seed(1337);

### Creating the Dataset

We create the vocabulary and the dataset as we have been doing in the other `makemore` notebooks. But here, instead of using the `names` dataset, we will be using Tiny Shakespeare. 

In [3]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print("Length of the dataset in characters: ", len(text))

Length of the dataset in characters:  1115394


Here, we are building a character level language model. Our vocabulary is going to be all the characters in the dataset, and the *tokens* in our language model are the characters mapped to integers. In LLMs, this tokenization could be at *subword* level, or something else also! 

The larger the vocabulary, the larger integer to token mapping you have. That means, that you can represent larger sentences using fewer tokens. On the contrary, if you have less number of tokens in your vocabulary, you will need more tokens to represent larger sentence. 

For example, with character level language model, we need `len(sentence)` tokens to represent it. But if we had a word level tokenization, then we would need `len(sentence.split(" "))` tokens, which would be fewer than the characters.

In [4]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

print("Vocab Size is: ", vocab_size)

Vocab Size is:  65


In [5]:
# Create an integer to character mapping- i.e. the tokenizer that encodes and decodes tokens

stoi = { ch:i for i, ch in enumerate(chars) }
itos = { i:ch for i, ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]  # takes an input string, and outputs a list of integers. i.e. the character map
decode = lambda l: "".join([itos[i] for i in l]) # takes the token list, and produces the string for it

print(encode("Hii, there!"))
print(decode(encode("Hii, there!")))

[20, 47, 47, 6, 1, 58, 46, 43, 56, 43, 2]
Hii, there!


We encode the text into PyTorch tensor now, and split the encoded dataset into train and validation split.

In [6]:
data = torch.tensor(encode(text), dtype=torch.long)

In [7]:
cut = int(0.9 * len(data))
train_data = data[:cut]
validation_data = data[cut:]

We define the context length first. This context length is the maximum context that the model can look at when making a prediction. However, there doesn't have to be 8 characters always- you can have less than that. Thus, you get something as this. But notice that now we're dealing with tokens and not integers.

In [8]:
block_size = 8 # context length: maximum 8 tokens can be taken as context

sample_x = train_data[:block_size]
sample_y = train_data[1:block_size + 1]

for t in range(block_size):
    context = sample_x[:t+1]
    target = sample_y[t]

    print(f"When input is {context} the target: {target}")

When input is tensor([18]) the target: 47
When input is tensor([18, 47]) the target: 56
When input is tensor([18, 47, 56]) the target: 57
When input is tensor([18, 47, 56, 57]) the target: 58
When input is tensor([18, 47, 56, 57, 58]) the target: 1
When input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
When input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58


### Creating Batches

Now that the dataset is there, we need to think about how the input text can be passed as a batch. 

Before that, an important thing to note about transformers is that there is a maximum number of tokens that you can pass to them. They are able to handle sequential inputs of arbitrary length, but this arbitrary length is also capped to some number such as 512. This number is the context length. You can have at maximum that many tokens but at minimum, you can have any number of tokens. 

Now let's think about how would we create and pass a batch of sequences to the model. Our wishlist is the following:

1. We want to pick arbitrary sequences so that the model can generalize well. How do we pick random sequences? Just pick out random starting indexes.
2. How big a sequence should you pick? Well, it cannot be more than the context length of the model. For the moment, assume you would pick the input of size `block_size` i.e. the context size. For example, if you have a `batch_size` of 4 and `block_size` of 8, then you would randomly pick 4 indices in the dataset, and index 8 characters from that index. 
3. What should be the targets? The targets are just the next character. 

As we have seen before, one sequence of 8 characters gives us 8 training examples ( in cell above ). So when we have a batch of size 4, with each having a sequence of 32, it is going to give 32 training samples. 

**Important:** Each training sample can be passed independently to the transformer!

The key is going to be figuring out how to pass this to the transformer.

In [9]:
torch.manual_seed(1337)

batch_size = 4
block_size = 8

def get_batch(split:str):
    data = train_data if split == 'train' else validation_data
    ix = torch.randint(len(data) - block_size, (batch_size, )) # randomly select batch_size many indices. len(data) - block_size just handles edge case
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])

    x, y = x.to(device), y.to(device)
    return x, y

xb, yb = get_batch('train')

print(f"Shape of inputs: {xb.shape}")
print(f"Shape of outputs: {yb.shape}")

Shape of inputs: torch.Size([4, 8])
Shape of outputs: torch.Size([4, 8])


For this batched input, we can again split the training examples. But note, this is NOT relevant till we get to transformers. At the moment, we are just training a bigram model.

In [10]:
i = 0
for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension ( PyTorch convention: (B, T, C) = (Batch, Time, Channel))
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f"{i}: When input is {context} the target is: {target}")
        i+=1

0: When input is tensor([24]) the target is: 43
1: When input is tensor([24, 43]) the target is: 58
2: When input is tensor([24, 43, 58]) the target is: 5
3: When input is tensor([24, 43, 58,  5]) the target is: 57
4: When input is tensor([24, 43, 58,  5, 57]) the target is: 1
5: When input is tensor([24, 43, 58,  5, 57,  1]) the target is: 46
6: When input is tensor([24, 43, 58,  5, 57,  1, 46]) the target is: 43
7: When input is tensor([24, 43, 58,  5, 57,  1, 46, 43]) the target is: 39
8: When input is tensor([44]) the target is: 53
9: When input is tensor([44, 53]) the target is: 56
10: When input is tensor([44, 53, 56]) the target is: 1
11: When input is tensor([44, 53, 56,  1]) the target is: 58
12: When input is tensor([44, 53, 56,  1, 58]) the target is: 46
13: When input is tensor([44, 53, 56,  1, 58, 46]) the target is: 39
14: When input is tensor([44, 53, 56,  1, 58, 46, 39]) the target is: 58
15: When input is tensor([44, 53, 56,  1, 58, 46, 39, 58]) the target is: 1
16: Wh

## Bigram Model

We've built a simple bigram model in the earlier part of this series. But since the dataset is newer, and there are some slight tweaks in the implementation, I am reimplementing the code.


In [11]:
class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size) # vocab_size X vocab_size lookup table

    def forward(self, idx, targets=None):
        
        logits = self.token_embedding_table(idx) # logits.shape = (B, T, C) = (4, 8, 65) in our case

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape

            logits = logits.view(B*T, C)
            targets = targets.view(B*T) # targets are of shape (B, T) 

            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx is a tuple of sample indices of characters from where to start generating

        for _ in range(max_new_tokens):
            logits, loss = self(idx) # logits.shape is (4, x, 65)

            # we want the row corresponding to the last character in each batch to predict next character- i.e. the last elem in T dimension
            logits = logits[:, -1, :]

            probs = F.softmax(logits, dim=1)

            next_idx = torch.multinomial(probs, num_samples=1)

            idx = torch.cat((idx, next_idx), dim=1)
        
        # idx will be the sequence generated for each batch
        return idx
    
m = BigramLanguageModel(vocab_size)

logits, loss = m(xb, yb)

print(f"Loss is: {loss.item():.4f}")
print("Generated sequence: ")
print(decode(m.generate(torch.zeros((1, 1), dtype=torch.long), max_new_tokens=50)[0].tolist()))

Loss is: 5.0364
Generated sequence: 

l-QYjt'CL?jLDuQcLzy'RIo;'KdhpV
vLixa,nswYZwLEPS'pt


**Note on Forward Pass:** Observe that for this bigram model, we don't have any context. So we can assume each character in each batch as a separate training example. For this model, the training examples are just one characters, as follows:

When input is `tensor([24])` the target is: 43
When input is `tensor([43])` the target is: 58

What is happening with the forward pass is that for each of the characters in each of the batch, the forward pass basically plucks out a row from the `token_embedding_table`. Since our `vocab_size` is 65, for a batch we get `logits` of shape `(4, 8, 65)`. For each, for each character in the batch we are plucking out a row from the embedding table and interpreting this row as the `logits`.

But there is one issue with this. PyTorch expects (B, C, ...) dimension in `F.cross_entropy()`. So we need to use `view` to change the shape f both the logits and the targets. Imagine it as a 3D cube. It helps a lot!

**Note On Generate function:** What is the wishlist for the generate function? For each of the batch, we want to generate the next token. This next token is based only on the last character that we generated, and *not* the entire batch! We haven't yet added context yet.

Further, we need to apply softmax to logits and draw a sample from it. And what we want is not just the next predicted token, but we want to add it to the current context which will be used to predict the next word again. 

### Training the Bigram Model

**Pro Tip:** For Adam, in practice, `lr=3e-4` works quite well. But for smaller datasets, you can have much faster learning rates like we are having. 

In [12]:
model = BigramLanguageModel(vocab_size=vocab_size)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

We also want a smoother loss estimate since the loss can vary batch to batch based on what sample is drawn. So we do what we did before to smooth the loss estimate.

In [13]:
@torch.no_grad()
def estimate_loss():
    out = { }

    # set model to eval mode
    model.eval()

    # for train and val data, take mean of 300 iters
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    
    # set model back to train mode
    model.train()
    return out

In [14]:
for epoch in range(epochs):
    
    if epoch % eval_interval == 0:
        losses = estimate_loss()
        print(f"Step {epoch}: train loss {losses['train']:.4f}, and val loss:{losses['val']:.4f}")
    
    # sample a batch from the dataset
    xb, yb = get_batch('train')

    # Evaluate loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(f"Loss is: {loss.item():.4f}\n")
print("Generated sequence: ")
print(decode(model.generate(torch.zeros((1, 1), dtype=torch.long), max_new_tokens=150)[0].tolist()))

Step 0: train loss 4.7741, and val loss:4.7923
Step 300: train loss 4.5333, and val loss:4.5842
Step 600: train loss 4.3361, and val loss:4.3582
Step 900: train loss 4.1544, and val loss:4.1761
Step 1200: train loss 3.9806, and val loss:3.9866
Step 1500: train loss 3.8436, and val loss:3.8441
Step 1800: train loss 3.7120, and val loss:3.7218
Step 2100: train loss 3.5824, and val loss:3.5902
Step 2400: train loss 3.4792, and val loss:3.4799
Step 2700: train loss 3.3325, and val loss:3.3767
Loss is: 3.3229

Generated sequence: 

otoOm ixALIntXZy'?mec-.
St3-R
crotha h? ALMtvegakVre,
shoEJQKZ;v?WN3???QmRfU-ENnV3q&XMmea; ik,hotwAya'R,PHJxAYWisJU'Pe;:weX?AqothyoiBr
 se I3DVr,ES.xy


Certainly the outputs we are getting are not Shakespeare like, and we're never going to get them with a bigram model but this is a decent start from the untrained model.

## First Self-Attention Block

The bigram model was not paying any attention to the context. It was just considering the previous character. Now, we will build a self-attention block that pays attention to context of characters.

### Attention, Mathematically

Think about what attention means. Attention is just a mechanism to let the past context interact with the current token. How do you make two vectors *interact*? By the way of addition and multiplication.

One simple way of encoding context would be to take average of embeddings of all the previous tokens in the context, and remember that the context length is fixed. But what taking a simple average means is that the current token is paying equal attention to the all the previous tokens in the context. We don't want that. We want the attention to be data dependent, which in other words means that this attention should be *learned* from the model.

So we're going to take a weighted average of the previous tokens but these weights are going to be learned by the model.

These weights need to sum up to 1. And each token should only look at itself and the previous token- not the tokens that come after it.

How do you encode this operation mathematically? Why, matrix multiplication, of course! If you think about it, if you have a lower diagonal weight matrix, then you can multiply it by the context and you're going to satisfy the condition that each token should only look at the token that comes before it.

```
weights @ context
```

In PyTorch, `torch.tril` gives us a lower triangular matrix. You need to pass the matrix that you want to convert to lower triangular.

How do we encode a simple average with this matrix multiplication?
Just divide each row by the number of non-zero entries in it. Think about it, in a lower triangular matrix, everything above the main diagonal is zero. So if you divide the each row by the number of non-zero entries in it, then essentially you are taking average. See the example below for the weights:

In [15]:
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [16]:
weights = torch.tril(torch.ones(T, T))
weights = weights / weights.sum(1, keepdim=True)
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

Now we can do matrix multiplication to get the simple average. But think about the dimensions of this multiplication.

```
weights.shape = (T, T)
x.shape = (B, T, C)
```

So, PyTorch is going to *slide* the matrix multiplication across the batch dimension and we're going to get an output that is of shape: ```(B, T, C)```

In [17]:
xbow = weights @ x
xbow.shape

torch.Size([4, 8, 2])

Our main requirement is that these attention weights need to be trainable. For the moment, they are not. So we need a way that allows us to learn those attention weights while still preserving the averaging mechanism.

That's where softmax is going to come in. Look up Mitesh Sir's Teacher Forcing or Masked Attention lectures for the intuition behind the `-inf` values. But essentially, these `-inf` and `0` values achieve the same thing as before:

1. They let the model only focus on previous tokens
2. Each row sums to 1.

Because:
1. $softmax(- \infty) = 0$
2. $exp(0) = 1$

What we want is that the weights at each row should sum to 1 and they should be learnable. We know a function that lets generate a row that sums to 1 and that is softmax.

In [18]:
tril = torch.tril(torch.ones(T, T))

weights = torch.zeros((T, T))
weights = weights.masked_fill(tril == 0, float('-inf')) # wherever tril == 0, replace with -inf
weights = F.softmax(weights, dim=-1)
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

Again we can do matrix multiplication and we get the same matrix as above:

In [19]:
(weights @ x).allclose(xbow)

True

Now, let's take a moment to appreciate what is happening here. For each token, we said we would combine the information from the previous tokens by addition and taking the average. With this mask, we effectively achieve the same. We prevent the information present in the future tokens to be combined with the present token because that information is going to get multiplied by zero.

The `weights` matrix is essentially what is going to be defining the how much attention to pay to each of the previous tokens given a certain token. We want the model to learn this matrix.

### Building a Self-Attention Block

Now, let's think about how self-attention is implemented. Each token (of the block_size tokens) *emits* two vectors- a key and a query. You'll introduce a Key and Query matrix as parameters which transform a given token. Then linear transformation of token with Key matrix gives the key vector and similarly for the query vector.

- **Key Vector:** Captures "What I contain"
- **Query Vector:** Captures "What am I looking for"

So far, there is no communication between tokens.

But then, the way we're going to find whether a token is worth paying attention to is by doing dot product of queries and keys. This is how we're going to find attention weights.   For each token, you want to find out whether the other tokens contain what the current token is looking for. If the dot product of these two vectors is high, then that means that the key vector contains what the query vector was looking for. 

This dot product is what the attention weights matrix will be!

In [34]:
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

head_size = 16 # hyperaparam
key = nn.Linear(C, head_size, bias=False) # W_K matrix
query = nn.Linear(C, head_size, bias=False) # W_Q matrix

# Apply linear transforms to get key and query vectors for each
k = key(x) # shape is (B, T, 16)
q = query(x) # shape is (B, T, 16)

weights = q @ k.transpose(-2, -1) # Transpose T, C of `k`


Small comment on the matrix multiplication of `q` and `k`. This is essentially going to achieve pairwise dot product of queries and keys. But notice that we have a batch dimension. We don't want to use batch dimension when doing matrix multiplication. So we need to transpose only the last two dimensions. So the shape of weights now is: `(B, T, T)`. $ T \times T$ is what we want- pairwise attention weights, and we have those pairwise attention weights for each the batches. 

Now that we have the `weights`, we can do what we did before, we can mask it and apply softmax to get attention weighted aggregates of the previous tokens. Again, these are just weights to take a weighted average of the previous tokens such that the information in them is encoded in the current token.

In [36]:
tril = torch.tril(torch.ones(T, T))

weights = weights.masked_fill(tril==0, float('-inf'))
weights = F.softmax(weights, dim=-1)
out = weights @ x

weights


tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.6157, 0.3843, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2814, 0.2690, 0.4496, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3489, 0.2214, 0.2302, 0.1995, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1765, 0.1930, 0.2380, 0.2135, 0.1790, 0.0000, 0.0000, 0.0000],
         [0.1336, 0.1421, 0.3111, 0.1353, 0.1436, 0.1342, 0.0000, 0.0000],
         [0.1263, 0.1691, 0.1358, 0.1442, 0.1668, 0.1258, 0.1321, 0.0000],
         [0.1175, 0.1080, 0.1113, 0.1101, 0.1083, 0.2207, 0.1159, 0.1083]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5956, 0.4044, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3859, 0.3271, 0.2870, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2111, 0.2441, 0.2455, 0.2993, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1561, 0.1779, 0.1557, 0.1563, 0.3540, 0.0000, 0.0000, 0.0000],
         [0.1362, 0.139

As we can see, `weights` now contains the attention weights to take the weighted average of only the past tokens and they are not uniform- so we have learnable parameters now!

And `out` is now the modified input `x` by taking the weighted average using the attention weights.

There is one additional component in the self attention block. That is, we don't modify the input `x` directly. We pass `x` through a linear transformation to produce a `value` for that input and then we modify this value. 

Another way of thinking about this is that the information contained in the `x` is private. So it tells the other vectors what it contains (key), what it is looking for (query), and what it will communicate with them (value). So we need a linear transform, which again are parameters.

In [37]:
value = nn.Linear(C, head_size)
v = value(x)
out = weights @ v
out.shape # out shape is now 16D instead of 32D as before

torch.Size([4, 8, 16])

All this Q, K, V matrix multiplications and the weighted aggregation constitutes one *head* of attention. 

### Andrej's Notes on Self-Attention

1. Attention is just a communication mechanism- not specific to language modeling. You can look at it as a directed graph where each node points to other nodes with which it is communicating, and the weights of those edges? The attention weights! So you can apply attention to ANY directed graph!

2. There is no notion of space in attention. Attention simply acts over a set of vectors. So if you have a problem where the sequence is important, then you need to encode the positional information, like we did.

3. There is no communication across batches. With the matrix multiplications, the elements within a batch communicate but across batches there is no communication. 

4. Not all tasks require you to have a masked attention. In this task, we wanted to predict the future token and thus we didn't want the tokens to communicate with the future. But let's say your task is sentiment analysis, then there is no need for masking. You will have the entire sentiment with you and you can allow the communication with future tokens also. Without masking the self-attention block is called an "encoder block", and with masking it is called the "decoder block". Look up Mitesh sir's lectures on Encoder only and Decoder only models. Attention by itself has no notion about which nodes it can communicate with.

5. Self-attention means that the keys, queries, and values are coming from same source- the input `x`. But there is cross-attention also where the queries come from different source and the keys & values come from an entirely different source, such as an encoder-decoder architecture.

6. In the original "Attention is All You Need" paper, we have the equation:
$softmax(\dfrac{QK^T}{\sqrt{d_k}})V$. We have done all the matrix multiplication part but there is also `sqrt(head_size)`. The attention is being *scaled* so it's called Scaled attention, and it is to make sure that when the input to Q and K is unit variance, then the `weights` is unit variance too, so as to avoid saturation when being passed through the softmax. So at least at initialization, you want `weights` to avoid taking too extreme values. So when we are doing the `q @ k` matrix multiplication, we need to multiply it by `C ** -0.5`.

### Creating a Head Module:

In [38]:
n_embd = 32 # this is a global variable
block_size = 8

class Head(nn.Module):
    """Single head of self-attention"""

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape

        k = self.key(x)
        q = self.query(x)

        weights = q @ k.transpose(-2, -1) * C**-0.5

        # TODO: need to understand indexing in tril for this line. It's most likely to handle less than block_size inputs
        weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf') )
        weights = F.softmax(weights, dim=-1)

        v = self.value(x)

        out = weights @ v
        return out


In PyTorch, whatever is not a parameter of the model is called a `buffer`. It is just creating the `tril` variable and assigning it to the module but it is not a parameter of the module. 

## A Language Model With Attention Head

We introduce an additional linear layer called `lm_head` which will transform the output of attention head to match the dimensions of the targets. 

We also need to make sure that there are at maximum only `block_size` tokens that are passed to the model, otherwise the `position_embedding_table`, which captures the positional information upto block size, is going to throw an error. 

**Pro Tip:** Self-attention cannot handle very high learning rates. So you need something like `1e-3` or lower, and you need to train for more number of iterations. 

In [42]:
n_embd = 32
block_size = 8

class LangModelWithOneHead(nn.Module):
    def __init__(self):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_head = Head(n_embd)

        # Final linear layer to produce logits of equal dim as targets
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))

        x = tok_emb + pos_emb

        x = self.sa_head(x) 

        logits = self.lm_head(x)

        if targets is None:
            loss = None
        
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):

        for _ in range(max_new_tokens):
            
            # You can't have more than block_size tokens in context now because pos_embd will throw an error.
            # so keep only the last block_size tokens
            idx_cond = idx[:, -block_size:]

            logits, loss = self(idx_cond) # logits.shape is (4, x, 65)

            logits = logits[:, -1, :]

            probs = F.softmax(logits, dim=1)

            next_idx = torch.multinomial(probs, num_samples=1)

            idx = torch.cat((idx, next_idx), dim=1)
        
        # idx will be the sequence generated for each batch
        return idx

### Training the Attention Based Model

In [44]:
lr = 1e-3
epochs = 10_000
eval_interval = 1000

model = LangModelWithOneHead()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

In [45]:
for epoch in range(epochs):
    
    if epoch % eval_interval == 0:
        losses = estimate_loss()
        print(f"Step {epoch}: train loss {losses['train']:.4f}, and val loss:{losses['val']:.4f}")
    
    # sample a batch from the dataset
    xb, yb = get_batch('train')

    # Evaluate loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()


print(f"Loss is: {loss.item():.4f}\n")
print("Generated sequence: ")
print(decode(model.generate(torch.zeros((1, 1), dtype=torch.long), max_new_tokens=150)[0].tolist()))

Step 0: train loss 4.2447, and val loss:4.2429
Step 1000: train loss 2.7729, and val loss:2.7496
Step 2000: train loss 2.5968, and val loss:2.5753
Step 3000: train loss 2.4999, and val loss:2.5141
Step 4000: train loss 2.4827, and val loss:2.4808
Step 5000: train loss 2.4716, and val loss:2.5132
Step 6000: train loss 2.4719, and val loss:2.4548
Step 7000: train loss 2.4696, and val loss:2.4049
Step 8000: train loss 2.4089, and val loss:2.4405
Step 9000: train loss 2.4348, and val loss:2.4504
Loss is: 2.6167

Generated sequence: 

hont;
I nd st thand ginereay idigailith oursh, ble Pomont hag homavom, mhetos core arve d'd,
RE:
G:
S:
Y Eved wite, bou suefat rint to onkll why wr he


The loss is slightly better but it needs to be a whole lot better than this also, as the text being produced is still not as impressive. 

## Multi-Head Attention