In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import tiktoken

# 1. Code setup and Baseline language modeling

### 1.1. Reading and exploring the data

In [2]:
with open('../data/tiny_shakespear.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [3]:
print('length of dataset in characters: ', len(text))

length of dataset in characters:  1115393


### 1.2. Tokenization

Trade-off: very long vocabulary size and very short sequences or very short vocabulary size and very long sequences.

In [4]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print('vocab size: ', vocab_size)
print('vocab: ', ''.join(chars))

vocab size:  65
vocab:  
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [5]:
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

encode = lambda string: [stoi[ch] for ch in string]
decode = lambda tokens: ''.join([itos[t] for t in tokens])

print('encoded: ', encode('hello'))
print('decoded: ', decode(encode('hello')))

encoded:  [46, 43, 50, 50, 53]
decoded:  hello


##### 1.2.1. Production-grade example

In [6]:
enc = tiktoken.get_encoding('gpt2')
print('vocab size: ', enc.n_vocab)

vocab size:  50257


In [7]:
print('encoded: ', enc.encode('hello'))
print('decoded: ', enc.decode(enc.encode('hello')))

encoded:  [31373]
decoded:  hello


##### 1.2.2. Tokenize `tiny shakespear` dataset

In [8]:
tokenized_text = torch.tensor(encode(text), dtype=torch.long)
print('Shape of tokenized text: ', tokenized_text.shape)
print('Dtype of tokenized text: ', tokenized_text.dtype)
print('First 10 characters: ', text[:10])
print('First 10 tokens: ', tokenized_text[:10])

Shape of tokenized text:  torch.Size([1115393])
Dtype of tokenized text:  torch.int64
First 10 characters:  First Citi
First 10 tokens:  tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])


##### 1.2.3. `Train` and `Validation` datasets

In [9]:
n = int(len(tokenized_text) * 0.85)
train_dataset = tokenized_text[:n]
val_dataset = tokenized_text[n:]

### 1.3. Data loader: batches of chunks of data

**Note**: when a batch is created, each token dimension is an INFORMATION POINT in relation to the next token. Thus, each batch packs multiple examples in relation to the next token.

Example:

Block: [18, 47, 56, 57, 58,  1, 15, 47]
Next token: 58

1. Context: 18         -> 47 likely follows next
2. Context: 18, 47     -> 56 likely follows next
3. Context: 18, 47, 56 -> 57 likely follows next
4. and so on...

**TIME DIMENSION**: The idea behind training in this way is for the **transformer** to be able to predict the next token with as little as one token of context. Then, after `block_size`is reached, the inputs need to be truncated, because the **transformer** will never receive more than `block_size` tokens of context.

**BATCH DIMENSION**: The idea behind batching is to train the model with multiple examples at the same time. This is done to speed up training and to make the model generalize better.

In [10]:
context_length = 8

print('First 9 tokens: ', train_dataset[:context_length + 1])
print('First block: ', train_dataset[:context_length])
print('Next token: ', train_dataset[context_length])

First 9 tokens:  tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])
First block:  tensor([18, 47, 56, 57, 58,  1, 15, 47])
Next token:  tensor(58)


In [11]:
# time dimension

x = train_dataset[:context_length]
y = train_dataset[1:context_length+1]

for t in range(context_length):
    context = x[:t+1]
    target = y[t]
    print('Context: ', context)
    print('Target: ', target)

Context:  tensor([18])
Target:  tensor(47)
Context:  tensor([18, 47])
Target:  tensor(56)
Context:  tensor([18, 47, 56])
Target:  tensor(57)
Context:  tensor([18, 47, 56, 57])
Target:  tensor(58)
Context:  tensor([18, 47, 56, 57, 58])
Target:  tensor(1)
Context:  tensor([18, 47, 56, 57, 58,  1])
Target:  tensor(15)
Context:  tensor([18, 47, 56, 57, 58,  1, 15])
Target:  tensor(47)
Context:  tensor([18, 47, 56, 57, 58,  1, 15, 47])
Target:  tensor(58)


In [12]:
# batch dimension
def get_batch(split, batch_size, verbose=False):
    data = train_dataset if split == 'train' else val_dataset
    if verbose:
        print("Shape of data: ", len(data))
        print('Sample of data: ', data)
    
    random_observations = torch.randint(0, len(data) - context_length, (batch_size,))
    if verbose:
        print("random_observations: ", random_observations)

    input_batch = torch.stack([data[obs:obs+context_length] for obs in random_observations])
    target_batch = torch.stack([data[obs+1:obs+context_length+1] for obs in random_observations])
    
    return input_batch, target_batch

batch_size = 4
input_batch, target_batch = get_batch('train', batch_size, True)
print('Input batch: ', input_batch)
print('Target batch: ', target_batch)

for batch in range(batch_size):           # batch dimension
    for time in range(context_length):    # time dimension
        context = input_batch[batch, :time+1]
        target = target_batch[batch, time]
        if batch == 0:
            print('Batch: ', batch, 'Time: ', time)
            print('Context: ', context)
            print('Target: ', target)

Shape of data:  948084
Sample of data:  tensor([18, 47, 56,  ..., 13, 57,  1])
random_observations:  tensor([475443, 318696,  34022, 391054])
Input batch:  tensor([[35, 46, 43, 52,  1, 23, 47, 52],
        [53, 42, 63,  1, 57, 46, 39, 50],
        [47, 52,  1, 21,  1, 51, 43, 43],
        [43, 52, 58,  8,  0,  0, 16, 33]])
Target batch:  tensor([[46, 43, 52,  1, 23, 47, 52, 45],
        [42, 63,  1, 57, 46, 39, 50, 50],
        [52,  1, 21,  1, 51, 43, 43, 58],
        [52, 58,  8,  0,  0, 16, 33, 23]])
Batch:  0 Time:  0
Context:  tensor([35])
Target:  tensor(46)
Batch:  0 Time:  1
Context:  tensor([35, 46])
Target:  tensor(43)
Batch:  0 Time:  2
Context:  tensor([35, 46, 43])
Target:  tensor(52)
Batch:  0 Time:  3
Context:  tensor([35, 46, 43, 52])
Target:  tensor(1)
Batch:  0 Time:  4
Context:  tensor([35, 46, 43, 52,  1])
Target:  tensor(23)
Batch:  0 Time:  5
Context:  tensor([35, 46, 43, 52,  1, 23])
Target:  tensor(47)
Batch:  0 Time:  6
Context:  tensor([35, 46, 43, 52,  1, 23,

### 1.4. Simplest baseline: bigram language model

**Note**: this implementation is ridiculous by design. As a simple *character-level* bigram model, the prediction throws away all context.

In [13]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        """
        `self.embedding` is 65 x 65, because for each of the 65 tokens in the vocabulary,
        we have a 65-dimensional vector that represents the probability of the next token
        given the context.
        """
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, inputs, targets=None):
        logits = self.embedding(inputs)    # B, T, C (batch, time, channels)

        if targets is None:
            loss = None
        else:
            _, _, C = logits.shape
            logits = logits.view(-1, C)  # Flatten to [B * T, C]
            targets = targets.view(-1)   # Flatten to [B * T]
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, inputs, num_predictions):
        predictions = torch.zeros(inputs.shape[0] * num_predictions, dtype=torch.long)

        for i in range(num_predictions):
            logits, _ = self(inputs)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            target = torch.multinomial(probs, num_samples=1)
            predictions[i] = target

        return predictions
    
model = BigramLanguageModel(vocab_size)
logits, loss = model(input_batch, target_batch)
print("Vocab size: ", vocab_size)
print('Input batch shape: ', input_batch.shape)
print('Target batch shape: ', target_batch.shape)
print('Logits shape: ', logits.shape)
print('Loss: ', loss)

NUM_PREDICTIONS = 50
INPUTS = input_batch[:1]

print('Inputs: ', INPUTS)
predicted_targets = model.generate(INPUTS, NUM_PREDICTIONS)
print('Predicted targets: ', predicted_targets)
print('Predicted characters: ', decode(predicted_targets.tolist()))


Vocab size:  65
Input batch shape:  torch.Size([4, 8])
Target batch shape:  torch.Size([4, 8])
Logits shape:  torch.Size([32, 65])
Loss:  tensor(4.4551, grad_fn=<NllLossBackward0>)
Inputs:  tensor([[35, 46, 43, 52,  1, 23, 47, 52]])
Predicted targets:  tensor([ 1, 24, 33, 16, 57, 51, 54,  7,  5,  6, 54,  7, 15, 11, 20, 54, 54,  6,
         9, 13,  9, 11, 15, 54,  5,  1, 56, 54,  9, 37, 15, 15, 27, 51, 33,  0,
         7, 20, 54, 15,  0, 45, 15, 54, 49, 15, 28, 54, 15, 27])
Predicted characters:   LUDsmp-',p-C;Hpp,3A3;Cp' rp3YCCOmU
-HpC
gCpkCPpCO


### 1.5. Training the bigram model

In [14]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [15]:
batch_size = 32

for epoch in range(10000):
    input_batch, target_batch = get_batch('train', batch_size)
    logits, loss = model(input_batch, target_batch)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if epoch % 2000 == 0:
        print(f'Epoch {epoch} | Loss {loss.item()}')

print(f'Loss {loss.item()}')

Epoch 0 | Loss 4.719812393188477
Epoch 2000 | Loss 2.9696238040924072
Epoch 4000 | Loss 2.6671085357666016
Epoch 6000 | Loss 2.4390830993652344
Epoch 8000 | Loss 2.452904462814331
Loss 2.477780342102051


In [16]:
NUM_PREDICTIONS = 50
INPUTS = input_batch[:1]

print('Inputs: ', INPUTS)
predicted_targets = model.generate(INPUTS, NUM_PREDICTIONS)
print('Predicted characters: ', decode(predicted_targets.tolist()))

Inputs:  tensor([[57,  1, 39,  1, 57, 47, 41, 49]])
Predicted characters:  
neenieeee,u; ,iee n nese iesee . -es ,in iueie.e 


# 2. Building the "self-attention" model

### 2.1. `Version 1` - Weakest form of aggregation: averaging past context

Each token in a batch should communicate information with other tokens in the batch, in such a way that information only flows from past tokens to the current token.

Consider the fifth token in a batch of eight tokens. It should not communicate with tokens in the sixth, seventh and eighth positions, because those are FUTURE tokens in a sequence, but it should communicate with the fourth, third, second and first tokens, because those are PAST tokens in a sequence. This way, information only flows from previous context to the current timestep.

Given this, the easiest way for tokens to communicate is to simply average all previous embeddings. This is the weakest form of aggregation and is extremely lossy, because all information about spatial arrangement of tokens is lost. 

This implementation is also very low performance, because it requires computation to be linear in the number of tokens in the sequence.

In [17]:
B, T, C = 4, 8, 2
logits = torch.randn(B, T, C)
print("Shape of logits: ", logits.shape)

Shape of logits:  torch.Size([4, 8, 2])


In [18]:
# we want bag of words be logits[b, t] = mean of logits[b, i] for i<=t
logits_bow = torch.zeros((B, T, C))

for batch in range(B):
    for time in range(T):
        logits_prev = logits[batch, :time+1]
        logits_bow[batch, time] = torch.mean(logits_prev, dim=0)

In [19]:
# each location of logits_bow is the vertical mean of all previous logits
print('Logits BOW shape: ', logits_bow.shape)
logits[0], logits_bow[0]

Logits BOW shape:  torch.Size([4, 8, 2])


(tensor([[-0.9618, -0.6699],
         [ 0.1400,  1.8193],
         [-0.2196,  2.5470],
         [-1.4458,  0.0633],
         [ 1.3801, -0.7089],
         [-2.2699, -0.1819],
         [ 1.6007,  0.8455],
         [-0.4921, -1.3103]]),
 tensor([[-0.9618, -0.6699],
         [-0.4109,  0.5747],
         [-0.3472,  1.2321],
         [-0.6218,  0.9399],
         [-0.2215,  0.6101],
         [-0.5629,  0.4781],
         [-0.2538,  0.5306],
         [-0.2836,  0.3005]]))

### 2.2. The `mathematical trick` in self-attention: matrix multiplication with triangular mask

Matrix multiplication is a very efficient way to calculate the dot product of each token with all other tokens. By masking the upper triangular part of the matrix, we can ensure that each token only communicates with previous tokens.

In [20]:
a = torch.ones(3, 3)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b

print('a =')
print(a)
print('---')
print('b =')
print(b)
print('---')
print('c =')
print(c)

a =
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
---
b =
tensor([[8., 9.],
        [3., 1.],
        [5., 9.]])
---
c =
tensor([[16., 19.],
        [16., 19.],
        [16., 19.]])


In [21]:
# this is the whole trick
torch.tril(torch.ones(3, 3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [22]:
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, dim=1, keepdim=True)  # normalize
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b

print('a =')
print(a)
print('---')
print('b =')
print(b)
print('---')
print('c =')
print(c)

a =
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
---
b =
tensor([[8., 9.],
        [2., 5.],
        [9., 9.]])
---
c =
tensor([[8.0000, 9.0000],
        [5.0000, 7.0000],
        [6.3333, 7.6667]])


### 2.3. `Version 2` - Self-attention with matrix multiplication

In [23]:
weights = torch.tril(torch.ones(T, T))
weights = weights / torch.sum(weights, dim=1, keepdim=True)
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [24]:
logits_bow_2 = weights @ logits
print('Logits BOW 2 shape: ', logits_bow_2.shape)
print('Logits BOW == Logits BOW 2? ', torch.allclose(logits_bow, logits_bow_2))

Logits BOW 2 shape:  torch.Size([4, 8, 2])
Logits BOW == Logits BOW 2?  True


### 2.4. `Version 3`: Self-attention adding softmax

The softmax function exponentiates each element and normalizes the results so that they sum up to **1** along the specified dimension. Since &minus;&infin; values correspond to zeros when exponentiated, the softmax ensures that the attention is focused only on the elements allowed by the lower triangular mask.

An important aspect to note here is that the weights are initialized as `zeroes`, giving room for future `affinities` between tokens to be data dependent. That is, they will start looking at each other and some tokens will find other tokens more or less interesting.

In [25]:
tril = torch.tril(torch.ones(T, T))
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [26]:
weights = torch.zeros(T, T)
weights = weights.masked_fill(tril == 0, float('-inf'))
weights

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [27]:
weights = F.softmax(weights, dim=1)
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [28]:
logits_bow_3 = weights @ logits
print('Logits BOW 3 shape: ', logits_bow_3.shape)
print('Logits BOW == Logits BOW 3? ', torch.allclose(logits_bow, logits_bow_3))

Logits BOW 3 shape:  torch.Size([4, 8, 2])
Logits BOW == Logits BOW 3?  True


### 2.5. `Version 4`: THE CRUX OF IT ALL - Self-attention with `affinities`

In previous versions, all past tokens are averaged in the context of the current token, resulting in uniform affinities. But self-attention is all about learning data-dependent affinities between tokens so that, for example, a *vowel* token might look for *consonants* in its past and might want to know what those consonants were and let this data to flow to the current token.

In order to solve this data dependency problem, every single token in self-attention will emit two vectors:
- **Query**: roughly speaking is "What am I looking for?"
- **Key**: roughly speaking is "What do I contain?"

These vectors will be produced in parallel and independently by two linear transformations of the token embeddings. After creation, they will communicate through a dot product, resulting in a scalar value that will be used as the **affinity** between tokens. If the `Key` and the `Query` are very similar, the dot product will be high and the value of the token will be weighted more.

Finally, for the aggregations to be calculated, the `Value` vector is created by a third linear transformation of the token embeddings. This way, the logits become private information of the token and the aggregation is done by a weighted sum of the `Value` vector. Roughly speaking, "If you find me interesting, this is what I will communicate to you."

Important notes about self-attention:
- Attention is a **communication mechanism**. It can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other.
- In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate.

In [29]:
# implementation of a single head self-attention mechanism
head_size = 1

key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

key, query, value

(Linear(in_features=2, out_features=1, bias=False),
 Linear(in_features=2, out_features=1, bias=False),
 Linear(in_features=2, out_features=1, bias=False))

In [34]:
# matrix multiplication of key and query to get the weights
# NOTE: previously, weights were initialized as `zeroes` and then masked
k = key(logits)
q = query(logits)
v = value(logits)
weights = q @ k.transpose(-2, -1)
k.shape, q.shape, v.shape, weights.shape

(torch.Size([4, 8, 1]),
 torch.Size([4, 8, 1]),
 torch.Size([4, 8, 1]),
 torch.Size([4, 8, 8]))

In [31]:
tril = torch.tril(torch.ones(T, T))

weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)
weights[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2441, 0.7559, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0967, 0.3937, 0.5096, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3368, 0.1882, 0.1692, 0.3058, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1814, 0.2159, 0.2229, 0.1867, 0.1931, 0.0000, 0.0000, 0.0000],
        [0.2389, 0.0812, 0.0666, 0.1998, 0.1620, 0.2513, 0.0000, 0.0000],
        [0.0669, 0.2183, 0.2713, 0.0814, 0.1024, 0.0633, 0.1963, 0.0000],
        [0.1712, 0.0642, 0.0536, 0.1455, 0.1202, 0.1793, 0.0701, 0.1959]],
       grad_fn=<SelectBackward0>)

In [33]:
logits_bow_4 = weights @ v
logits_bow_4.shape

torch.Size([4, 8, 1])

# 3. Building the transformer