In [11]:
import torch

### Data Loading and Tokenization

In [12]:
with open('ansturm.txt', 'r', encoding='utf-8') as f:
    text = f.read()

    punctuation = {
        u'\u2003': " ",
    }
    
    for key, value in punctuation.items():
        text = text.replace(key, value)

# A sample of the text
print(text[:1000])

Sokrate  Flucht
Montag, 22. September 2014

Ich sehe die Nacht hereinbrechen. Dunkelheit umhüllt nach und nach die karge Landschaft, mir wird zunehmend kälter. Ich muss mich beeilen, mir bleibt nicht mehr viel Zeit. In wenigen Minuten wird es zu spät sein, alles vorbei, mein Leben sich in Luft auflösen, meine Erinnerungen sich verflüchtigen, mein Ich im Nichts verschwinden. Was habe ich mir nur dabei gedacht. Dass ich jetzt hier stehe, an diesem Ort, von dem niemand etwas ahnt, geschweige denn weiß, habe ich wieder einzig meinem sturen Kopf zu verdanken.
Ich erreiche den Wald, mein Herz schlägt höher. Was werden die Folgen von meinem unerwünschten Eindringen hier sein? Schreckliche Gedanken gehen mir durch den Kopf. Aber es ist die Gefahr wert, oder? Ich habe viele wichtige Informationen bekommen, welche die Menschheit vor einer gewaltigen Bedrohung schützen könnten. Vorausgesetzt, es kommt jemals dazu, dass sie einen Menschen erreichen. Und die Chance, dass der mir diese unvorstellbar

In [13]:
# EDA on the text
print('Number of characters:', len(text))
print('Number of words:', len(text.split(' ')))

# Unique characters
chars = sorted(list(set(text)))
vocab_size = len(chars)
print('Vocabulary size:', vocab_size)
print(chars)

Number of characters: 295869
Number of words: 46581
Vocabulary size: 81
['\n', ' ', '!', ',', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '|', '«', '»', 'Ä', 'Ö', 'Ü', 'ß', 'ä', 'ö', 'ü', '–', '’', '‚']


In [14]:
# Tokenization (simple character-level encoding)
stoi = { ch: i for i, ch in enumerate(chars) }
itos = { i: ch for i, ch in enumerate(chars) }
encode = lambda s: [stoi[ch] for ch in s]
decode = lambda l: ''.join([itos[i] for i in l])

print('Encoded:', encode('Hallo Welt!'))
print('Decoded:', decode(encode('Hallo Welt!')))

Encoded: [24, 42, 53, 53, 56, 1, 39, 46, 53, 61, 2]
Decoded: Hallo Welt!


In [15]:
# Encode the text into data tensor
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:250])

torch.Size([295869]) torch.int64
tensor([35, 56, 52, 59, 42, 61, 46,  1,  1, 22, 53, 62, 44, 49, 61,  0, 29, 56,
        55, 61, 42, 48,  3,  1,  7,  7,  4,  1, 35, 46, 57, 61, 46, 54, 43, 46,
        59,  1,  7,  5,  6,  9,  0,  0, 25, 44, 49,  1, 60, 46, 49, 46,  1, 45,
        50, 46,  1, 30, 42, 44, 49, 61,  1, 49, 46, 59, 46, 50, 55, 43, 59, 46,
        44, 49, 46, 55,  4,  1, 20, 62, 55, 52, 46, 53, 49, 46, 50, 61,  1, 62,
        54, 49, 77, 53, 53, 61,  1, 55, 42, 44, 49,  1, 62, 55, 45,  1, 55, 42,
        44, 49,  1, 45, 50, 46,  1, 52, 42, 59, 48, 46,  1, 28, 42, 55, 45, 60,
        44, 49, 42, 47, 61,  3,  1, 54, 50, 59,  1, 64, 50, 59, 45,  1, 67, 62,
        55, 46, 49, 54, 46, 55, 45,  1, 52, 75, 53, 61, 46, 59,  4,  1, 25, 44,
        49,  1, 54, 62, 60, 60,  1, 54, 50, 44, 49,  1, 43, 46, 46, 50, 53, 46,
        55,  3,  1, 54, 50, 59,  1, 43, 53, 46, 50, 43, 61,  1, 55, 50, 44, 49,
        61,  1, 54, 46, 49, 59,  1, 63, 50, 46, 53,  1, 41, 46, 50, 61,  4,  1,
       

### Data Loader with Batch and Time Dimension

In [16]:
# Train / Test split
n = int(0.8 * len(data))
train_data = data[:n]
val_data = data[n:]

In [17]:
# Block / Context size
block_size = 8

In [18]:
# Time dimension
# Note: Transformer will be able to predict from context as short as a single character, up to the whole block size
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f'When context: {context}, -> target: {target}')

When context: tensor([35]), -> target: 56
When context: tensor([35, 56]), -> target: 52
When context: tensor([35, 56, 52]), -> target: 59
When context: tensor([35, 56, 52, 59]), -> target: 42
When context: tensor([35, 56, 52, 59, 42]), -> target: 61
When context: tensor([35, 56, 52, 59, 42, 61]), -> target: 46
When context: tensor([35, 56, 52, 59, 42, 61, 46]), -> target: 1
When context: tensor([35, 56, 52, 59, 42, 61, 46,  1]), -> target: 1


Note: We get 32 examples with a single batch , each of them completely independent as far as the transformer is concerned

In [19]:
# Batch & time dimension
torch.manual_seed(42)
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

def get_batch(split):
    # generate a small batch of data of inputs × and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,)) # get batch_size # of random integer offsets within the data
    x = torch.stack([data[i:i+block_size] for i in ix]) # stack the tensors of context slices
    y = torch.stack([data[i+1:i+block_size+1] for i in ix]) # stack the tensors of predicted slices, offset by +1 from context
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print (yb.shape)
print (yb)

for b in range(batch_size):
    print(f'Batch {b}:')
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f'When context: {context}, -> target: {target}')



inputs:
torch.Size([4, 8])
tensor([[53, 46, 55,  1, 20, 42, 55, 52],
        [46, 59, 60, 50, 57,  1, 64, 50],
        [46, 59,  1, 39, 42, 53, 45,  4],
        [56, 44, 49,  1, 46, 60,  1, 64]])
targets:
torch.Size([4, 8])
tensor([[46, 55,  1, 20, 42, 55, 52,  1],
        [59, 60, 50, 57,  1, 64, 50, 59],
        [59,  1, 39, 42, 53, 45,  4,  1],
        [44, 49,  1, 46, 60,  1, 64, 42]])
Batch 0:
When context: tensor([53]), -> target: 46
When context: tensor([53, 46]), -> target: 55
When context: tensor([53, 46, 55]), -> target: 1
When context: tensor([53, 46, 55,  1]), -> target: 20
When context: tensor([53, 46, 55,  1, 20]), -> target: 42
When context: tensor([53, 46, 55,  1, 20, 42]), -> target: 55
When context: tensor([53, 46, 55,  1, 20, 42, 55]), -> target: 52
When context: tensor([53, 46, 55,  1, 20, 42, 55, 52]), -> target: 1
Batch 1:
When context: tensor([46]), -> target: 59
When context: tensor([46, 59]), -> target: 60
When context: tensor([46, 59, 60]), -> target: 50
When 

### Baseline: Bigram Language Model

In [20]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx and targets are both (B,T) tensor of integers
        # logits are basically the scores for the next character in the sequence
        logits = self.token_embedding_table(idx) # (B,T,C) dimension where C is the vocab size
        
        if targets is None:
            loss = None
        else:
            # need to reshape logits as expected by F.cross_entropy
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # stretch out the first two dimensions, keep channel C as 2nd dimension
            targets = targets.view(B*T) # make one dimensional
            loss = F.cross_entropy(logits, targets) # also called "negative log-likelihood" loss

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

# Generate some text on untrained model
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))

torch.Size([32, 81])
tensor(4.9485, grad_fn=<NllLossBackward0>)

gvIVlTdM9WD»F7äüAVzpk4eDmYybÜQ«o.,zEhS‚Y43AdfZDm Lßf–R»!GHm«R’nq HqmQpoqBm’p8rGH«|9Rxc‚MA0»äJNGC9wu2


In [21]:
# Create optimizer
import torch.optim as optim
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [22]:
batch_size = 32
for steps in range(10000): # increase number of steps for good results... 
    
    # sample a batch of data
    xb, yb = get_batch('train') # get batch of data

    # forward pass
    logits, loss = m(xb, yb) # evaluate the loss

    # clear the gradients
    optimizer.zero_grad(set_to_none=True) 

    # get gradients for all parameters and update parameters accordingly
    loss.backward()
    optimizer.step()

print(loss.item())


2.4390273094177246


In [23]:
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))


Re ert siaurens ust wase miert ewaräch benwin agtanen gl st.
intteugeund Schau d brt hererchneleinen


### Mathematical trick  for self-attention

#### Version 1: Weighted aggregation of past context with for loops

We want $x[b,t] = mean_{i<=t} x[b,i]$

In [24]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2 # batch size, time steps, channels

x = torch.randn(B, T, C)
print(x.shape)

torch.Size([4, 8, 2])


In [25]:
xbow = torch.zeros((B,T,C)) # Bag-of-Words representation
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0)

In [26]:
x[0] # the 0th batch element

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [27]:
xbow[0] # e.g the the second row is now an average of the 0th, 1st and 2nd rows of x

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

the above is **really** inefficient, so we want to use matrix operations instead:

#### Version 2: Weighted aggregation for past context with matrix operations

In [28]:
# toy example
torch.manual_seed(42)
a = torch.tril(torch.ones (3, 3)) # lower triangular matrix
a = a / torch.sum(a, 1, keepdim=True) # normalize rows to sum to 1 (helps us average!
b = torch.randint(0, 10, (3,2)).float() 
c = a @ b
print('a=')
print(a)
print('b=')
print(b)
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [29]:
# version 2: using matrix multiply for a weighted aggregation
weights = torch.tril(torch.ones(T, T)) # again, our triangular matrix
weights = weights / weights.sum(1, keepdim=True) # normalize rows to sum to 1

# batched matrix multiply
xbow2 = weights @ x # (B, T, T) @ (B, T, C) ----> (B, T, C) | means: (T, T) @ (T,C) for each batch element, resulting in (B, T, C)
torch.allclose(xbow, xbow2)

True

### Version 3: Weighted aggregation for past context with softmax

In [30]:
# version 3: use Softmax
lower_triangular = torch.tril(torch.ones(T, T))
weights = torch.zeros((T,T)) # 0s bow, but can be viewed as "interaction strength" between tokens! (how interesting they find each other)
weights = weights.masked_fill(lower_triangular == 0, float('-inf')) # tokens from the past cannot interact with tokens from the future! therefore, mask out the upper triangular part with -inf 
weights = F.softmax(weights, dim=-1) # softmax over all rows
xbow3 = weights @ x # aggregate the token's values depending on their interaction strength 
torch.allclose(xbow, xbow3)

True

### Version 4: Self-attention!

Instead of simply averaging, instead, we want "interaction strength" to be data-dependent, not just an average!

How self-attention solves this:
- Every single token will emit two vectors: a query vector (what am I looking for?) and a key vector (what do I contain?)
- We calculate the dot product between the query vector of our given token and the key vector of every other token, yielding a vector of scores. Thus, if key and query vectors are "aligned", they will interact with a high amount (the dot product will yield a higher number). This leads to attention on the tokens that are most relevant to each other. This is the core of self-attention.
- There is a third vector, value (if you find me interesting, here's what I have to communicate to you)

In [31]:
# version 4: self-attention!
torch.manual_seed(1337)
B, T, C = 4,8,32 # batch, time, channels | note: we now use 32 channels instead of 2, meaning we have 32 features per token
x = torch.randn (B, T, C)

lower_triangular = torch.tril(torch.ones(T, T))
weights = torch.zeros((T,T)) # instead, we want "interaction strength" to be data-dependent, not just an average!
weights = weights.masked_fill(lower_triangular == 0, float('-inf'))
weights = F.softmax (weights, dim=-1)
out = weights @ x

out.shape

torch.Size([4, 8, 32])

In [32]:
# version 4: self-attention!
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels | note: we now use 32 channels instead of 2, meaning we have 32 features per token
x = torch.randn(B,T,C)

# let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)

# for every row of B we now get a matrix of size (T, T) giving us the interaction strength between tokens!
weights = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T) | note: we only transpose the last two dimensions, not the batch dimension!

lower_triangular = torch.tril(torch.ones(T, T))
#weights = torch.zeros((T,T))
weights = weights.masked_fill(lower_triangular == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)

v = value(x)
#weights = weights @ x
out = weights @ v # perform the weighted aggregation of the values

out.shape

torch.Size([4, 8, 16])

Looking at the last row of our first (T, T) matrix, we for example see what the 8th token has a high attention towards. For example, to the 4th token:

[0.0210, 0.0843, 0.0555, 4th: **_0.2297_**, 0.0573, 0.0709, 0.2423, 8th: **0.2391**]

In [33]:
# now every single batch element has a different interaction strength matrix! 
weights[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

Notes:
- Attention is a **communication mechanism**. For an autoregressive language model, this can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
  
<img src="assets/attention-graph-viz.svg" alt="attention-graph-viz" width="400"/>

- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an **"encoder"** attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additional divides `weights` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, weights will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below

Illustrating the effect of scaled attention:

In [34]:
import math

k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
weights_raw = q @ k.transpose(-2, -1)
weights_scaled = weights_raw / math.sqrt(head_size)

In [36]:
# print variances for each
print('k var', k.var())
print('q var', q.var())
print('weights_raw var', weights_raw.var())
print('weights_scaled var', weights_scaled.var())

k var tensor(1.0449)
q var tensor(1.0700)
weights_raw var tensor(17.4690)
weights_scaled var tensor(1.0918)


In [41]:
# demo with example data
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)

tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])

In [42]:
 # without scaling and big value differences, softmax gets too peaky focusing on max, and converges to one-hot encoded vectors
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=-1)

tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])