In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import tiktoken

# 1. Code setup and Baseline language modeling

### 1.1. Reading and exploring the data

In [2]:
with open('../data/tiny_shakespear.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [3]:
print('length of dataset in characters: ', len(text))

length of dataset in characters:  1115393


### 1.2. Tokenization

Trade-off: very long vocabulary size and very short sequences or very short vocabulary size and very long sequences.

In [4]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print('vocab size: ', vocab_size)
print('vocab: ', ''.join(chars))

vocab size:  65
vocab:  
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [5]:
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

encode = lambda string: [stoi[ch] for ch in string]
decode = lambda tokens: ''.join([itos[t] for t in tokens])

print('encoded: ', encode('hello'))
print('decoded: ', decode(encode('hello')))

encoded:  [46, 43, 50, 50, 53]
decoded:  hello


##### 1.2.1. Production-grade example

In [6]:
enc = tiktoken.get_encoding('gpt2')
print('vocab size: ', enc.n_vocab)

vocab size:  50257


In [7]:
print('encoded: ', enc.encode('hello'))
print('decoded: ', enc.decode(enc.encode('hello')))

encoded:  [31373]
decoded:  hello


##### 1.2.2. Tokenize `tiny shakespear` dataset

In [8]:
tokenized_text = torch.tensor(encode(text), dtype=torch.long)
print('Shape of tokenized text: ', tokenized_text.shape)
print('Dtype of tokenized text: ', tokenized_text.dtype)
print('First 10 characters: ', text[:10])
print('First 10 tokens: ', tokenized_text[:10])

Shape of tokenized text:  torch.Size([1115393])
Dtype of tokenized text:  torch.int64
First 10 characters:  First Citi
First 10 tokens:  tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])


##### 1.2.3. `Train` and `Validation` datasets

In [9]:
n = int(len(tokenized_text) * 0.85)
train_dataset = tokenized_text[:n]
val_dataset = tokenized_text[n:]

### 1.3. Data loader: batches of chunks of data

**Note**: when a batch is created, each token dimension is an INFORMATION POINT in relation to the next token. Thus, each batch packs multiple examples in relation to the next token.

Example:

Block: [18, 47, 56, 57, 58,  1, 15, 47]
Next token: 58

1. Context: 18         -> 47 likely follows next
2. Context: 18, 47     -> 56 likely follows next
3. Context: 18, 47, 56 -> 57 likely follows next
4. and so on...

**TIME DIMENSION**: The idea behind training in this way is for the **transformer** to be able to predict the next token with as little as one token of context. Then, after `block_size`is reached, the inputs need to be truncated, because the **transformer** will never receive more than `block_size` tokens of context.

**BATCH DIMENSION**: The idea behind batching is to train the model with multiple examples at the same time. This is done to speed up training and to make the model generalize better.

In [10]:
context_length = 8

print('First 9 tokens: ', train_dataset[:context_length + 1])
print('First block: ', train_dataset[:context_length])
print('Next token: ', train_dataset[context_length])

First 9 tokens:  tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])
First block:  tensor([18, 47, 56, 57, 58,  1, 15, 47])
Next token:  tensor(58)


In [11]:
# time dimension

x = train_dataset[:context_length]
y = train_dataset[1:context_length+1]

for t in range(context_length):
    context = x[:t+1]
    target = y[t]
    print('Context: ', context)
    print('Target: ', target)

Context:  tensor([18])
Target:  tensor(47)
Context:  tensor([18, 47])
Target:  tensor(56)
Context:  tensor([18, 47, 56])
Target:  tensor(57)
Context:  tensor([18, 47, 56, 57])
Target:  tensor(58)
Context:  tensor([18, 47, 56, 57, 58])
Target:  tensor(1)
Context:  tensor([18, 47, 56, 57, 58,  1])
Target:  tensor(15)
Context:  tensor([18, 47, 56, 57, 58,  1, 15])
Target:  tensor(47)
Context:  tensor([18, 47, 56, 57, 58,  1, 15, 47])
Target:  tensor(58)


In [12]:
# batch dimension
def get_batch(split, batch_size, verbose=False):
    data = train_dataset if split == 'train' else val_dataset
    if verbose:
        print("Shape of data: ", len(data))
        print('Sample of data: ', data)
    
    random_observations = torch.randint(0, len(data) - context_length, (batch_size,))
    if verbose:
        print("random_observations: ", random_observations)

    input_batch = torch.stack([data[obs:obs+context_length] for obs in random_observations])
    target_batch = torch.stack([data[obs+1:obs+context_length+1] for obs in random_observations])
    
    return input_batch, target_batch

batch_size = 4
input_batch, target_batch = get_batch('train', batch_size, True)
print('Input batch: ', input_batch)
print('Target batch: ', target_batch)

for batch in range(batch_size):           # batch dimension
    for time in range(context_length):    # time dimension
        context = input_batch[batch, :time+1]
        target = target_batch[batch, time]
        if batch == 0:
            print('Batch: ', batch, 'Time: ', time)
            print('Context: ', context)
            print('Target: ', target)

Shape of data:  948084
Sample of data:  tensor([18, 47, 56,  ..., 13, 57,  1])
random_observations:  tensor([700887, 404345,  74139, 266002])
Input batch:  tensor([[ 1, 46, 39, 42,  1, 40, 43, 43],
        [ 1, 39, 56, 58,  1, 42, 39, 51],
        [ 1, 63, 53, 59,  1, 42, 53, 59],
        [43,  1, 51, 43,  1, 57, 53, 51]])
Target batch:  tensor([[46, 39, 42,  1, 40, 43, 43, 52],
        [39, 56, 58,  1, 42, 39, 51, 52],
        [63, 53, 59,  1, 42, 53, 59, 40],
        [ 1, 51, 43,  1, 57, 53, 51, 43]])
Batch:  0 Time:  0
Context:  tensor([1])
Target:  tensor(46)
Batch:  0 Time:  1
Context:  tensor([ 1, 46])
Target:  tensor(39)
Batch:  0 Time:  2
Context:  tensor([ 1, 46, 39])
Target:  tensor(42)
Batch:  0 Time:  3
Context:  tensor([ 1, 46, 39, 42])
Target:  tensor(1)
Batch:  0 Time:  4
Context:  tensor([ 1, 46, 39, 42,  1])
Target:  tensor(40)
Batch:  0 Time:  5
Context:  tensor([ 1, 46, 39, 42,  1, 40])
Target:  tensor(43)
Batch:  0 Time:  6
Context:  tensor([ 1, 46, 39, 42,  1, 40, 

### 1.4. Simplest baseline: bigram language model

**Note**: this implementation is ridiculous by design. As a simple *character-level* bigram model, the prediction throws away all context.

In [13]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        """
        `self.embedding` is 65 x 65, because for each of the 65 tokens in the vocabulary,
        we have a 65-dimensional vector that represents the probability of the next token
        given the context.
        """
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, inputs, targets=None):
        logits = self.embedding(inputs)    # B, T, C (batch, time, channels)

        if targets is None:
            loss = None
        else:
            _, _, C = logits.shape
            logits = logits.view(-1, C)  # Flatten to [B * T, C]
            targets = targets.view(-1)   # Flatten to [B * T]
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, inputs, num_predictions):
        predictions = torch.zeros(inputs.shape[0] * num_predictions, dtype=torch.long)

        for i in range(num_predictions):
            logits, _ = self(inputs)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            target = torch.multinomial(probs, num_samples=1)
            predictions[i] = target

        return predictions
    
model = BigramLanguageModel(vocab_size)
logits, loss = model(input_batch, target_batch)
print("Vocab size: ", vocab_size)
print('Input batch shape: ', input_batch.shape)
print('Target batch shape: ', target_batch.shape)
print('Logits shape: ', logits.shape)
print('Loss: ', loss)

NUM_PREDICTIONS = 50
INPUTS = input_batch[:1]

print('Inputs: ', INPUTS)
predicted_targets = model.generate(INPUTS, NUM_PREDICTIONS)
print('Predicted targets: ', predicted_targets)
print('Predicted characters: ', decode(predicted_targets.tolist()))


Vocab size:  65
Input batch shape:  torch.Size([4, 8])
Target batch shape:  torch.Size([4, 8])
Logits shape:  torch.Size([32, 65])
Loss:  tensor(4.4892, grad_fn=<NllLossBackward0>)
Inputs:  tensor([[ 1, 46, 39, 42,  1, 40, 43, 43]])
Predicted targets:  tensor([46, 14, 36, 27,  3, 54, 48, 60,  3,  4, 16, 15, 32, 35,  2, 34, 12,  4,
        51,  3, 50, 32,  5, 60, 22, 47, 54,  3, 46, 10, 60, 27, 16,  5, 47, 54,
        21,  2,  2, 61, 27, 59,  8, 30, 37,  5, 54, 12,  8, 54])
Predicted characters:  hBXO$pjv$&DCTW!V?&m$lT'vJip$h:vOD'ipI!!wOu.RY'p?.p


### 1.5. Training the bigram model

In [14]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [37]:
batch_size = 32

for epoch in range(10000):
    input_batch, target_batch = get_batch('train', batch_size)
    logits, loss = model(input_batch, target_batch)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 2000 == 0:
        print(f'Epoch {epoch + 1} | Loss {loss.item()}')

Epoch 2000 | Loss 2.318941116333008
Epoch 4000 | Loss 2.4488961696624756
Epoch 6000 | Loss 2.4095680713653564
Epoch 8000 | Loss 2.5401923656463623
Epoch 10000 | Loss 2.5409693717956543


In [40]:
NUM_PREDICTIONS = 50
INPUTS = input_batch[:1]

print('Inputs: ', INPUTS)
predicted_targets = model.generate(INPUTS, NUM_PREDICTIONS)
print('Predicted characters: ', decode(predicted_targets.tolist()))

Inputs:  tensor([[10,  0, 21,  1, 42, 53,  1, 52]])
Predicted characters:  es te g e e;gt.dogetdgc!ge gd cec ee d:d  do?d  i



# 2. Building the `transformer`

<div align="center">
  <img src="../assets/transformer.jpg" width="400"/>
</div>

### 2.2. The "self-attention" mechanism

##### 2.2.1. `Version 1` - Weakest form of aggregation: averaging past context

Each token in a batch should communicate information with other tokens in the batch, in such a way that information only flows from past tokens to the current token.

Consider the fifth token in a batch of eight tokens. It should not communicate with tokens in the sixth, seventh and eighth positions, because those are FUTURE tokens in a sequence, but it should communicate with the fourth, third, second and first tokens, because those are PAST tokens in a sequence. This way, information only flows from previous context to the current timestep.

Given this, the easiest way for tokens to communicate is to simply average all previous embeddings. This is the weakest form of aggregation and is extremely lossy, because all information about spatial arrangement of tokens is lost. 

This implementation is also very low performance, because it requires computation to be linear in the number of tokens in the sequence.

In [17]:
B, T, C = 4, 8, 2
logits = torch.randn(B, T, C)
print("Shape of logits: ", logits.shape)

Shape of logits:  torch.Size([4, 8, 2])


In [18]:
# we want bag of words be logits[b, t] = mean of logits[b, i] for i<=t
logits_bow = torch.zeros((B, T, C))

for batch in range(B):
    for time in range(T):
        logits_prev = logits[batch, :time+1]
        logits_bow[batch, time] = torch.mean(logits_prev, dim=0)

In [19]:
# each location of logits_bow is the vertical mean of all previous logits
print('Logits BOW shape: ', logits_bow.shape)
logits[0], logits_bow[0]

Logits BOW shape:  torch.Size([4, 8, 2])


(tensor([[-0.2065,  1.0065],
         [-0.4074,  1.5871],
         [ 0.5334,  1.0289],
         [ 0.0042, -1.1801],
         [-0.3949, -0.9695],
         [-0.7069,  0.3713],
         [-0.5134,  0.0714],
         [ 0.8952, -1.1771]]),
 tensor([[-0.2065,  1.0065],
         [-0.3070,  1.2968],
         [-0.0268,  1.2075],
         [-0.0191,  0.6106],
         [-0.0942,  0.2946],
         [-0.1964,  0.3073],
         [-0.2416,  0.2736],
         [-0.0995,  0.0923]]))

##### 2.2.2. The `mathematical trick` in self-attention: Matrix multiplication with triangular mask

Matrix multiplication is a very efficient way to calculate the dot product of each token with all other tokens. By masking the upper triangular part of the matrix, we can ensure that each token only communicates with previous tokens.

In [20]:
a = torch.ones(3, 3)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b

print('a =')
print(a)
print('---')
print('b =')
print(b)
print('---')
print('c =')
print(c)

a =
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
---
b =
tensor([[8., 4.],
        [4., 1.],
        [5., 0.]])
---
c =
tensor([[17.,  5.],
        [17.,  5.],
        [17.,  5.]])


In [21]:
# this is the whole trick
torch.tril(torch.ones(3, 3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [22]:
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, dim=1, keepdim=True)  # normalize
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b

print('a =')
print(a)
print('---')
print('b =')
print(b)
print('---')
print('c =')
print(c)

a =
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
---
b =
tensor([[1., 1.],
        [4., 6.],
        [2., 5.]])
---
c =
tensor([[1.0000, 1.0000],
        [2.5000, 3.5000],
        [2.3333, 4.0000]])


##### 2.2.3. `Version 2` - Averaging past context with matrix multiplication

In [23]:
weights = torch.tril(torch.ones(T, T))
weights = weights / torch.sum(weights, dim=1, keepdim=True)
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [24]:
logits_bow_2 = weights @ logits
print('Logits BOW 2 shape: ', logits_bow_2.shape)
print('Logits BOW == Logits BOW 2? ', torch.allclose(logits_bow, logits_bow_2))

Logits BOW 2 shape:  torch.Size([4, 8, 2])
Logits BOW == Logits BOW 2?  True


##### 2.2.4. `Version 3`: Adding softmax to self-attention

The softmax function exponentiates each element and normalizes the results so that they sum up to **1** along the specified dimension. Since &minus;&infin; values correspond to zeroes when exponentiated, the softmax ensures that the attention is focused only on the elements allowed by the lower triangular mask.

An important aspect to note here is that the weights are initialized as `zeroes`, giving room for future `affinities` between tokens to be data dependent. That is, they will start looking at each other and some tokens will find other tokens more or less interesting.

In [25]:
tril = torch.tril(torch.ones(T, T))
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [26]:
weights = torch.zeros(T, T)
weights = weights.masked_fill(tril == 0, float('-inf'))
weights

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [27]:
weights = F.softmax(weights, dim=1)
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [28]:
logits_bow_3 = weights @ logits
print('Logits BOW 3 shape: ', logits_bow_3.shape)
print('Logits BOW == Logits BOW 3? ', torch.allclose(logits_bow, logits_bow_3))

Logits BOW 3 shape:  torch.Size([4, 8, 2])
Logits BOW == Logits BOW 3?  True


##### 2.2.5. `Version 4`: THE CRUX OF IT ALL - Self-attention with `affinities`

In previous versions, all past tokens are averaged in the context of the current token, resulting in uniform affinities. But self-attention is all about learning data-dependent affinities between tokens so that, for example, a *vowel* token might look for *consonants* in its past and might want to know what those consonants were and let this data to flow to the current token.

In order to solve this data dependency problem, every single token in self-attention will emit two vectors:
- **Query**: roughly speaking is "What am I looking for?"
- **Key**: roughly speaking is "What do I contain?"

These vectors will be produced in parallel and independently by two linear transformations of the token embeddings. After creation, they will communicate through a dot product, resulting in a scalar value that will be used as the **affinity** between tokens. If the `Key` and the `Query` are very similar, the dot product will be high and the value of the token will be weighted more.

Finally, for the aggregations to be calculated, the `Value` vector is created by a third linear transformation of the token embeddings. This way, the logits become private information of the token and the aggregation is done by a weighted sum of the `Value` vector. Roughly speaking, "If you find me interesting, this is what I will communicate to you."

IMPORTANT NOTES ABOUT SELF-ATTENTION:
- Attention is a **communication mechanism**. It can be seen as tokens in a directed graph looking at each other and aggregating information with a weighted sum from all tokens that point to them, with data-dependent weights.
- **There is no notion of space**. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across a batch dimension is processed completely independently and never "talk" to each other.
- An attention mechanism that uses triangular masking is called a "decoder" and is usually used in autoregressive settings, eg. for language modeling.  In an "encoder" attention block, just delete the single line that does masking with `tril`, allowing all tokens to communicate, eg. for 'sentiment analysis'.
- `self-attention` means that queries, keys and values are all produced from the same source. In `cross-attention`, the queries are produced from the logits, but the keys and values como from some other, external source, eg. an encoder module.
- As implemented in the **Attention is all your need** paper, `scaled attention` divides the weights by $\frac{1}{\sqrt{\text{head\_size}}}$. This makes it so when input Q, K are unit variance, the weights will be unit variance as well and softmax will stay diffuse and not saturate the gradients.

<div align="center">
  <img src="../assets/scaled-dot-product-attention.jpg" width="200"/>
</div>

In [29]:
# implementation of a single head self-attention mechanism
head_size = 1

key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

key, query, value

(Linear(in_features=2, out_features=1, bias=False),
 Linear(in_features=2, out_features=1, bias=False),
 Linear(in_features=2, out_features=1, bias=False))

In [30]:
# matrix multiplication of key and query to get the weights
# NOTE: previously, weights were initialized as `zeroes` and then masked
k = key(logits)
q = query(logits)
v = value(logits)
weights = q @ k.transpose(-2, -1)
k.shape, q.shape, v.shape, weights.shape

(torch.Size([4, 8, 1]),
 torch.Size([4, 8, 1]),
 torch.Size([4, 8, 1]),
 torch.Size([4, 8, 8]))

In [31]:
tril = torch.tril(torch.ones(T, T))

weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)
weights[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4608, 0.5392, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3229, 0.3471, 0.3299, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2279, 0.2050, 0.2208, 0.3463, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1803, 0.1678, 0.1765, 0.2398, 0.2355, 0.0000, 0.0000, 0.0000],
        [0.1788, 0.1897, 0.1820, 0.1412, 0.1433, 0.1650, 0.0000, 0.0000],
        [0.1478, 0.1516, 0.1489, 0.1336, 0.1344, 0.1428, 0.1410, 0.0000],
        [0.0963, 0.0838, 0.0924, 0.1667, 0.1610, 0.1160, 0.1241, 0.1596]],
       grad_fn=<SelectBackward0>)

In [32]:
logits_bow_4 = weights @ v
logits_bow_4.shape

torch.Size([4, 8, 1])

### 2.3. Multi-head self-attention

The idea behind multi-head attention is to have multiple heads and then concatenate their results over the channel dimension. In this scenarion, each head represents one communication channel between tokens and each of these channels will be typically correspondingly smaller than the original embedding dimension. For example, consider a 512-dimensional embedding split into 8 heads of 64 dimensions each.

Multiple independent channels of communication helps to improve the loss because the model can learn different types of relationships between tokens. For example, one head might learn to look for verbs, another for nouns, and so on.

<div align="center">
  <img src="../assets/multi-head-attention.jpg" width="300"/>
</div>

In [None]:
class SelfAttentionHead(nn.Module):
    pass

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([SelfAttentionHead(head_size) for _ in range(num_heads)])

    def forward(self, inputs):
        return torch.cat([head(inputs) for head in self.heads], dim=-1)
    
class Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, context_length)
        self.position_embedding_table = nn.Embedding(context_length, context_length)
        self.self_attention_heads = MultiHeadSelfAttention(num_heads=4, head_size=context_length // 4)
        self.language_model_head = nn.Linear(context_length, vocab_size)

### 2.4. Feedforward layers

The feedforward layer is a simple linear transformation followed by a non-linearity. The non-linearity is usually a `GELU` or a `ReLU`. The feedforward layer is applied independently to each token in the sequence.

The idea here is to allow the model to learn complex relationships between tokens. The feedforward layer is a simple way to allow the model to learn non-linear relationships between tokens.

In [None]:
class FeedForward(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU()
        )

    def forward(self, inputs):
        return self.layer(inputs)

### 2.5. Residual connections

<div align="center">
  <img src="../assets/residual-connections.jpg" width="300"/>
</div>

