# Let's build GPT

## Loading and Tokenizing the Data

In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# Read in the text file
with open('input.txt') as f:
    text = f.read()

# Get all the unique chars in this list to create our vocab
chars = sorted(list(set(text)))
vocab_size = len(chars)

print(len(text), vocab_size)

1115394 65


In [3]:
# Tokenize the text
stoi = {ch: i for i,ch in enumerate(chars)}
itos = {i: ch for i,ch in enumerate(chars)}

encode = lambda s: [stoi[c] for c in s]             # encode: string -> list of ints
decode = lambda l: ''.join([itos[i] for i in l])    # decode: list of ints -> string

print(encode("hello world"))

[46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42]


In [15]:
# Encode the entire piece of text and store as torch.tensor
data = torch.tensor(encode(text), dtype=torch.long)
data.shape  # vector of ints corresponding to each and every char

torch.Size([1115394])

Now we can start thinking of what our model will actually do.

It works off taking a chunk of integers, whose length is at most `ctx_len`, and will predict the next integer in the sequence. Note that the chunks can be any length, we just have to specify the maximum context length for our model.

When we take a chunk of 8 chars, we don't just predict the next character after this sequence of 8 chars - we train our model to predict **at each and every one of these positions**. This means we have $n$ different training examples for each context of length $n$.

The model is being made to predict at contexts with sizes all the way from 1 till `ctx_len`; this means it has the ability to predict the next token and start generating when it's been given just one token of context.

In [16]:
ctx_len = 8

x = data[:ctx_len]
y = data[1:ctx_len+1]   # the target is the window offset by 1 token

for t in range(ctx_len):
    context = x[:t+1]   # grab the first t tokens in x
    target = y[t]       # since y is just x shifted by 1
    print(f"{context.tolist()} --> {target}")

[18] --> 47
[18, 47] --> 56
[18, 47, 56] --> 57
[18, 47, 56, 57] --> 58
[18, 47, 56, 57, 58] --> 1
[18, 47, 56, 57, 58, 1] --> 15
[18, 47, 56, 57, 58, 1, 15] --> 47
[18, 47, 56, 57, 58, 1, 15, 47] --> 58


In [17]:
# Now to create a function to generate a batch of these chunks
batch_size = 4

def get_batch():
    # All we really need is a set of batch_size random indices
    idxs = torch.randint(len(data)-ctx_len, (batch_size,))  # (B,)
    x = torch.stack([data[i:i+ctx_len] for i in idxs])      # (B, ctx_len)
    y = torch.stack([data[i+1:i+ctx_len+1] for i in idxs])  # (B, ctx_len) (x offset by 1)
    return x, y

x, y = get_batch()
print(x)
print()
print(y)
print()

# Print out the examples for the first sequence in the batch
for t in range(ctx_len):
    context = x[0, :t+1]   # grab the first t tokens in x
    target = y[0, t]       # since y is just x shifted by 1
    print(f"{context.tolist()} --> {target}")

tensor([[41, 58, 47, 53, 52,  1, 42, 47],
        [56, 47, 52, 45,  1, 46, 43, 50],
        [15, 39, 50, 39, 47, 57,  0, 16],
        [58, 46, 53, 59, 57, 39, 52, 42]])

tensor([[58, 47, 53, 52,  1, 42, 47, 45],
        [47, 52, 45,  1, 46, 43, 50, 51],
        [39, 50, 39, 47, 57,  0, 16, 47],
        [46, 53, 59, 57, 39, 52, 42,  1]])

[41] --> 58
[41, 58] --> 47
[41, 58, 47] --> 53
[41, 58, 47, 53] --> 52
[41, 58, 47, 53, 52] --> 1
[41, 58, 47, 53, 52, 1] --> 42
[41, 58, 47, 53, 52, 1, 42] --> 47
[41, 58, 47, 53, 52, 1, 42, 47] --> 45


##  Language Model

The simplest model we can start with is the  language model, where we predict the next character in the sequence *purely* by looking at the current character.

This can be formulated as using Embeddings - a simple lookup table - to get the representation of the current character. We could set the embedding dimensionality to be the same as the vocab size so that our output from the table can be interpreted as a probability distribution over the vocabulary.

So, we're taking as input a $(B,T)$ tensor, and the model will output a $(B,T,V)$ tensor, where $V$ is the vocabulary size. This is the case because we are predicting a token (i.e. predicting the probability distribution over the entire vocab), for all $B \times T$ positions in the input.

The point of interest is that the tokens are treated independently of each other, in a way, i.e. only one token is used to predict the next token. We'd like to let the tokens *interact* with each other in a way, so that the model can learn more complex patterns, with the context windows we've defined.

In [23]:
class LanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.emb_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, x, targets=None):
        
        logits = self.emb_table(x) # (B, T, C)

        if targets is not None:
            # Get the loss here too
            loss = F.cross_entropy(
                logits.view(-1, vocab_size), # (B*T, C) - predicting for each token
                targets.view(-1)            # (B*T)   - the true target
            )
        else:
            loss = None

        return logits, loss
    
    def generate(self, ctx, max_new_tokens):
        '''
        Generate a sequence of new tokens given a context.

        Parameters
        ----------
        ctx: torch.tensor (B, T)
            The starting context to condition on.

        max_new_tokens: int
            The maximum number of new tokens to generate.
        '''

        for _ in range(max_new_tokens):

            # Get the predictions
            logits, _ = self(ctx)
            logits = logits[:, -1, :]  # focus on the last token to predict what comes next - (B,C) now

            # Normalize and sample the next token
            probas = F.softmax(logits, dim=-1) # (B, C)
            next_token = torch.multinomial(probas, 1) # (B, 1)

            # Update context
            ctx = torch.cat([ctx, next_token], dim=1)
        
        return ctx

    
model = LanguageModel(vocab_size)
logits, loss = model(x, y)
print(logits.shape)
print(loss.item())

torch.Size([4, 8, 65])
4.483732223510742


In [26]:
# Generate from the model with an empty context - 0 is newline char
idx = torch.tensor([0]).view(1,1) # (B,T)

preds = model.generate(idx, 10)

decode(preds[0].tolist())           # decode the first sequence in the batch

'\n.a!BmqpFup'

Let's improve upon this design.

Let's use the aforementioned Embedding table just to get the embeddings or representations for the tokens themselves. Then, we can use a simple feedforward network to predict the next token, using the embeddings of the tokens as input.

Alongside this, we will put in an additional Embedding table to encode the position of the token in the sequence. This is because the model should be able to learn that the first token in the sequence is different from the second token, and so on.

In [39]:
class LanguageModelv2(nn.Module):

    def __init__(self, emb_dim=32):

        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, emb_dim)
        self.pos_emb = nn.Embedding(ctx_len, emb_dim)
        self.lm_head = nn.Linear(emb_dim, vocab_size)

    def forward(self, x, targets=None):

        B, T = x.shape
        
        # Get the logits (incorporating positional information too)
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb(
            torch.arange(T, device=x.device) # (T,) - the position of each token
        )
        x = tok_emb + pos_emb   # (B, T, C)
        logits = self.lm_head(x)    # (B, T, vocab_size)

        if targets is None:
            loss = None

        else:
            # Get the loss
            loss = F.cross_entropy(
                logits.view(-1, vocab_size), # (B*T, C) - predicting for each token
                targets.view(-1)            # (B*T)   - the true target
            )

        return logits, loss
    
    def generate(self, ctx, max_new_tokens):
        '''
        Generate a sequence of new tokens given a context.

        Parameters
        ----------
        ctx: torch.tensor (B, T)
            The starting context to condition on.

        max_new_tokens: int
            The maximum number of new tokens to generate.
        '''

        for _ in range(max_new_tokens):

            # Get the predictions
            logits, _ = self(ctx)
            logits = logits[:, -1, :]

            # Normalize and sample the next token
            probas = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probas, 1)

            # Update context
            ctx = torch.cat([ctx, next_token], dim=1)

        return ctx

## The mathematical trick in Self-Attention

The whole point of Self-Attention is to get the tokens interacting and talking with one another. In the previous implementation of just using the last token to predict the next one, the tokens were decoupled in a way, and incapable of interacting with each other.

The simplest way of getting other tokens to mingle with one another is to compute the average of the $t$-th token's representation with all the other tokens' representations.

We want to do this in a very specific way, as to only incorporate information that is *not* from the future, i.e. only the previous tokens in the context can be used for this aggregation.

In [28]:
B, T, C = 4, 8, 2 # batch size, context length, channels/hidden size
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [33]:
# Compute the new representation of each token as the average of all the previous tokens
# Version 1: naive implementation
xbow = torch.zeros((B, T, C))

for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1]              # grab the first t tokens in x - (t+1, C)
        xbow[b, t] = xprev.mean(dim=0)  # average across the time dimension - (B, T, C)

# Note how the first t rows of x are averaged to give the t-th row of xbow
print(x[0])
print()
print(xbow[0])

tensor([[-0.1816,  0.1048],
        [ 0.3941,  0.5729],
        [ 1.5386,  0.4016],
        [ 0.8023,  1.2386],
        [ 0.6774, -1.2224],
        [-1.0634,  0.3195],
        [-0.4746,  0.6998],
        [ 0.4340,  0.5938]])

tensor([[-0.1816,  0.1048],
        [ 0.1062,  0.3389],
        [ 0.5837,  0.3598],
        [ 0.6384,  0.5795],
        [ 0.6462,  0.2191],
        [ 0.3612,  0.2358],
        [ 0.2418,  0.3021],
        [ 0.2659,  0.3386]])


Note how the first element in both matrices is the same (since it has no extra context). On top of this, the second row in the second matrix is an average of the first two rows in the first matrix, and so on. This is taking as much context as we can.

There is a lot of information lost when we process things this way. We can be more efficient if we use Matrix Multiplication.

If we can set up the matrices as being of shape $(T, C)$, this means creating a matrix of weights of shape $(T, T)$, and multiplying the two matrices together, we can get the same result as the previous method. 

This is because each row in the weight matrix (having a uniform distribution) will be multiplied with a column of the input matrix (representing a single feature across all tokens).

A lower-triangular matrix is perfect for this.

In [37]:
# Version 2
ws = torch.tril(torch.ones(T, T))
ws = ws / ws.sum(-1, keepdim=True)  # normalize the ws to sum to 1 along the time dimension
print(ws)

# ws @ x --> (T, T) @ (B, T, C) --> (B, T, T) @ (B, T, C) --> (B, T, C)
xavg = ws @ x
torch.allclose(xavg, xbow)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


True

An even cooler way of doing this is with the Softmax Activation Function, in that we perform it across the rows of each row in the weight matrix.

If we start off with a matrix of all ones, we can set up the post-softmax matrix as getting zeros in the upper region, by filling it with negative infinity in the upper region, and then applying the softmax function.

In [38]:
# Version 3: Softmax
tril = torch.tril(torch.ones(T, T))

# Create a weights matrix that is filled with -inf in the top right corner
# This is what measures the affinities between tokens when we perform the aggregation
ws = torch.zeros((T, T))
ws[tril == 0] = float('-inf')
print(ws)

# Apply softmax to the ws
# The -inf masking will say "tokens cannot communicate with anything in the future"
ws = F.softmax(ws, dim=-1)
print(ws)

# Apply our matmul again
# ws @ x --> (T, T) @ (B, T, C) --> (B, T, T) @ (B, T, C) --> (B, T, C)
xavg = ws @ x
torch.allclose(xavg, xbow)

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


True

## Self-Attention

We can use the tricks above to implement Self Attention properly now. Instead of using the simple average with the Lower Triangular Matrix (that implements the notion of *blocking out the future*), we want to more towards a more sophisticated form of these affinities (rather than relying on uniform numbers). **Self-Attention solves this in a data-driven way**.

The idea is this: every token in the input sequence emits a **query** and a **key** (alongside the **value**). These have the following ideas behind them:
* The Query says "What am I looking for/Here's what I'm interested in..."
* The Key says "What do I contain/This is what I have..."
* The Value says "If you find me interesting, here's what I will communicate to you..."

The way we get the *affinities between tokens* now is to simply take a dot product between the Query and the Key. 

The Query for a *specific token* emits a certain value, representing what it's looking for. Now, all the tokens in the input sequence emit their Keys, representing what that token is offering. If that specific token, say at position 8 (whose Query we take), finds that a token at postion 4 produces a high value when we take the dot product of the Query and Key (at positions 8 and 4 respectively), then the model has learned something meaningful about the meaning of that 8th token (new information has been aggregated, so the model has learned more *about it*).

In [50]:
# Create an input matrix whose token representations we wish to refine
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)

# Create linear functions to get the query and key matrices
head_size = 16  # the size of the query and key matrices

query = nn.Linear(C, head_size, bias=False)
key = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

q = query(x)  # (B, T, head_size) - produce all the keys of the tokens independently of each other
k = key(x)    # (B, T, head_size) - produce all the queries of the tokens independently of each other
v = value(x)  # (B, T, head_size) - produce all the values of the tokens independently of each other

# Compute the attention scores/affinities (that function as the ws from before)
# This is where they start mingling with each other
# q @ k.T --> (B, T, head_size) @ (B, head_size, T) --> (B, T, T)
ws = q @ k.transpose(-2, -1) # (B, T, T)
print(f"Shape of attention weights: {ws.shape}")

# With our attention scores, now we can do the same masking+softmax procedure as before
mask = torch.tril(torch.ones(T, T))
ws = ws.masked_fill(mask == 0, float('-inf'))
ws = F.softmax(ws, dim=-1)

# Use these attention weights to aggregate the values (what each token has to offer)
out = ws @ v  # (B, T, T) @ (B, T, head_size) --> (B, T, head_size)
print(f"Shape of aggregated values: {out.shape}")

Shape of attention weights: torch.Size([4, 8, 8])
Shape of aggregated values: torch.Size([4, 8, 16])


In [44]:
# Note how this time the weights are different across batches - i.e. they aren't uniform anymore
ws

tensor([[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [6.2145e-01, 3.7855e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [2.0093e-01, 4.2380e-01, 3.7527e-01, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [1.6418e-01, 3.2049e-01, 2.6588e-01, 2.4945e-01, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [1.6415e-01, 2.0907e-01, 2.3029e-01, 1.8816e-01, 2.0834e-01,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [3.4858e-01, 1.3391e-01, 8.5627e-02, 2.0435e-01, 1.3707e-01,
          9.0452e-02, 0.0000e+00, 0.0000e+00],
         [1.1828e-01, 1.5938e-01, 5.2063e-02, 1.5644e-01, 1.9018e-01,
          1.1777e-01, 2.0589e-01, 0.0000e+00],
         [2.1552e-01, 1.1294e-01, 8.1352e-02, 1.5061e-01, 1.1517e-01,
          8.5889e-02, 3.5706e-02, 2.0280e-01]],

        [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.00

Some notes:
* Attention is a **communication mechanism** in that it allows nodes in any directed graph to aggregate information from other nodes, specifically those that *point to it*. We can think of text as being a specific type of directed graph where the first token only points to itself, the second is being pointed to by the first (and itself), and so on, until the last token is being pointed to by *all the previous nodes, and itself*.

* Attention has no notion of *space*, which is the reason why we implement the Positional Embeddings that we'll see later.

* Each instance in the batch is processed independently of one another. So if we have 4 batches each having a max context length of 8, then we really have a total of 32 nodes in our graph. But because each batch is processed independently, the nodes in one batch don't communicate with the nodes in another batch - i.e. there are no edges between nodes of different batches.

* In an Encoder, if we don't want to restrict the tokens talking to the future tokens, we just get rid of the masking portion of the code with the lower triangular matrix. In the Decoder though, when we are trying to generate new tokens, we don't want the future tokens communicating with the past tokens (otherwise they'd be giving away the answer, on top of this interaction not even being possible), which is why we'd bring the masking feature in there.

* "Self-Attention" is specifically Attention when the Keys, Queries and Values are coming from the same source. Attention is more general than that: "Cross Attention" is when the Keys and the Values are coming from an external source outside the source of the Keys.

**Side note:** why is it we divide by $\sqrt{d_k}$ in computing the attention scores?

We find that if we don't do this, then the result of the pre-softmax Attention Scores incur a very high variance, on the order of $d_k$. When they incur this high variance, this means some values are highly positive while others are much smaller - this means that post-Softmax, we may find a vector that is similar to a one-hot vector (which means that each token will only aggregate information from a single node only).

In [58]:
# Instantiate query and keys as being unit-gaussian
q = torch.randn(B, T, 16)
k = torch.randn(B, T, 16)

print(f"Variance of query: {q.var()}")
print(f"Variance of key: {k.var()}")
print(f"Variance of Q @ K.T: {(q @ k.transpose(-2, -1)).var()}")

# Note how the variance of the dot prod is on the order of head size
# We divide by that amount to keep the variance of the dot prod at 1
ws = q @ k.transpose(-2, -1)
ws = ws / (16**0.5)
print(f"Variance of (Q @ K.T)/sqrt(head_size): {ws.var()}")

Variance of query: 0.9703316688537598
Variance of key: 0.9714754819869995
Variance of Q @ K.T: 13.188318252563477
Variance of (Q @ K.T)/sqrt(head_size): 0.8242698907852173


## Self-Attention

We start off by incorporating this mechanism as its own module.

The procedure is the following:

1. Take the input embeddings as a `(B, T, C)` matrix.

2. Extract the Query, Key, and Value matrices from it.

3. Compute the scaled Attention Scores, with the dot product and constant factor.

4. Apply the masking to the upper triangular region, filling in with `-inf`, before passing through a Softmax.

5. Take the matmul with the Value matrices.

In [68]:
emb_dim = head_size

class Head(nn.Module):

    def __init__(self, emb_dim, head_size):
        super().__init__()

        self.query = nn.Linear(emb_dim, head_size, bias=False)
        self.key = nn.Linear(emb_dim, head_size, bias=False)
        self.value = nn.Linear(emb_dim, head_size, bias=False)

        # Register a buffer that we use for masking (Decoder structure)
        self.register_buffer(
            "tril",
            torch.tril(torch.ones((ctx_len, ctx_len)))  # same thing as before
        )

    def forward(self, x):

        B, T, C = x.shape

        # Get the qkv matrices
        q = self.query(x)   # (B, T, C)
        k = self.key(x)     # (B, T, C)
        v = self.value(x)   # (B, T, C)

        # Compute the (scaled) Attention Scores
        # (B, T, C) @ (B, C, T) --> (B, T, T)
        att_scores = q @ k.transpose(-2, -1) * (C ** -0.5)

        # Perform the masking
        att_scores = att_scores.masked_fill(
            self.tril[:T, :T] == 0,             # be careful if T < ctx_len
            float('-inf')
        )
        att_scores = F.softmax(att_scores, dim=-1)

        # Aggregate with the value vectors
        # (B, T, T) @ (B, T, C) --> (B, T, T)
        out = att_scores @ v
        
        return out
    
x = torch.randn(4, 8, 2)
sa_head = Head(2, 2)
out = sa_head(x)
print(f"{x.shape} ---> {out.shape}")

torch.Size([4, 8, 2]) ---> torch.Size([4, 8, 2])


Now that we've made this class, we can incorporate it into our Language Model implementation.

This is just one additional line of code, where we pass our embedded inputs, with positional encodings, into this module - this just _refines_ the representations, making no changes to the actual shapes (provided we let `emb_dim == head_size`).

In [61]:
class LanguageModelv3(nn.Module):

    def __init__(self, emb_dim=32):

        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, emb_dim)
        self.pos_emb = nn.Embedding(ctx_len, emb_dim)
        self.sa_head = Head(emb_dim, head_size)
        self.lm_head = nn.Linear(emb_dim, vocab_size)

    def forward(self, x, targets=None):

        B, T = x.shape
        
        # Get the logits (incorporating positional information too)
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb(
            torch.arange(T, device=x.device) # (T,) - the position of each token
        )
        x = tok_emb + pos_emb   # (B, T, C)

        # New line here: refine the representations with Self-Attention
        x = self.sa_head(x)
        # END

        logits = self.lm_head(x)    # (B, T, vocab_size)

        if targets is None:
            loss = None

        else:
            # Get the loss
            loss = F.cross_entropy(
                logits.view(-1, vocab_size), # (B*T, C) - predicting for each token
                targets.view(-1)            # (B*T)   - the true target
            )

        return logits, loss
    
    def generate(self, ctx, max_new_tokens):
        '''
        Generate a sequence of new tokens given a context.

        Parameters
        ----------
        ctx: torch.tensor (B, T)
            The starting context to condition on.

        max_new_tokens: int
            The maximum number of new tokens to generate.
        '''

        for _ in range(max_new_tokens):

            # Crop to the last ctx_len tokens
            idx_cond = idx[:, -ctx_len:]

            # Get the predictions
            logits, _ = self(ctx)
            logits = logits[:, -1, :]

            # Normalize and sample the next token
            probas = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probas, 1)

            # Update context
            ctx = torch.cat([ctx, next_token], dim=1)

        return ctx

## MultiHead Self Attention

To allow for **more independent communication channels**, which means allowing for tokens to take on more than just one combination of context clues, we allow for the Transformer to use multiple of these Self Attention Heads.

Note however that we change up the dimension of the Head Size, since at the end, we want the output to stay the same, after we concatenate information from all the heads.

Note that the processing for each head will run in parallel so we suffer no cost in terms of time.

In [70]:
class MultiHeadSelfAttention(nn.Module):

    def __init__(self, emb_dim, num_heads, head_size):

        super().__init__()

        self.heads = nn.ModuleList([
            Head(emb_dim, head_size) for _ in range(num_heads)
        ])

    def forward(self, x):
        return torch.cat([
            head(x) for head in self.heads      # compute output for each head in parallel
        ], dim=-1)                              # concatenate along the C channel

x = torch.randn(4, 8, 32)
mhsa = MultiHeadSelfAttention(emb_dim=32, num_heads=4, head_size=32//4)
out = mhsa(x)
print(f"{x.shape} ---> {out.shape}")

torch.Size([4, 8, 32]) ---> torch.Size([4, 8, 32])


While we created this using our single `Head` class, we could have also done this in one go.

The following snippet is by Copilot:

```python
class MultiHeadSelfAttentionv0(nn.Module):
    def __init__(self, emb_dim, num_heads):
        super().__init__()
        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.head_dim = emb_dim // num_heads
        
        self.query = nn.Linear(emb_dim, emb_dim)
        self.key = nn.Linear(emb_dim, emb_dim)
        self.value = nn.Linear(emb_dim, emb_dim)
        self.fc = nn.Linear(emb_dim, emb_dim)
        
    def forward(self, x):
        B, T, C = x.size()
        
        queries = self.query(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # (B, num_heads, T, head_dim)
        keys = self.key(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # (B, num_heads, T, head_dim)
        values = self.value(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # (B, num_heads, T, head_dim)
        
        attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** 0.5)  # (B, num_heads, T, T)
        attention_probs = F.softmax(attention_scores, dim=-1)  # (B, num_heads, T, T)
        
        attended_values = torch.matmul(attention_probs, values)  # (B, num_heads, T, head_dim)
        attended_values = attended_values.transpose(1, 2).contiguous().view(B, T, C)  # (B, T, emb_dim)
        
        output = self.fc(attended_values)  # (B, T, emb_dim)
        
        return output
```

## Feedforward Layers

In the paper, the authors did not immediately feed in the processed inputs to a classification layer, rather there were intermediate FeedForward layers that allowed for more of this intermediate processing of activations.

This can be thought of as Self-Attention and the multiple heads aggregating that information, and the Feedforward layers allowing these tokens to *think* on this new aggregated information.

In [None]:
class FeedForward(nn.Module):

    def __init__(self, emb_dim):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(emb_dim, 4*emb_dim),
            nn.ReLU(),
            nn.Linear(4*emb_dim, emb_dim),
        )
    
    def forward(self, x):
        return self.ff(x)
    

## Blocks, Residual Connections and Layer Normalization

The Transformer can be divided blocks that are further segmented into two separate components:

* the MHSA that performs the **communication**

* the Feedforward layers that perform the **computation**

Each block consists of the MHSA module, followed by the Feedforward network on all the tokens independently.

In [71]:
class Block(nn.Module):

    def __init__(self, emb_dim, num_heads):

        super().__init__()
        self.head_size = emb_dim // num_heads
        self.mhsa = MultiHeadSelfAttention(emb_dim, num_heads, self.head_size)
        self.ff = FeedForward(emb_dim)

    def forward(self, x):

        x = self.mhsa(x)
        x = self.ff(x)
        return x

Now when we start to create a lot more blocks in the architecture, we find that this model ends up becoming rather *deep*. 

This could pose certain issues regarding the optimization of the parameters.

The authors of the paper utilize two approaches that aim to resolve these optimization issues.

First we have **(1) Skip/Residual Connections** that distributes gradients equally along both a *residual pathway* (going unimpeded from the inputs to the outputs) - we fork off, do some computation, then come back. We have to implement these skip connections (via addition) and projections in a few of our classes.

The other improvement we have is **(2) Layer Norm**. This is very similar to Batch Normalization, except that now, instead of normalizing along the columns, we normalize along the rows (for every single example).

We don't need any distinction between training and test time, we don't require any buffers for the running mean and variances, we can apply this simpler algorithm any time we wish with no previous state (except the parameters).

Slight deviation from the initial paper: the `LayerNorm` is applied *before* each of the transformations, rather than after.

In [74]:
# Write the old classes here again

class Head(nn.Module):

    def __init__(self, head_size, emb_dim=32, ctx_len=8):
        super().__init__()
        
        self.key = nn.Linear(emb_dim, head_size, bias=False)
        self.query = nn.Linear(emb_dim, head_size, bias=False)
        self.value = nn.Linear(emb_dim, head_size, bias=False)

        self.register_buffer('tril', torch.tril(torch.ones(ctx_len, ctx_len)))

    def forward(self, x):

        B,T,C = x.shape # batch size, num tokens, embedding dim

        k = self.key(x)
        q = self.query(x)

        weights = q @ (k.transpose(-1, -2)) / (C**0.5)
        weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # decoder block
        weights = F.softmax(weights, dim=-1)

        v = self.value(x)
        out = weights @ v
        return out


class MultiHeadAttention(nn.Module):

    def __init__(self, emb_dim, num_heads, ctx_len):
        super().__init__()
        head_size = emb_dim // num_heads
        self.heads = nn.ModuleList([Head(head_size, emb_dim=emb_dim, ctx_len=ctx_len) for _ in range(num_heads)])

        # NEW: add in a projection layer to let the model "think" on the aggregated information even more
        self.proj = nn.Linear(emb_dim, emb_dim) # project back to original embedding dim
        # -----

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1) # output of raw MHSA: (B, T, C)
        out = self.proj(out) # projection back to original embedding dim
        return out

class LayerNorm:

    def __init__(self, dim, eps=1e-5, momentum=0.1):
        self.eps = eps
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)

    def __call__(self, x):
        
        # Normalize the rows for each example
        xmean = x.mean(dim=1, keepdim=True)
        xvar = x.var(dim=1, keepdim=True)
        xhat = (x-xmean) / torch.sqrt(xvar + self.eps)
        self.out = self.gamma * xhat + self.beta
        
        return self.out
    
    def parameters(self):
        return [self.gamma, self.beta]

class Feedforward(nn.Module):

    def __init__(self, emb_dim):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(emb_dim, 4*emb_dim), # small change: 4x bigger hidden layer (from the paper)
            nn.ReLU(),
            nn.Linear(4*emb_dim, emb_dim), # projection back into residual pathway
            nn.Dropout(0.2), # NEW: dropout
        )

    def forward(self, x):
        return self.net(x)

Now we bring all of this together to form one `Block`, which contains the MHSA module, the FF net, the LayerNorm modules, and the skip connections.

Note the output shape of this: its still really just refining the input representations.

In [77]:
class Block(nn.Module):

    def __init__(self, emb_dim, num_heads, ctx_len):
        super().__init__()
        
        self.sa_heads = MultiHeadAttention(emb_dim=emb_dim, num_heads=num_heads, ctx_len=ctx_len)
        
        self.ffwd = Feedforward(emb_dim=emb_dim)
        self.ln1 = nn.LayerNorm(emb_dim) # for each MHSA layer
        self.ln2 = nn.LayerNorm(emb_dim) # for the feedforward layer

    def forward(self, x):
        # NEW: Add in the skip connections, and the LayerNorms
        x = x + self.sa_heads(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x
    
x = torch.randn(64, 256, 384)
block = Block(384, 8, 256)
out = block(x)
print(f"{x.shape} ---> {out.shape}")

torch.Size([64, 256, 384]) ---> torch.Size([64, 256, 384])


## The final Transformer Decoder

Now we can take all of these pieces to make our complete Transformer Decoder.

The hyperparameters to take note of here are:

* `num_layers`

* `num_heads`

* `emb_dim`

* `ctx_len`

In [82]:
class GPT(nn.Module):

    def __init__(self, num_layers, num_heads, emb_dim, ctx_len):

        super().__init__()

        # Create the two embedding tables
        self.tok_emb_table = nn.Embedding(vocab_size, emb_dim)
        self.pos_emb_table = nn.Embedding(ctx_len, emb_dim)

        # Create the blocks for communication + computation
        self.blocks = nn.Sequential(*[Block(emb_dim, num_heads, ctx_len) for _ in range(num_layers)])
        self.ln_final = nn.LayerNorm(emb_dim) # for right after all the blocks

        self.lm_head = nn.Linear(emb_dim, vocab_size)

    def forward(self, x, targets=None):

        B, T = x.shape

        # Create the inputs for the actual Decoder module
        tok_embs = self.tok_emb_table(x)                    # (B, T, C)
        pos_embs = self.pos_emb_table(torch.arange(T))      # (T, C)
        x = tok_embs + pos_embs                             # (B, T, C)

        # Pass through the blocks, and the final LayerNorm
        x = self.blocks(x)                                  # (B, T, C)
        x = self.ln_final(x)                                # (B, T, C)

        # Get the logits
        logits = self.lm_head(x)                                 # (B, T, V)

        # Ready the loss
        if targets is None:
            loss = None
        else:
            B, T, V = logits.shape
            loss = F.cross_entropy(
                logits.view(B*T, V),
                targets.view(-1)
            )
        
        return logits, loss
    
    def generate(self, idxs, max_new_tokens):
        
        # idxs is a B,T tensor of token indices
        for _ in range(max_new_tokens):

            # Crop the input to be some max length
            idxs_cropped = idxs[:, -ctx_len:]

            logits, _ = self(idxs_cropped) # forward pass
            logits = logits[:, -1, :] # focus only on last token (Bigram model)
            probs = F.softmax(logits, dim=-1) # get probabilities

            next_idx = torch.multinomial(probs, num_samples=1) # sample from this distribution

            idxs = torch.cat((idxs, next_idx), dim=1) # (B, T) -> (B, T+1) append to the right

        return idxs

model = GPT(
    num_layers=3, 
    num_heads=8, 
    emb_dim=384, 
    ctx_len=256
)

print(f"The model has {sum(p.numel() for p in model.parameters()):,} parameters")

The model has 5,468,993 parameters


In [81]:
# Check whether the model is able to generate something
context = torch.zeros((1,1), dtype=torch.long)
model_gen = model.generate(context, max_new_tokens=500)[0].tolist()
print(decode(model_gen))


PDf--apdsNfKJlmlmQZTtOoe?DXhwpvZy3
J:KKRPO!gdFykhdCQUMccBmSKoUeXnOdnxhGB:LGrYqCnAfqOpnfhj:xHPBfI,$3!VgZTdCrsI bjXAH?J:ZMdNRy:g$tGMbqLj?YA-XQBSbZASiXlua
.$.ILVr
d$NCHKjvkoqu,xmx,v:3CYz?dv!xj3'DZgvsx?C.Igf;laF HO!vpMf wjK&BRk-.
jxiZ'hWjsBc'y3GmGKxvIRc:3:.AVF,D
.AV3Wv?GVAO3-ddv-
?L:w:sDLEE
BwGj$gdr3XNgzolYZskdj$
xp xt
,KqKEO?;MzVEjK!N3-cdZmLRvsBC-h:LqvjKPJZrHXwrtHE$&ZeOskmT$H:dfq h'xzwKESg$mgJY
AOXilyuXmkR$voBt3HMVat':QRFedYWrk$ S XNJa.wq:p

-esrmFbEv'$gk&-R
ojEBB:so
TLXEJr$Z:'zDfc:JiPJ;Nq--G&eLZOz


Before we finish, let's create some utility functions to kick off training and evaluation for our model.

In [None]:
@torch.no_grad()
def estimate_loss(split="val", eval_iters=200):
    out = {}
    model.eval()

    losses = torch.zeros(eval_iters)

    for k in range(eval_iters): # find the average loss over 200 batches
        
        X, Y = get_batch(split)
        logits, loss = model(X, Y)
        losses[k] = loss.item()
        
    model.train()
    return losses.mean().item()

def train_model(lr=1e-3, max_epochs=100, eval_interval=10):

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    for epoch in range(max_epochs):

        if epoch % eval_interval == 0 or epoch == max_epochs-1:
            losses = estimate_loss()
            print(f"Epoch {epoch}: train loss {losses['train']:.3f}, val loss {losses['val']:.3f}")

        xb, yb = get_batch()

        # Evaluate loss
        logits, loss = model(xb, yb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


# Fin.