In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import tiktoken

# 1. Baseline language modeling and Code setup

### 1.1. Reading and exploring the data

In [2]:
with open('../data/tiny_shakespear.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [3]:
print('length of dataset in characters: ', len(text))

length of dataset in characters:  1115393


### 1.2. Tokenization

Trade-off: very long vocabulary size and very short sequences or very short vocabulary size and very long sequences.

In [4]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print('vocab size: ', vocab_size)
print('vocab: ', ''.join(chars))

vocab size:  65
vocab:  
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [5]:
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

encode = lambda string: [stoi[ch] for ch in string]
decode = lambda tokens: ''.join([itos[t] for t in tokens])

print('encoded: ', encode('hello'))
print('decoded: ', decode(encode('hello')))

encoded:  [46, 43, 50, 50, 53]
decoded:  hello


##### 1.2.1. Production-grade example

In [6]:
enc = tiktoken.get_encoding('gpt2')
print('vocab size: ', enc.n_vocab)

vocab size:  50257


In [7]:
print('encoded: ', enc.encode('hello'))
print('decoded: ', enc.decode(enc.encode('hello')))

encoded:  [31373]
decoded:  hello


##### 1.2.2. Tokenize `tiny shakespear` dataset

In [8]:
tokenized_text = torch.tensor(encode(text), dtype=torch.long)
print('Shape of tokenized text: ', tokenized_text.shape)
print('Dtype of tokenized text: ', tokenized_text.dtype)
print('First 10 characters: ', text[:10])
print('First 10 tokens: ', tokenized_text[:10])

Shape of tokenized text:  torch.Size([1115393])
Dtype of tokenized text:  torch.int64
First 10 characters:  First Citi
First 10 tokens:  tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])


##### 1.2.3. `Train` and `Validation` datasets

In [9]:
n = int(len(tokenized_text) * 0.85)
train_dataset = tokenized_text[:n]
val_dataset = tokenized_text[n:]

### 1.3. Data loader: batches of chunks of data

**Note**: when a batch is created, each token dimension is an INFORMATION POINT in relation to the next token. Thus, each batch packs multiple examples in relation to the next token.

Example:

Block: [18, 47, 56, 57, 58,  1, 15, 47]
Next token: 58

1. Context: 18         -> 47 likely follows next
2. Context: 18, 47     -> 56 likely follows next
3. Context: 18, 47, 56 -> 57 likely follows next
4. and so on...

**TIME DIMENSION**: The idea behind training in this way is for the **transformer** to be able to predict the next token with as little as one token of context. Then, after `block_size`is reached, the inputs need to be truncated, because the **transformer** will never receive more than `block_size` tokens of context.

**BATCH DIMENSION**: The idea behind batching is to train the model with multiple examples at the same time. This is done to speed up training and to make the model generalize better.

In [10]:
context_length = 8

print('First 9 tokens: ', train_dataset[:context_length + 1])
print('First block: ', train_dataset[:context_length])
print('Next token: ', train_dataset[context_length])

First 9 tokens:  tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])
First block:  tensor([18, 47, 56, 57, 58,  1, 15, 47])
Next token:  tensor(58)


In [11]:
# time dimension

x = train_dataset[:context_length]
y = train_dataset[1:context_length+1]

for t in range(context_length):
    context = x[:t+1]
    target = y[t]
    print('Context: ', context)
    print('Target: ', target)

Context:  tensor([18])
Target:  tensor(47)
Context:  tensor([18, 47])
Target:  tensor(56)
Context:  tensor([18, 47, 56])
Target:  tensor(57)
Context:  tensor([18, 47, 56, 57])
Target:  tensor(58)
Context:  tensor([18, 47, 56, 57, 58])
Target:  tensor(1)
Context:  tensor([18, 47, 56, 57, 58,  1])
Target:  tensor(15)
Context:  tensor([18, 47, 56, 57, 58,  1, 15])
Target:  tensor(47)
Context:  tensor([18, 47, 56, 57, 58,  1, 15, 47])
Target:  tensor(58)


In [12]:
# batch dimension
def get_batch(split, batch_size, verbose=False):
    data = train_dataset if split == 'train' else val_dataset
    if verbose:
        print("Shape of data: ", len(data))
        print('Sample of data: ', data)
    
    random_observations = torch.randint(0, len(data) - context_length, (batch_size,))
    if verbose:
        print("random_observations: ", random_observations)

    input_batch = torch.stack([data[obs:obs+context_length] for obs in random_observations])
    target_batch = torch.stack([data[obs+1:obs+context_length+1] for obs in random_observations])
    
    return input_batch, target_batch

batch_size = 4
input_batch, target_batch = get_batch('train', batch_size, True)
print('Input batch: ', input_batch)
print('Target batch: ', target_batch)

for batch in range(batch_size):           # batch dimension
    for time in range(context_length):    # time dimension
        context = input_batch[batch, :time+1]
        target = target_batch[batch, time]
        if batch == 0:
            print('Batch: ', batch, 'Time: ', time)
            print('Context: ', context)
            print('Target: ', target)

Shape of data:  948084
Sample of data:  tensor([18, 47, 56,  ..., 13, 57,  1])
random_observations:  tensor([724804, 822947, 340160, 803020])
Input batch:  tensor([[43, 57, 57,  6,  0, 27, 56,  1],
        [ 1, 43, 63, 43, 57,  1, 43, 50],
        [10,  0, 32, 46, 47, 52, 45, 57],
        [43,  1, 61, 47, 50, 50,  1, 39]])
Target batch:  tensor([[57, 57,  6,  0, 27, 56,  1, 43],
        [43, 63, 43, 57,  1, 43, 50, 57],
        [ 0, 32, 46, 47, 52, 45, 57,  1],
        [ 1, 61, 47, 50, 50,  1, 39, 50]])
Batch:  0 Time:  0
Context:  tensor([43])
Target:  tensor(57)
Batch:  0 Time:  1
Context:  tensor([43, 57])
Target:  tensor(57)
Batch:  0 Time:  2
Context:  tensor([43, 57, 57])
Target:  tensor(6)
Batch:  0 Time:  3
Context:  tensor([43, 57, 57,  6])
Target:  tensor(0)
Batch:  0 Time:  4
Context:  tensor([43, 57, 57,  6,  0])
Target:  tensor(27)
Batch:  0 Time:  5
Context:  tensor([43, 57, 57,  6,  0, 27])
Target:  tensor(56)
Batch:  0 Time:  6
Context:  tensor([43, 57, 57,  6,  0, 27, 

### 1.4. Simplest baseline: bigram language model

**Note**: this implementation is ridiculous by design. As a simple *character-level* bigram model, the prediction throws away all context.

In [13]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        """
        `self.embedding` is 65 x 65, because for each of the 65 tokens in the vocabulary,
        we have a 65-dimensional vector that represents the probability of the next token
        given the context.
        """
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, inputs, targets=None):
        logits = self.embedding(inputs)    # B, T, C (batch, time, channels)

        if targets is None:
            loss = None
        else:
            _, _, C = logits.shape
            logits = logits.view(-1, C)  # Flatten to [B * T, C]
            targets = targets.view(-1)   # Flatten to [B * T]
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, inputs, num_predictions):
        predictions = torch.zeros(inputs.shape[0] * num_predictions, dtype=torch.long)

        for i in range(num_predictions):
            logits, _ = self(inputs)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            target = torch.multinomial(probs, num_samples=1)
            predictions[i] = target

        return predictions
    
model = BigramLanguageModel(vocab_size)
logits, loss = model(input_batch, target_batch)
print("Vocab size: ", vocab_size)
print('Input batch shape: ', input_batch.shape)
print('Target batch shape: ', target_batch.shape)
print('Logits shape: ', logits.shape)
print('Loss: ', loss)

NUM_PREDICTIONS = 50
INPUTS = input_batch[:1]

print('Inputs: ', INPUTS)
predicted_targets = model.generate(INPUTS, NUM_PREDICTIONS)
print('Predicted targets: ', predicted_targets)
print('Predicted characters: ', decode(predicted_targets.tolist()))


Vocab size:  65
Input batch shape:  torch.Size([4, 8])
Target batch shape:  torch.Size([4, 8])
Logits shape:  torch.Size([32, 65])
Loss:  tensor(4.7661, grad_fn=<NllLossBackward0>)
Inputs:  tensor([[43, 57, 57,  6,  0, 27, 56,  1]])
Predicted targets:  tensor([58, 31,  8, 33,  2, 31, 50, 11, 30, 20, 17, 12,  6,  1, 61, 30, 26,  6,
        17, 54, 20, 30, 35, 12, 49, 46, 34, 31, 54, 49, 48, 61, 46,  4, 48,  5,
        35, 12,  2, 10,  6, 11, 48, 35,  1, 58, 55, 14, 59, 48])
Predicted characters:  tS.U!Sl;RHE?, wRN,EpHRW?khVSpkjwh&j'W?!:,;jW tqBuj


### 1.5. Training the bigram model

In [14]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [15]:
batch_size = 32

for epoch in range(10000):
    input_batch, target_batch = get_batch('train', batch_size)
    logits, loss = model(input_batch, target_batch)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if epoch % 2000 == 0:
        print(f'Epoch {epoch} | Loss {loss.item()}')

print(f'Loss {loss.item()}')

Epoch 0 | Loss 4.661503314971924
Epoch 2000 | Loss 3.0199105739593506
Epoch 4000 | Loss 2.684917449951172
Epoch 6000 | Loss 2.561742067337036
Epoch 8000 | Loss 2.456726312637329
Loss 2.36138653755188


In [16]:
NUM_PREDICTIONS = 50
INPUTS = input_batch[:1]

print('Inputs: ', INPUTS)
predicted_targets = model.generate(INPUTS, NUM_PREDICTIONS)
print('Predicted characters: ', decode(predicted_targets.tolist()))

Inputs:  tensor([[50, 47, 62, 43, 52, 43, 57,  0]])
Predicted characters:  PSW!SLBIETR
u
RLHTI
TMPWTWATA
KI'DTHWFQEWOIOTNOC




# 2. Building the "self-attention" model

### 2.1. `Version 1` - Weakest form of aggregation: averaging past context

Each token in a batch should communicate information with other tokens in the batch, in such a way that information only flows from past tokens to the current token.

Consider the fifth token in a batch of eight tokens. It should not communicate with tokens in the sixth, seventh and eighth positions, because those are FUTURE tokens in a sequence, but it should communicate with the fourth, third, second and first tokens, because those are PAST tokens in a sequence. This way, information only flows from previous context to the current timestep.

Given this, the easiest way for tokens to communicate is to simply average all previous embeddings. This is the weakest form of aggregation and is extremely lossy, because all information about spatial arrangement of tokens is lost. 

In [17]:
B, T, C = 4, 8, 2
logits = torch.randn(B, T, C)
print("Shape of logits: ", logits.shape)

Shape of logits:  torch.Size([4, 8, 2])


In [18]:
# we want logits[b, t] = mean of logits[b, i] for i<=t
logits_bow = torch.zeros((B, T, C))    # bow: bag of words

for batch in range(B):
    for time in range(T):
        logits_prev = logits[batch, :time+1]
        logits_bow[batch, time] = torch.mean(logits_prev, dim=0)

In [19]:
# each location of logits_bow is the vertical mean of all previous logits
logits[0], logits_bow[0]

(tensor([[ 0.0372, -0.8731],
         [ 0.8561,  1.5495],
         [-0.7364, -0.5302],
         [-0.7129,  1.1072],
         [-0.2394, -0.3602],
         [ 0.8442,  0.8462],
         [ 1.2206,  1.9347],
         [-1.4610, -0.5564]]),
 tensor([[ 0.0372, -0.8731],
         [ 0.4466,  0.3382],
         [ 0.0523,  0.0487],
         [-0.1390,  0.3133],
         [-0.1591,  0.1786],
         [ 0.0081,  0.2899],
         [ 0.1813,  0.5249],
         [-0.0240,  0.3897]]))

### 2.2. The `mathematical trick` in self-attention: matrix multiplication with triangular mask

Matrix multiplication is a very efficient way to calculate the dot product of each token with all other tokens. By masking the upper triangular part of the matrix, we can ensure that each token only communicates with previous tokens.

In [20]:
a = torch.ones(3, 3)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b

print('a =')
print(a)
print('---')
print('b =')
print(b)
print('---')
print('c =')
print(c)

a =
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
---
b =
tensor([[4., 9.],
        [3., 8.],
        [3., 0.]])
---
c =
tensor([[10., 17.],
        [10., 17.],
        [10., 17.]])


In [21]:
# this is the whole trick
torch.tril(torch.ones(3, 3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [22]:
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, dim=1, keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b

print('a =')
print(a)
print('---')
print('b =')
print(b)
print('---')
print('c =')
print(c)

a =
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
---
b =
tensor([[2., 1.],
        [4., 3.],
        [1., 5.]])
---
c =
tensor([[2.0000, 1.0000],
        [3.0000, 2.0000],
        [2.3333, 3.0000]])


### 2.3. `Version 2` - Self-attention with matrix multiplication

### 2.4. `Version 3`: Selft-attention adding softmax
### 2.5. positional encoding
### 2.6. THE CRUX OF IT ALL: version 4: self-attention
### 2.7. note 1: attention as communication
### 2.8. note 2: attention has no notion of space, operates over sets
### 2.9. note 3: there is no communication across batch dimension
### 2.10. note 4: encoder blocks vs. decoder blocks
### 2.11. note 5: attention vs. self-attention vs. cross-attention
### 2.12. note 6: "scaled" self-attention. why divide by sqrt(head_size)

# 3. Building the transformer