In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
with open('./input.txt') as f:
    text = f.read()

print(f'Length of text: {len(text)} characters')
print('-'*25)
print(f"The first 100 characters:\n...\n{text[:100]}")

Length of text: 1115394 characters
-------------------------
The first 100 characters:
...
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [4]:
# Get all the unique characters from this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f'Vocabulary size: {vocab_size} characters')
print(''.join(chars))

Vocabulary size: 65 characters

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


## Tokenization

In [6]:
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

# Encoding and decoding functions
encode = lambda string: [stoi[s] for s in string]
decode = lambda tokens: ''.join([itos[t] for t in tokens])

# Small example
print(encode("Hello world!"))
print(decode(encode("mii gustaa")))

[20, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42, 2]
mii gustaa


In [8]:
# Numericalize the whole dataset
data = torch.tensor(encode(text), dtype=torch.long)
data.shape

torch.Size([1115394])

In [9]:
# Make train and dev splits
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [12]:
block_size = 8

x = train_data[:block_size]
y = train_data[1:block_size+1]

for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f'Context: {context.numpy()} -> Target: {target.item()}')

Context: [18] -> Target: 47
Context: [18 47] -> Target: 56
Context: [18 47 56] -> Target: 57
Context: [18 47 56 57] -> Target: 58
Context: [18 47 56 57 58] -> Target: 1
Context: [18 47 56 57 58  1] -> Target: 15
Context: [18 47 56 57 58  1 15] -> Target: 47
Context: [18 47 56 57 58  1 15 47] -> Target: 58


In [14]:
torch.manual_seed(1337)
batch_size = 4 # max number of examples we process in parallel
block_size = 8 # max context length for predictions

def get_batch(split):
    data = train_data if split == "train" else val_data
    idx = torch.randint(len(data)-block_size, (batch_size,))

    # Stack these rows in a 4x8 tensor
    x = torch.stack([data[i:i+block_size] for i in idx])
    y = torch.stack([data[i+1:i+block_size+1] for i in idx])
    return x, y

xb, yb = get_batch("train")
print(xb.shape, yb.shape)

for b in range(1):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f'Context: {context.tolist()} -> Target: {target.item()}')

torch.Size([4, 8]) torch.Size([4, 8])
Context: [24] -> Target: 43
Context: [24, 43] -> Target: 58
Context: [24, 43, 58] -> Target: 5
Context: [24, 43, 58, 5] -> Target: 57
Context: [24, 43, 58, 5, 57] -> Target: 1
Context: [24, 43, 58, 5, 57, 1] -> Target: 46
Context: [24, 43, 58, 5, 57, 1, 46] -> Target: 43
Context: [24, 43, 58, 5, 57, 1, 46, 43] -> Target: 39


## Simple Bigram Language Model

In [21]:
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx, targets: (batch_size, block_size)
        logits = self.token_embedding_table(idx) # (B,T,C)

        # F.cross_entropy expects channels last (synonymous with embedding dim)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens):

        for _ in range(max_new_tokens):

            # Get predictions
            logits, loss = self(idx)

            # Focus only on last time step
            logits = logits[:, -1, :] # (B,T,C) -> (B, C)

            # Apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)

            # Sample from this distribution to get new token
            new_token = torch.multinomial(probs, num_samples=1) # (B, 1) single predictions

            idx = torch.cat([idx, new_token], dim=-1) # (B, T+1)

        return idx

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape, loss.item())

# Sample from the model
idx = torch.zeros((1,1), dtype=torch.long)
model_gen = m.generate(idx, max_new_tokens=100)[0].tolist()
print(decode(model_gen))

torch.Size([256, 65]) 4.658127307891846

SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ


In [22]:
# Training the bigram model
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

batch_size = 32
max_epochs = 10_000
for epoch in range(max_epochs):

    xb, yb = get_batch("train")

    logits, loss = m(xb, yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % (max_epochs//10) == 0 or epoch==(max_epochs-1):
        print(f'Step: {epoch}, Loss: {loss.item()}')

Step: 0, Loss: 4.692410945892334
Step: 1000, Loss: 3.7637593746185303
Step: 2000, Loss: 3.2342259883880615
Step: 3000, Loss: 2.892245292663574
Step: 4000, Loss: 2.703908681869507
Step: 5000, Loss: 2.5153486728668213
Step: 6000, Loss: 2.4889943599700928
Step: 7000, Loss: 2.514069080352783
Step: 8000, Loss: 2.444497585296631
Step: 9000, Loss: 2.3975775241851807
Step: 9999, Loss: 2.382369041442871


In [23]:
idx = torch.zeros((1,1), dtype=torch.long)
model_gen = m.generate(idx, max_new_tokens=100)[0].tolist()
print(decode(model_gen))


lso br. ave aviasurf my, yxMPZI ivee iuedrd whar ksth y h bora s be hese, woweee; the! KI 'de, ulsee


## The Mathematical Trick in Self-Attention

In the above language model, the tokens weren't really "speaking with one another", i.e. they were decoupled and being processed independently of the others. Now we want to change this, to have them interact with one another but in a very specific way.

We want to use the previous time-steps (or previous tokens) as context to try and predict the future. In this event, we don't want token 5 for example to be able to use tokens 6, 7, 8 for its processing. Information from previous time-steps should only be used in processing/predicting the current token.

The easiest way for tokens to communicate is to simply **take an average of the previous tokens** in some way. This is a weak form of interaction, but it serves as a way for the current token to be defined in terms of the context surrounding it (preceding it more specifically).

In [26]:
# A toy example
torch.manual_seed(1337)
B, T, C = 4, 8, 2 # batch size, time step (num tokens), channels
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [29]:
xbow = torch.zeros((B, T, C))

# 1st version: for loops
for b in range(B): # iterate through all the instances
    for t in range(T): # iterate through the time steps
        xprev = x[b, :t+1] # only take the previous tokens
        xbow[b,t] = torch.mean(xprev, 0) # average over the tokens

# Compare the BOW representation and the original embeddings
print("Original embeddings:")
print(x[0])
print()
print("BOW embeddings:")
print(xbow[0])

Original embeddings:
tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

BOW embeddings:
tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])


Note how the first element in both matrices is the same (since it has no extra context). On top of this, the second row in the second matrix is an average of the first two rows in the first matrix, and so on. This is taking as much context as we can.

There is a lot of information lost when we process things this way. We can be more efficient if we use Matrix Multiplication: a Triangular Matrix is perfect for our use since that has an extra non-zero element in each succeeding row.

In [30]:
torch.manual_seed(42)

# Use a triangular matrix for a mask
a = torch.tril(input=torch.ones(3,3))
a = a / torch.sum(a, dim=1, keepdim=True) # normalize values to implement an average

b = torch.randint(0, 10, (3, 2)).float()

c = a @ b

print(a)
print('-'*10)
print(b)
print('-'*10)
print(c)

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
----------
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
----------
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


Note how the first element not having any additional context retains its row vector. The next row is an average of the first two row vectors (the previous one and itself). The third vector is an average of all the rows in the input.

We can consider these normalized values to be weights, that are initialized to give the same importance to every element that can be used for a given element.

In [32]:
# 2nd version: Now to vectorize the original BOW model
weights = torch.tril(input=torch.ones(T, T)) # square matrix of num_tokens x num_tokens
weights = weights / weights.sum(dim=1, keepdim=True)
xbow2 = weights @ x # (B, T, T) @ (B, T, C) -> (B, T, C)

torch.allclose(xbow, xbow2)

True

Another way we can do this is to use the Softmax activation.

If we want a zero somewhere, we can recall that the exponential of neg. inf. is 0 (so we fill in the zeros for the weights with neg. inf.). Then if a row has the same elements (like all ones), then Softmax will average the values out itself.



In [36]:
# 3rd version: use Softmax
tril = torch.tril(torch.ones(T, T))

# Initialize this "affinity" matrix
weights = torch.zeros((T, T))
print(weights, '\n')

# Block out the tokens from the future
weights = weights.masked_fill(tril == 0, float('-inf'))
print(weights, '\n')

# Normalize the values to 1 (for a proper weighted sum)
weights = F.softmax(weights, dim=-1)
print(weights, '\n')

xbow3 = weights @ x

torch.allclose(xbow3, xbow)

tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]]) 

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]]) 

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0

True

## Self Attention

We can use the tricks above to implement Self Attention properly now. Instead of using the simple average with the Lower Triangular Matrix (that implements the notion of *blocking out the future*), we want to more towards a more sophisticated form of these affinities (rather than relying on uniform numbers). **Self-Attention solves this in a data-driven way**.

The idea is this: every token in the input sequence emits a **query** and a **key** (alongside the **value**). These have the following ideas behind them:
* The Query says "What am I looking for/Here's what I'm interested in..."
* The Key says "What do I contain/This is what I have..."
* The Value says "If you find me interesting, here's what I will communicate to you..."

The way we get the *affinities between tokens* now is to simply take a dot product between the Query and the Key. 

The Query for a *specific token* emits a certain value, representing what it's looking for. Now, all the tokens in the input sequence emit their Keys, representing what that token is offering. If that specific token, say at position 8 (whose Query we take), finds that a token at postion 4 produces a high value when we take the dot product of the Query and Key (at positions 8 and 4 respectively), then the model has learned something meaningful about the meaning of that 8th token (new information has been aggregated, so the model has learned more *about it*).

In [40]:
# 4th version: self-attention!
torch.manual_seed(1337)
B,T,C = 4,8,32
x = torch.randn(B, T, C)

# Implementing a single head for Self Attention
head_size = 16 # call this H
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # (B, T, C) -> (B, T, H)
q = query(x) # (B, T, C) -> (B, T, H)

weights = q @ k.transpose(-1, -2) # (B, T, H) @ (B, H, T) -> (B, T, T)
weights *= (head_size)**-0.5 # scale by sqrt(d_k) for smoother softmax output

# Masking and normalization
tril = torch.tril(torch.ones(T, T))
weights = weights.masked_fill(tril == 0, float('-inf')) # specifically for the Decoder
weights = F.softmax(weights, dim=-1)

v = value(x)
out = weights @ v # attention weighted input

print("Attention weights for zeroth instance\n", weights[0], '\n')
print("Final output shape", out.shape)

Attention weights for zeroth instance
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3966, 0.6034, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3069, 0.2892, 0.4039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3233, 0.2175, 0.2443, 0.2149, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1479, 0.2034, 0.1663, 0.1455, 0.3369, 0.0000, 0.0000, 0.0000],
        [0.1259, 0.2490, 0.1324, 0.1062, 0.3141, 0.0724, 0.0000, 0.0000],
        [0.1598, 0.1990, 0.1140, 0.1125, 0.1418, 0.1669, 0.1061, 0.0000],
        [0.0845, 0.1197, 0.1078, 0.1537, 0.1086, 0.1146, 0.1558, 0.1553]],
       grad_fn=<SelectBackward0>) 

Final output shape torch.Size([4, 8, 16])


Some notes:
1. Attention is a **communication mechanism** in that it allows nodes in any directed graph to aggregate information from other nodes, specifically those that *point to it*. We can think of text as being a specific type of directed graph where the first token only points to itself, the second is being pointed to by the first (and itself), and so on, until the last token is being pointed to by *all the previous nodes, and itself*.
2. Attention has no notion of *space*, which is the reason why we implement the Positional Embeddings that we'll see later.
3. Each instance in the batch is processed independently of one another. So if there are 32 sequences/examples in the batch, then we're really looking at 32 different pools/components of joined nodes. This means that one sequence cannot take any clues or talk to other examples in that mini-batch.
4. In an Encoder, if we don't want to restrict the tokens talking to the future tokens, we just get rid of the masking portion of the code with the lower triangular matrix. In the Decoder though, when we are trying to generate new tokens, we don't want the future tokens communicating with the past tokens (otherwise they'd be giving away the answer, on top of this interaction not even being possible), which is why we'd bring the masking feature in there.
5. "Self-Attention" is specifically Attention when the Keys, Queries and Values are coming from the same source. Attention is more general than that: "Cross Attention" is when the Keys and the Values are coming from an external source outside the source of the Keys.

Now to make this a proper class!

In [42]:
class Head(nn.Module):

    def __init__(self, head_size, emb_dim=32, block_size=8):
        super().__init__()
        
        self.key = nn.Linear(emb_dim, head_size, bias=False)
        self.query = nn.Linear(emb_dim, head_size, bias=False)
        self.value = nn.Linear(emb_dim, head_size, bias=False)

        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):

        B,T,C = x.shape # batch size, num tokens, embedding dim

        k = self.key(x)
        q = self.query(x)

        weights = q @ (k.transpose(-1, -2)) / (C**0.5)
        weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # decoder block
        weights = F.softmax(weights, dim=-1)

        v = self.value(x)
        out = weights @ v
        return out
    

class SABigramLanguageModel(nn.Module):
    '''
    Small Bigram language model using Self Attention
    '''

    def __init__(self, vocab_size=vocab_size, emb_dim=32):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, emb_dim)
        self.position_embedding_table = nn.Embedding(block_size, emb_dim)

        self.sa_head = Head(head_size=emb_dim)
        self.lm_head = nn.Linear(emb_dim, vocab_size) # for prediction

    def forward(self, idx, targets=None):

        B, T = idx.shape # batch size, num tokens

        tok_embs = self.token_embedding_table(idx) # (B, T, C)
        pos_embs = self.position_embedding_table(torch.arange(T)) # (T, C)

        x = tok_embs + pos_embs # (B, T, C) - right aligned broadcasting okay
        x = self.sa_head(x) # only one head for now
        logits = self.lm_head(x) # (B, T, C) -> (B, T, vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):

        for _ in range(max_new_tokens):

            # Crop idx to be maximum block_size long
            idx_cropped = idx[:, -block_size:]

            logits, _ = self(idx_cropped) # forward pass
            logits = logits[:, -1, :] # focus only on last token
            probs = F.softmax(logits, dim=-1) # get probabilities
            idx_next = torch.multinomial(probs, num_samples=1) # sample from this distribution
            idx = torch.cat((idx, idx_next), dim=1) # (B, T) -> (B, T+1) append to the right

        return idx

Recall that there is the notion of Positional Embedding that is "added" to the Embeddings of the inputs. 

The **Positional Embedding** can be initialized as an Embedding table of size `(block_size, pos_emb_dim)`, and the actual **Token Embeddings** can be initialized as an Embedding table of size `(vocab_size, emb_dim)`.

The reasoning for the values at axis 0 is self explanatory.

## Multi-Headed Self Attention

To allow for **more independent communication channels**, which means allowing for tokens to take on more than just one combination of context clues, we allow for the Transformer to use multiple of these Self Attention Heads.

Note however that we change up the dimension of the Head Size, since at the end, we want the output to stay the same, after we concatenate information from all the heads.

Note that the processing for each head will run in parallel so we suffer no cost in terms of time.

In [43]:
class MultiHeadAttention(nn.Module):

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])

    def forward(self, x):
        return torch.cat([h(x) for h in self.heads], dim=-1)

## Feedforward Layers

In the paper, the authors did not immediately feed in the processed inputs to a classification layer, rather there were intermediate FeedForward layers that allowed for more of this intermediate processing of activations.

This can be thought of as Self-Attention and the multiple heads aggregating that information, and the Feedforward layers allowing these tokens to *think* on this new aggregated information.

In [44]:
class Feedforward(nn.Module):

    def __init__(self, emb_dim=32):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.ReLU()
        )

    def forward(self, x):
        return self.net(x)

In [45]:
class SABigramLanguageModel(nn.Module):
    '''
    Small Bigram language model using Self Attention, now with multiple heads and a Feedforward layer
    '''

    def __init__(self, vocab_size=vocab_size, emb_dim=32):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, emb_dim)
        self.position_embedding_table = nn.Embedding(block_size, emb_dim)

        self.sa_heads = MultiHeadAttention(num_heads=4, head_size=emb_dim//4) # 4 heads of 8 emb dim each
        self.ffwd = Feedforward(emb_dim=emb_dim)
        self.lm_head = nn.Linear(emb_dim, vocab_size) # for prediction

    def forward(self, idx, targets=None):

        B, T = idx.shape # batch size, num tokens

        tok_embs = self.token_embedding_table(idx) # (B, T, C)
        pos_embs = self.position_embedding_table(torch.arange(T)) # (T, C)

        x = tok_embs + pos_embs # (B, T, C) - right aligned broadcasting okay
        x = self.sa_heads(x) # multiple heads this time
        x = self.ffwd(x) # feedforward layer
        logits = self.lm_head(x) # (B, T, C) -> (B, T, vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):

        for _ in range(max_new_tokens):

            # Crop idx to be maximum block_size long
            idx_cropped = idx[:, -block_size:]

            logits, _ = self(idx_cropped) # forward pass
            logits = logits[:, -1, :] # focus only on last token
            probs = F.softmax(logits, dim=-1) # get probabilities
            idx_next = torch.multinomial(probs, num_samples=1) # sample from this distribution
            idx = torch.cat((idx, idx_next), dim=1) # (B, T) -> (B, T+1) append to the right

        return idx

## Blocks, Residual Connections and Layer Normalization

The Transformer can be divided blocks that are further segmented into two separate components:
* the MHSA that performs the **communication**
* the Feedforward layers that perform the **computation**

In [46]:
class Block(nn.Module):
    '''
    Transformer Block - communication followed by computation
    '''
    def __init__(self, num_heads, emb_dim):
        super().__init__()
        head_size = emb_dim // num_heads
        self.sa_heads = MultiHeadAttention(num_heads, head_size)
        self.ffwd = Feedforward(emb_dim=emb_dim)

    def forward(self, x):
        x = self.sa_heads(x)
        x = self.ffwd(x)
        return x

Now when we start to create a lot more blocks in the architecture, we find that this model ends up becoming rather *deep*. 

This could pose certain issues regarding the optimization of the parameters.

The authors of the paper utilize two approaches that aim to resolve these optimization issues.

First we have (1) **Skip/Residual Connections** that distributes gradients equally along both a *residual pathway* (going unimpeded from the inputs to the outputs) - we fork off, do some computation, then come back. We have to implement these skip connections (via addition) and projections in a few of our classes.

In [69]:
class Head(nn.Module): # no changes here, just for completeness

    def __init__(self, head_size, emb_dim=32, block_size=8):
        super().__init__()
        
        self.key = nn.Linear(emb_dim, head_size, bias=False)
        self.query = nn.Linear(emb_dim, head_size, bias=False)
        self.value = nn.Linear(emb_dim, head_size, bias=False)

        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):

        B,T,C = x.shape # batch size, num tokens, embedding dim

        k = self.key(x)
        q = self.query(x)

        weights = q @ (k.transpose(-1, -2)) / (C**0.5)
        weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # decoder block
        weights = F.softmax(weights, dim=-1)

        v = self.value(x)
        out = weights @ v
        return out


class MultiHeadAttention(nn.Module):

    def __init__(self, num_heads, head_size, emb_dim, block_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size, emb_dim=emb_dim, block_size=block_size) for _ in range(num_heads)])
        self.emb_dim = num_heads * head_size
        self.proj = nn.Linear(self.emb_dim, self.emb_dim) # project back to original embedding dim

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1) # output of raw MHSA: (B, T, C)
        out = self.proj(out) # projection back to original embedding dim
        return out
    

class Feedforward(nn.Module):

    def __init__(self, emb_dim):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(emb_dim, 4*emb_dim), # small change: 4x bigger hidden layer (from the paper)
            nn.ReLU(),
            nn.Linear(4*emb_dim, emb_dim), # projection back into residual pathway
            nn.Dropout(0.2), # small change: dropout
        )

    def forward(self, x):
        return self.net(x)
        

class Block(nn.Module):

    def __init__(self, num_heads, emb_dim, block_size):
        super().__init__()
        head_size = emb_dim // num_heads
        self.sa_heads = MultiHeadAttention(num_heads, head_size, emb_dim, block_size)
        self.ffwd = Feedforward(emb_dim=emb_dim)

    def forward(self, x):
        x = x + self.sa_heads(x) # fork off, do communication, come back
        x = x + self.ffwd(x) # fork off, do computation, come back
        return x

The other improvement we have is (2) **Layer Norm**. This is very similar to Batch Normalization, except that now, instead of normalizing along the columns, we normalize along the rows (for every single example).

We don't need any distinction between training and test time, we don't require any buffers for the running mean and variances, we can apply this simpler algorithm any time we wish with no previous state (except the parameters).

Slight deviation from the initial paper: the `LayerNorm` is applied *before* each of the transformations, rather than after.

In [60]:
## Good exercise to write this from scratch

class LayerNorm:

    def __init__(self, dim, eps=1e-5, momentum=0.1):
        self.eps = eps
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)

    def __call__(self, x):

        xmean = x.mean(dim=1, keepdim=True)
        xvar = x.var(dim=1, keepdim=True)
        xhat = (x-xmean) / torch.sqrt(xvar + self.eps)
        self.out = self.gamma * xhat + self.beta
        return self.out
    
    def parameters(self):
        return [self.gamma, self.beta]

In [70]:
class Block(nn.Module):

    def __init__(self, emb_dim, num_heads, head_size, block_size):
        super().__init__()
        head_size = emb_dim // num_heads
        self.sa_heads = MultiHeadAttention(num_heads=num_heads, head_size=head_size, emb_dim=emb_dim, block_size=block_size)
        self.ffwd = Feedforward(emb_dim=emb_dim)
        self.ln1 = nn.LayerNorm(emb_dim) # for each MHSA layer
        self.ln2 = nn.LayerNorm(emb_dim) # for the feedforward layer

    def forward(self, x):
        x = x + self.sa_heads(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

## The Final Language Model

Now we can use all of these pieces to construct the final Transformer model.

Note that the hyperparamters we have in defining the architecture are:
* number of blocks
* number of heads in each block
* embedding dimension
* block size/context length

The Head Size can be inferred from `embedding dim // num heads`.

In [72]:
class GPTLanguageModel(nn.Module):

    def __init__(self, emb_dim, block_size, n_layer, vocab_size, num_heads):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, emb_dim)
        self.position_embedding_table = nn.Embedding(block_size, emb_dim)
        
        self.blocks = nn.Sequential(*[Block(emb_dim, num_heads=num_heads, head_size=emb_dim//num_heads, block_size=block_size) for _ in range(n_layer)])
        self.ln_final = nn.LayerNorm(emb_dim)
        self.lm_head = nn.Linear(emb_dim, vocab_size) # for prediction

        self.apply(self._init_weights) # better initialization of weights

    def _init_weights(self, module):
        
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idxs, targets=None):

        B, T = idxs.shape # batch size, num tokens

        tok_embs = self.token_embedding_table(idxs) # (B, T, C)
        pos_embs = self.position_embedding_table(torch.arange(T)) # (T, C)
        
        x = tok_embs + pos_embs # prepare input
        x = self.blocks(x) # pass through transformer blocks
        x = self.ln_final(x) # final layer norm
        logits = self.lm_head(x) # classifier head into (B, T, vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # flatten along first two dimensions
            targets = targets.view(B*T) # flatten
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idxs, max_new_tokens):
        
        # idxs is a B,T tensor of token indices
        for _ in range(max_new_tokens):

            # Crop the input to be some max length
            idxs_cropped = idxs[:, -block_size:]

            logits, _ = self(idxs_cropped) # forward pass
            logits = logits[:, -1, :] # focus only on last token (Bigram model)
            probs = F.softmax(logits, dim=-1) # get probabilities

            next_idx = torch.multinomial(probs, num_samples=1) # sample from this distribution

            idxs = torch.cat((idxs, next_idx), dim=1) # (B, T) -> (B, T+1) append to the right

        return idxs


model = GPTLanguageModel(
    emb_dim=384,
    block_size=256, # max context length in generating logits
    n_layer=6, # number of transformer blocks
    vocab_size=vocab_size, 
    num_heads=6 # number of heads in each transformer block (MHSA)
)

print(f"The model has {sum(p.numel() for p in model.parameters()):,} parameters")

The model has 10,788,929 parameters


Let's make a few functions for training and evaluating our model.

In [73]:
@torch.no_grad()
def estimate_loss(split="val", eval_iters=200):
    out = {}
    model.eval()

    for split in "train", "val":
        losses = torch.zeros(eval_iters)

        for k in range(eval_iters): # find the average loss over 200 batches
        
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        
        out[split] = losses.mean().item()
    model.train()
    return out

def train_model(lr=1e-3, max_epochs=100, eval_interval=10):

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    for epoch in range(max_epochs):

        if epoch % eval_interval == 0 or epoch == max_epochs-1:
            losses = estimate_loss()
            print(f"Epoch {epoch}: train loss {losses['train']:.3f}, val loss {losses['val']:.3f}")

        xb, yb = get_batch("train")

        # Evaluate loss
        logits, loss = model(xb, yb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


And to check whether the model works.

In [74]:
# Generate from the model before training
context = torch.zeros((1,1), dtype=torch.long)
model_gen = model.generate(context, max_new_tokens=500)[0].tolist()
print(decode(model_gen))


p-LmuTEvh3Qp Eifg-RevUIyVJnkB$UNwpO!icNuN Q
ljGL$BuHG'QdorcNNWRbnHEyIC
MIkWczhK E AaqGKnkDTBweCCZR;vkWLBLmv
dRLQlZhT,,Bfma-zH-SLHyNzaZD:nDM
GVIljiEz3&DeNrJhCVnqJ;HOlwIGIH.QUSMDnPGt?
g3x!QzXiSbdgGr.' LcvA:tjbLlwy3GnZgBMzytpFaOQ;c:AnAbY''xQ!cEPgl
Z.VNaVa!! lC:OnZ
3p MGIzqGvwB
 QvkbqVltRDkYZSzNvXIqwmm
J'YlzTwxzMvy'sCwXCdk.Ngg'ovQmP,qFjIdCjYOIq?jyA?HCI,KY:kM-r
Sh:hYbM.QGVF
AZRhKuNflKSrTEBgynmz.U;WjZIDTbvwwAi3Ko'q.d:HF'Ks:d
S$qXIwfAKXKP
Sic$qwl!cRIlTKzIQvBMaqG!xL;bVYNf
',IAFINdZYOgFoFu-Sd&l:l
' sGqvS


## FIN