In [1]:
%pip install "torch>=2.1" numpy --quiet

Note: you may need to restart the kernel to use updated packages.


# Inspect the Shakespeare dataset

In [2]:
with open('shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [3]:
print(f"length of dataset in characters: {len(text)}")

length of dataset in characters: 1115395


In [4]:
print(text[:1000])


First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.


# Tokenization

In [5]:
# all the unique chars that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [6]:
# Implement a simple chaacter level tokenization schema. More sophisticated tokenizers include SentencePiece / tiktoken
stoi = { ch:i for i,ch in enumerate(chars) } # str to int mapping
itos = { i:ch for i,ch in enumerate(chars) } # int to str mapping
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

print(encode("transformers"))
print(decode(encode("transformers")))

[58, 56, 39, 52, 57, 44, 53, 56, 51, 43, 56, 57]
transformers


In [7]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000])

torch.Size([1115395]) torch.int64
tensor([ 0, 18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43,
        44, 53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52,
        63,  1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,
         1, 57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39,
        49,  6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15,
        47, 58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50,
        50,  1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1,
        58, 53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51,
        47, 57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43,
        42,  8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57,
        58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1,
        63, 53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56,
      

# Train-Test Split

In [8]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]
print(len(train_data), len(val_data))

1003855 111540


In [9]:
block_size = 8
train_data[:block_size+1]

tensor([ 0, 18, 47, 56, 57, 58,  1, 15, 47])

In [10]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"When input content is {context}, the targer is {target}")

When input content is tensor([0]), the targer is 18
When input content is tensor([ 0, 18]), the targer is 47
When input content is tensor([ 0, 18, 47]), the targer is 56
When input content is tensor([ 0, 18, 47, 56]), the targer is 57
When input content is tensor([ 0, 18, 47, 56, 57]), the targer is 58
When input content is tensor([ 0, 18, 47, 56, 57, 58]), the targer is 1
When input content is tensor([ 0, 18, 47, 56, 57, 58,  1]), the targer is 15
When input content is tensor([ 0, 18, 47, 56, 57, 58,  1, 15]), the targer is 47


In [11]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(split, batch_size, block_size):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train', batch_size, block_size)
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(batch_size):
    for t in range(batch_size):
        context = xb[b, :t+1]
        target =yb[b,t]
        print(f"when input is {context}, target is {target}")

inputs:
torch.Size([4, 8])
tensor([[59, 56,  6,  0, 24, 43, 58,  1],
        [50, 39, 47, 51,  1, 58, 46, 39],
        [47, 52, 45,  1, 58, 53,  1, 57],
        [14, 43, 47, 52, 45,  1, 46, 53]])
targets:
torch.Size([4, 8])
tensor([[56,  6,  0, 24, 43, 58,  1, 61],
        [39, 47, 51,  1, 58, 46, 39, 58],
        [52, 45,  1, 58, 53,  1, 57, 39],
        [43, 47, 52, 45,  1, 46, 53, 50]])
----
when input is tensor([59]), target is 56
when input is tensor([59, 56]), target is 6
when input is tensor([59, 56,  6]), target is 0
when input is tensor([59, 56,  6,  0]), target is 24
when input is tensor([50]), target is 39
when input is tensor([50, 39]), target is 47
when input is tensor([50, 39, 47]), target is 51
when input is tensor([50, 39, 47, 51]), target is 1
when input is tensor([47]), target is 52
when input is tensor([47, 52]), target is 45
when input is tensor([47, 52, 45]), target is 1
when input is tensor([47, 52, 45,  1]), target is 58
when input is tensor([14]), target is 43
w

# Bigram LM

In [12]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)


class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # the whole model is simply a square lookup table.
        # For each char (token) a in the whole char set, we maintain the probability of char b appearing after a.
        # So the size is vocab_size x vocab_size

        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx and targets are both (B, T) tensor of integers
        logits = self.token_embedding_table(idx)  # (B, T, C) -> Batch=4(batch size), Time=8(block_size), Channel=65(vocab_size)
        
        if targets is None:
            loss = None
        else:
            # reshape logits & targets into 2D to cater for F.corss_entropy
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens): # idx is (B, T) array of indices in the current context
        
        # generate max_new_tokens tokens iteratively by looking at only the last token each time
        for _ in range(max_new_tokens):
            
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx
    
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

torch.Size([32, 65])
tensor(4.6299, grad_fn=<NllLossBackward0>)


In [13]:
idx = torch.zeros((1,1), dtype=torch.long)
print(decode(m.generate(idx, max_new_tokens=100)[0].tolist()))


SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ


# Train Bigram LM

In [14]:
def train(model, steps, batch_size, block_size, lr=1e-3):
    optimizer = torch.optim.AdamW(model.parameters(), lr)
    for steps in range(steps):
        xb, yb = get_batch('train', batch_size, block_size)

        logits, loss = model(xb, yb)
        optimizer.zero_grad(set_to_none=True) # clear grads from the previous step
        loss.backward() # calculate grads for all params
        optimizer.step() # update params

    print(loss.item())

In [15]:
torch.manual_seed(1337)
batch_size = 32
train(m, 5000, batch_size, block_size)

2.572626829147339


In [16]:
idx = torch.zeros((1,1), dtype=torch.long)
print(decode(m.generate(idx, max_new_tokens=100)[0].tolist()))



gnd ous fo,

Thathoree w flat ze, dr q? vmy'NCIN stocowinsth,

TI! R:Cak.
EOUe.
T!
PUCENI herd toug


# Deriving self-attention

In [17]:
# consider this toy example batch:

torch.manual_seed(1337)
B, T, C = 4, 8, 2 # batch, time, channels
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

Currently, the bigram model is not communicating / paying attention at the (n-1, n-2, ... , 1)th tokens when predicting the (n+1)th token from the nth. Majority of the context info is lost.

We need to derive a mechanism for the model to attend to previous tokens when predicting the future token.


## Naive aggregation: averaging past tokens (weakest from of "communication")

### Version 1: by naive for loop

In [18]:
# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B, T, C)) # x bag of words
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1]
        xbow[b, t] = torch.mean(xprev, 0)
xbow

tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]],

        [[-0.6631, -0.2513],
         [ 0.1735, -0.0649],
         [ 0.1685,  0.3348],
         [-0.1621,  0.1765],
         [-0.2312, -0.0436],
         [-0.1015, -0.2855],
         [-0.2593, -0.1630],
         [-0.3015, -0.2293]],

        [[ 1.6455, -0.8030],
         [ 1.4985, -0.5395],
         [ 0.4954,  0.3420],
         [ 1.0623, -0.1802],
         [ 1.1401, -0.4462],
         [ 1.0870, -0.4071],
         [ 1.0430, -0.1299],
         [ 1.1138, -0.1641]]])

In [19]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a/ torch.sum(a, 1, keepdim=True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print(f'a={a}')
print(f'b={b}')
print(f'c={c}')

a=tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b=tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c=tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


### Version 2: by matrix multiplication

The operation of x[b,t] = mean_{i<=t} x[b,i] can be simplified & optimized by a row-normalized (each row sums to 1) lower triangular matric @ x

#### Version 2.1: get weights by dividing by row sum

In [20]:
torch.manual_seed(42)
wei = torch.tril(torch.ones(T, T)) # weight - the row-normalized lower triangular matrix
wei = wei / wei.sum(1, keepdim=True)

xbow2 = wei @ x # (B (auto broadcasted by torch), T, T) @ (B, T, C) --> (B, T, C)
print(f"wei: {wei}")
print(f"xbow2: {xbow2}")
torch.allclose(xbow, xbow2)

wei: tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
xbow2: tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.008

True

#### Version 2.2: get weights by softmax

by turning all zeros to -inf in a lower tri mat, then applying softmax to row, we can get the same weights
- another advantage of softmax is that it ensures **non negative weights**

In [21]:
wei = torch.tril(torch.ones(T, T)) # T by T lower tri mat
wei = wei.masked_fill(wei==0, float('-inf'))
wei

tensor([[1., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., 1., -inf, -inf, -inf],
        [1., 1., 1., 1., 1., 1., -inf, -inf],
        [1., 1., 1., 1., 1., 1., 1., -inf],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [22]:
wei = F.softmax(wei, dim=1)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

#### Implementing the averaging head module 

In [23]:
class AveragingHead(nn.Module):
    """ one head of naive aggregation """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        B, T, C = x.shape
        wei = torch.tril(torch.ones(T, T))
        wei = F.softmax(wei.masked_fill(wei==0, float('-inf')), dim=1) # T by T lower tri mat
        agg_x = wei @ x # (T, T) @ (B, T, C) --> (B, T, C)
        return agg_x

In [24]:
torch.manual_seed(500)
x = torch.rand(1, 3, 3)

ah = AveragingHead()
agg_x = ah(x)
print(f'x: {x}')
print(f'agg_x: {agg_x}')

x: tensor([[[0.5820, 0.1338, 0.7995],
         [0.3071, 0.6526, 0.6105],
         [0.1575, 0.6983, 0.7883]]])
agg_x: tensor([[[0.5820, 0.1338, 0.7995],
         [0.4446, 0.3932, 0.7050],
         [0.3489, 0.4949, 0.7328]]])


#### Adding aggregation head to Bigram LM

##### A few changes on BigramLanguageModelV2 from BigramLanguageModel
1. add position_embedding_table (along T / block_size axis) to capture positional info
2. parameterize n_embed in embedding tables to configure # dimensions of embedding vectors
3. add AggregationHead to establish the weakest form of communication between upper context of text
4. Now that we have implemented positional embedding, we cannot feed idx longer than block_size, else we will get index out of range when accessing the positional embedding table. idx is cropped to the last block_size block during generate()

In [25]:
class BigramLanguageModelWithAveragingHead(nn.Module):

    def __init__(self, vocab_size, n_embed):
        super().__init__()
        
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.head = AveragingHead()
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets are both (B, T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B, T, n_embed)
        pos_emb = self.position_embedding_table(torch.arange(T)) # (T, n_embed)
        x = tok_emb + pos_emb # (B, T, C)
        agg_x = self.head(x) # (B, C, C)
        logits = self.lm_head(agg_x)  # (B, T, vocab_size)
        
        if targets is None:
            loss = None
        else:
            # reshape logits & targets into 2D to cater for F.corss_entropy
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens): # idx is (B, T) array of indices in the current context
        
        # generate max_new_tokens tokens iteratively by looking at only the last token each time
        for _ in range(max_new_tokens):
            
            # crop idx to the last block_size tokens
            idx_cropped = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cropped)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [26]:
m1 = torch.compile(BigramLanguageModelWithAveragingHead(vocab_size, n_embed=32))
m1

OptimizedModule(
  (_orig_mod): BigramLanguageModelWithAveragingHead(
    (token_embedding_table): Embedding(65, 32)
    (position_embedding_table): Embedding(8, 32)
    (head): AveragingHead()
    (lm_head): Linear(in_features=32, out_features=65, bias=True)
  )
)

In [27]:
torch.manual_seed(1337)
train(m1, 5000, batch_size, block_size, lr=1e-3)

2.906489133834839


But the training result is worse than vanilla bigram, indicating averaging is a bad communication mechanism

In [28]:
idx = torch.zeros((1,1), dtype=torch.long)
print(decode(m1.generate(idx, max_new_tokens=100)[0].tolist()))


UCHERMNYA:I,

 rCSRhree
e flaf , Wadh chuvmoa,s t srboowe ntl,

hIa Rmicangeeke.
,,
PUOEAIFUTRN IOWT


### Version 4: self-attention

In [29]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

# let's see a single Head performing self-attention
head_size = 16 # output size of the Linear layers
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)
wei = q @ k.transpose(-2,-1) # (B, T, 16) @ (B, 16, T) --> (B, T, T)

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ x

out.shape

torch.Size([4, 8, 32])

In [30]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

#### Implementing self-attention head

In [31]:
class Head(nn.Module):
    """ one head of self-attention """
    def __init__(self, n_embed, head_size, block_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False) # (C, head_size)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x) # (B, T, C)
        q = self.query(x) # (B, T, C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2, -1) # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei *= C**-0.5 # Scaled attention
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=2) # (B, T, T)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B, T, C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

#### Adding self-attention to bigram

In [32]:
class BigramLanguageModelWithSelfAttentionHead(nn.Module):

    def __init__(self, vocab_size, n_embed):
        super().__init__()
        
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.self_attention_head = Head(n_embed, n_embed, block_size)
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets are both (B, T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B, T, n_embed)
        pos_emb = self.position_embedding_table(torch.arange(T)) # (T, n_embed)
        x = tok_emb + pos_emb # (B, T, C)
        x = self.self_attention_head(x) # single head self-attention (B, T, C)
        logits = self.lm_head(x)  # (B, T, vocab_size)
        
        if targets is None:
            loss = None
        else:
            # reshape logits & targets into 2D to cater for F.corss_entropy
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens): # idx is (B, T) array of indices in the current context
        
        # generate max_new_tokens tokens iteratively by looking at only the last token each time
        for _ in range(max_new_tokens):
            
            # crop idx to the last block_size tokens
            idx_cropped = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cropped)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [33]:
m2 = torch.compile(BigramLanguageModelWithSelfAttentionHead(vocab_size, n_embed=32))
m2

OptimizedModule(
  (_orig_mod): BigramLanguageModelWithSelfAttentionHead(
    (token_embedding_table): Embedding(65, 32)
    (position_embedding_table): Embedding(8, 32)
    (self_attention_head): Head(
      (key): Linear(in_features=32, out_features=32, bias=False)
      (query): Linear(in_features=32, out_features=32, bias=False)
      (value): Linear(in_features=32, out_features=32, bias=False)
    )
    (lm_head): Linear(in_features=32, out_features=65, bias=True)
  )
)

In [34]:
torch.manual_seed(1337)
train(m2, 5000, batch_size, block_size, lr=1e-3)

2.420104503631592


training loss is significantly lower than the previous 2 versions

In [35]:
idx = torch.zeros((1,1), dtype=torch.long)
print(decode(m2.generate(idx, max_new_tokens=1000)[0].tolist()))



Prn our fo,
A gat hree w fo f hat dr chevesars t sto the nto,
YTI shancange iem t, thand ther ater
Se m,
S:
Bawe th my:
O:
Ad onosnienkfe wick eanicth xjom
Ange wheso aly, sod pre Gd, sho hall.
Withalle fo.

whes nde shan ld
Th men ro toheroans, t
ISLAShor perre, dh troussa,
s. Wine.

Avised whto ter
He elovereivin ghad for te I!
Fimpgenck meas mecon. Whansat, has by blarnet fine barry iveam famirera matoke, ay bret hatout in ser!
Mous hainisithan meare dooutanth, tharave wan,
O NGRlo I ay degorse ardithatsilyofnom, dis,
Thory;
Ih wat th tuy an faveencolery mivomild, pre matt dee Gr
Wharf take ourgther ot, fi

ot
llbe the ler ary: FiciO Lowrens owrocan dor fr.

That ongatinof thesr pay y Cane you: than
Bind retel;
Gul noume sen:

JUSTor V:
s melove whakis ar tearem Mank sefat, dyot. Whef oubve fathan hpir l:
D We sthithise kned wove thilaw wimard ols the,
BO, st
Jlf the secth I takl feat burcerielert gt oo, plathound t-y mausam ney oorthe.

F: dos wilerdo osidoud eothal yea thirog ti

#### Implementing multi-headed self-attention

In [36]:
class MultiHead(nn.Module):
    """ multiple heads of self-attention """
    def __init__(self, num_heads, n_embed, head_size, block_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(n_embed, head_size, block_size) for _ in range(num_heads)])


    def forward(self, x):
        return torch.cat([h(x) for h in self.heads], dim=-1)

#### Adding multi-headed self-attention to bigram

In [37]:
class BigramLanguageModelWithMultiHeadedSelfAttention(nn.Module):

    def __init__(self, vocab_size, n_embed, num_heads):
        super().__init__()
        
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.self_attention_heads = MultiHead(num_heads, n_embed, n_embed//num_heads, block_size) # output dimension of MultiHead --> num_heads * _embed//num_heads = n_embed
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets are both (B, T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B, T, n_embed)
        pos_emb = self.position_embedding_table(torch.arange(T)) # (T, n_embed)
        x = tok_emb + pos_emb # (B, T, C)
        x = self.self_attention_heads(x) # single head self-attention (B, T, C)
        logits = self.lm_head(x)  # (B, T, vocab_size)
        
        if targets is None:
            loss = None
        else:
            # reshape logits & targets into 2D to cater for F.corss_entropy
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens): # idx is (B, T) array of indices in the current context
        
        # generate max_new_tokens tokens iteratively by looking at only the last token each time
        for _ in range(max_new_tokens):
            
            # crop idx to the last block_size tokens
            idx_cropped = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cropped)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [38]:
m3 = torch.compile(BigramLanguageModelWithMultiHeadedSelfAttention(vocab_size, n_embed=32, num_heads=4))
m3

OptimizedModule(
  (_orig_mod): BigramLanguageModelWithMultiHeadedSelfAttention(
    (token_embedding_table): Embedding(65, 32)
    (position_embedding_table): Embedding(8, 32)
    (self_attention_heads): MultiHead(
      (heads): ModuleList(
        (0-3): 4 x Head(
          (key): Linear(in_features=32, out_features=8, bias=False)
          (query): Linear(in_features=32, out_features=8, bias=False)
          (value): Linear(in_features=32, out_features=8, bias=False)
        )
      )
    )
    (lm_head): Linear(in_features=32, out_features=65, bias=True)
  )
)

In [39]:
torch.manual_seed(1337)
train(m3, 5000, batch_size, block_size, lr=1e-3)

2.334955930709839


In [40]:
idx = torch.zeros((1,1), dtype=torch.long)
print(decode(m3.generate(idx, max_new_tokens=1000)[0].tolist()))



PUERENY:
I,
A gat hree wofl fu, Wa:
Thy ves', cresto the now,
Youg hancangeeke.

And lit ther ate wis my
Surcake them sarabud ongingenkfe wick elficty ximm
An powe so all be disgen da shooush hing younto you whis nome hat nor so'se rove heroansten
I nou of perrt, du trousst,
st gine.

Dvired?

So ter
beeel asesivin ghakset. te I!
ENF:
Olom meds! For mast beat, has by bartnet fie? batry if?
man miresa matwes, an bret heacut inf mathus ifh weistuthape hin dooutarth, de Cave want nothtlarenay de of e and thatss coffor, dis,
Thow wore wat ther yiue cke mecelerth honcill, prepeatt dee Gut;
Anfirk'stos Mand you, wither
labe the lis 'sy: UCcie adwe.

K
BET:
By ou fromranto ongatinof thest pay youcen you: thaven homr teod
she mry bus me of and galds meleven hak ande teare.
MBak sefar, delt. Wheito bve fath and in your ge sabithis okned wove thiln cot and of nowe,
sinss lall the se they de lis waist cers lert gon of plath munst-yerdusampeey oortve.

Ford shis bret on dour eot su ye, Ving go i

# Remaining transformer components

## Feedforward layers of transformer block

In [41]:
class FeedForward(nn.Module):
    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, n_embed*4), # following "Attention is all you need" -> hidden layer size = 4 * input size
            nn.ReLU(),
            nn.Linear(n_embed*4, n_embed),
        )
        

    def forward(self, x):
        return self.net(x)

## The transformer block

For each block, we self attend, then feed forward, interspersed.


In [42]:
class Block(nn.Module):
    def __init__(self, n_embed, n_head):
        super().__init__()
        head_size = n_embed // n_head
        self.self_attention = MultiHead(n_head, n_embed, head_size, block_size)
        self.feed_forward = FeedForward(n_embed)

    def forward(self, x):
        x = self.self_attention(x)
        x = self.feed_forward(x)
        return x

## Residual Connections

In [43]:
class Block(nn.Module):
    def __init__(self, n_embed, n_head):
        super().__init__()
        head_size = n_embed // n_head
        self.self_attention = MultiHead(n_head, n_embed, head_size, block_size)
        self.feed_forward = FeedForward(n_embed)

    def forward(self, x):
        x = x + self.self_attention(x) # residual
        x = x + self.feed_forward(x) # residual
        return x

In [44]:
class MultiHead(nn.Module):
    
    def __init__(self, num_heads, n_embed, head_size, block_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(n_embed, head_size, block_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed) # linear transformation to project self-attention outputs back to the residual pathway


    def forward(self, x):
        x = torch.cat([h(x) for h in self.heads], dim=-1)
        x = self.proj(x)
        return x

## Layernorm

In [45]:
class Block(nn.Module):
    def __init__(self, n_embed, n_head):
        super().__init__()
        head_size = n_embed // n_head
        self.self_attention = MultiHead(n_head, n_embed, head_size, block_size)
        self.feed_forward = FeedForward(n_embed)
        self.layer_norm_1 = nn.LayerNorm(n_embed)
        self.layer_norm_2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x = x + self.self_attention(self.layer_norm_1(x)) # applies layer norm BEFORE attention (studies suggest this is better than the original architecture)
        x = x + self.feed_forward(self.layer_norm_2(x)) # applies layer norm BEFORE ffwd (studies suggest this is better than the original architecture
        return x

## dropout

In [46]:
class FeedForward(nn.Module):
    def __init__(self, n_embed, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, n_embed*4), # following "Attention is all you need" -> hidden layer size = 4 * input size
            nn.ReLU(),
            nn.Linear(n_embed*4, n_embed),
        )
        self.dropout = nn.Dropout(dropout)
        

    def forward(self, x):
        return self.net(x)

In [47]:
class Head(nn.Module):
    """ one head of self-attention """
    def __init__(self, n_embed, head_size, block_size, dropout):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout) # dropout

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)

        wei = q @ k.transpose(-2, -1)
        wei *= C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=2)
        wei = self.dropout(wei) # dropout
        
        v = self.value(x)
        out = wei @ v
        return out

In [48]:
class MultiHead(nn.Module):
    
    def __init__(self, num_heads, n_embed, head_size, block_size, dropout):
        super().__init__()
        self.heads = nn.ModuleList([Head(n_embed, head_size, block_size, dropout) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed) # linear transformation to project self-attention outputs back to the residual pathway


    def forward(self, x):
        x = torch.cat([h(x) for h in self.heads], dim=-1)
        x = self.proj(x)
        return x

In [49]:
class Block(nn.Module):
    def __init__(self, n_head, n_embed, block_size, dropout):
        super().__init__()
        head_size = n_embed // n_head
        self.self_attention = MultiHead(n_head, n_embed, head_size, block_size, dropout)
        self.feed_forward = FeedForward(n_embed, dropout)
        self.layer_norm_1 = nn.LayerNorm(n_embed)
        self.layer_norm_2 = nn.LayerNorm(n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = x + self.self_attention(self.layer_norm_1(x)) # applies layer norm BEFORE attention (studies suggest this is better than the original architecture)
        x = self.dropout(x)
        x = x + self.feed_forward(self.layer_norm_2(x)) # applies layer norm BEFORE ffwd (studies suggest this is better than the original architecture
        return x

# Building GPT

In [50]:
class GPT(nn.Module):

    def __init__(self, block_size, n_embed, n_head, n_block, dropout, device):
        super().__init__()
        self.device = device
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.register_buffer('positional_intervals', torch.arange(block_size, device=self.device))
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(
            *[Block(n_head, n_embed, block_size, dropout) for _ in range(n_block)]
        )
        self.ln_f = nn.LayerNorm(n_embed) # final layer norm
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets are both (B, T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(self.positional_intervals[:T]) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)  # (B, T, vocab_size)
        
        if targets is None:
            loss = None
        else:
            # reshape logits & targets into 2D to cater for F.corss_entropy
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens): # idx is (B, T) array of indices in the current context
        
        # generate max_new_tokens tokens iteratively by looking at only the last token each time
        for _ in range(max_new_tokens):
            
            # crop idx to the last block_size tokens
            idx_cropped = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cropped)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [51]:
def train(model, steps, batch_size, block_size, lr, eval_interval, eval_iters, device):
    @torch.no_grad()
    def estimate_loss():
        out = {}
        model.eval()
        for split in ['train', 'val']:
            losses = torch.zeros(eval_iters)
            for k in range(eval_iters):
                X, Y = get_batch(split,  batch_size, block_size)
                X, Y = X.to(device), Y.to(device)
                logits, loss = model(X, Y)
                losses[k] = loss.item()
            out[split] = losses.mean()
        model.train()
        return out

    optimizer = torch.optim.AdamW(model.parameters(), lr)
    for iter in range(steps):

        # every once in a while evaluate the loss on train and val sets
        if iter % eval_interval == 0:
            losses = estimate_loss()
            print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

        # sample a batch of data
        xb, yb = get_batch('train', batch_size, block_size)
        xb, yb = xb.to(device), yb.to(device)
        
        # evaluate the loss
        logits, loss = model(xb, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

In [52]:
# hyperparameters
batch_size = 64
block_size = 256
max_iters = 20000
eval_interval = 500
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
eval_iters = 200
n_embed = 160
n_head = 5
n_block = 5
dropout = 0.2
# ------------

babyGPT = GPT(block_size, n_embed, n_head, n_block, dropout, device)
babyGPT.to(device)
print(sum(p.numel() for p in babyGPT.parameters())/1e6, 'M parameters')


1.606145 M parameters


In [53]:
torch.manual_seed(1337)
train(babyGPT, max_iters, batch_size, block_size, learning_rate, eval_interval, eval_iters, device)

step 0: train loss 4.3397, val loss 4.3301
step 500: train loss 2.4637, val loss 2.4719
step 1000: train loss 2.3547, val loss 2.3718
step 1500: train loss 2.2414, val loss 2.2755
step 2000: train loss 2.1253, val loss 2.1733
step 2500: train loss 2.0056, val loss 2.0844
step 3000: train loss 1.9044, val loss 2.0143
step 3500: train loss 1.8153, val loss 1.9428
step 4000: train loss 1.7483, val loss 1.8954
step 4500: train loss 1.6949, val loss 1.8594
step 5000: train loss 1.6516, val loss 1.8243
step 5500: train loss 1.6170, val loss 1.7879
step 6000: train loss 1.5824, val loss 1.7640
step 6500: train loss 1.5572, val loss 1.7419
step 7000: train loss 1.5296, val loss 1.7198
step 7500: train loss 1.5132, val loss 1.7023
step 8000: train loss 1.4948, val loss 1.6882
step 8500: train loss 1.4728, val loss 1.6734
step 9000: train loss 1.4636, val loss 1.6648
step 9500: train loss 1.4457, val loss 1.6466
step 10000: train loss 1.4340, val loss 1.6336
step 10500: train loss 1.4220, val lo

In [67]:
idx = torch.zeros((1,1), dtype=torch.long, device=device)
print(decode(babyGPT.generate(idx, max_new_tokens=1000)[0].tolist()))


broathed's construns of that love, which are--

GLOUCESTER:
And he own cuntest of his hounds, poor country,-
With honest lose, he was it sincer.

NORTLANNE:
How not to says from his lord's guilt.
Nothing munis is toesth heme,
Rounce to nare have cold.

GLOUCESTER:
But how ip, nor morue way dear who spay, and I
Thou, somence that dath knowns, and for it the Claudio;
Thou clook, and, way to merry, hell a Volst's mock of
a grepard and the king of sorrower the treto the fhight
varter's habste son; a sunce of yet combriness are gone stouch,
my poor to-mooth, sir have, oil.
Now officery, proner, we men you come in the vient tower,
Know have a kised in succeite used
To Onch: no childs a parsuest, where your goodlish!

HENRY PELIZABOL:
Ay, a city molece took dayying worse more:
Camilonius long pose it; farewell.

QUEEN MARGARET:
Come, goes my ging trough, from you my head.

KING EDWE RICHARD IV:
Not-none. Boy: shief, my love,
Lord lord, art it is nothier are than he your had,
With hast him
Th

In [58]:
torch.save(babyGPT, f'babyGPT_{max_iters}_steps.pt')

# Closing notes

A few departures from the GPT models used in OpenAI, for simplicity's sake:
- In the feedforward layer, ReLU is used here instead of GELU
- a simple character level tokenizer is used here, while OpenAI uses the more sophisticated tiktoken