In [1]:
import random

import torch
import torch.nn as nn
import numpy as np

## Transformer
![transformer.png](transformer.png)

In [2]:
class FeedForward(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        d_model = config["d_model"]
        d_ff = 4*d_model
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.linear_2 = nn.Linear(d_ff, d_model)
        self.gelu = nn.GELU()
        
    def forward(self, x):
        x = self.gelu(self.linear_1(x))
        return self.linear_2(x)

In [3]:
class MaskedMultiHeadAttention(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        d_model, n_heads = config["d_model"], config["n_heads"]
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.n_heads = n_heads
        
    def forward(self, x):
        q = self.W_q(x)
        k = self.W_k(x)
        v = self.W_v(x)
        
        n_heads = self.n_heads
        B, N, D = x.shape
        q = q.reshape((B, N, n_heads, D // n_heads)).transpose(1, 2)
        k = k.reshape((B, N, n_heads, D // n_heads)).transpose(1, 2)
        v = v.reshape((B, N, n_heads, D // n_heads)).transpose(1, 2)
        
        # Dot product of keys and queries
        attn = torch.einsum('bhnd,bhdk->bhnk', q, k.transpose(2, 3)) / (D // n_heads)**0.5
        
        # auto-regressive mask
        attn = torch.tril(attn)
        attn[attn==0] = -torch.inf
        
        attn = torch.softmax(attn, dim=-1)
        
        # Dot product of attention matrix and values
        v = torch.einsum('bhnn,bhnd->bhnd', attn, v)
        v = torch.reshape(v, (B, N, D))
        
        return self.W_o(v)

In [4]:
class Residual(nn.Module):
    
    def __init__(self, transform):
        super().__init__()
        self.transform = transform
        self.dropout = nn.Dropout(p=0.1)
    
    def forward(self, x):
        return self.dropout(self.transform(x)) + x
        

In [5]:
class EncoderBlock(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        mha = MaskedMultiHeadAttention(config)
        feed_forward = FeedForward(config)
        
        d_model = config['d_model']
        self.f = nn.Sequential(
                     Residual(mha),
                     nn.LayerNorm(d_model),
                     Residual(feed_forward),
                     nn.LayerNorm(d_model)
                     )
        
    def forward(self, x):
        return self.f(x)

In [6]:
class Embedding(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        self.embedding = nn.Embedding(config["vocab_size"], config["d_model"], padding_idx=0) 
        
    def forward(self, x):
        return self.embedding(x) 

In [7]:
class Transformer(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        self.embedding = Embedding(config)
        self.encoder_blocks = nn.ModuleList([EncoderBlock(config) for _ in range(config["n_layers"])])
        self.linear = nn.Linear(config['d_model'], config['vocab_size'])
        
    def forward(self, x):
        x = self.embedding(x)
        
        for block in self.encoder_blocks:
            x = block(x)
            
        return self.linear(x)
        

## Test with random input

In [8]:
config ={"d_model": 768,
         "n_heads": 12,
         "n_layers": 2,
         "vocab_size": 4000}

N = 512
x = torch.randint(config["vocab_size"], (8, N))

model = Transformer(config)
x.shape, model(x).shape

(torch.Size([8, 512]), torch.Size([8, 512, 4000]))

## Number of parameters

In [9]:
num_of_params = 0
for param in model.parameters():
    num_of_params += param.numel()

print(f"Number of params: {num_of_params:}")

Number of params: 20323744


## Prepare data

In [10]:
with open('train.txt', 'r') as f:
    data = f.read()
    data = data.lower()
    
data = data.split('\n')

# Remove blank lines
data = list(filter(lambda x: x != "", data))

### Character-level data

In [11]:
# get all characters
characters = [" "]
for line in data:
    for word in line.split():
        characters.extend(list(word))
    
special_tokens = ['<PAD>', '<START>', '<END>']
vocab = special_tokens + list(set(characters))
token_to_ids = {token: i for i, token in enumerate(vocab)}
ids_to_token = {i: token for i, token in enumerate(vocab)}

In [12]:
# These are the only characters the model knows about
print(vocab), len(vocab)

['<PAD>', '<START>', '<END>', 'm', 'h', ')', 'e', 'z', 'd', '6', 'w', 'o', '[', 'v', 'x', '8', 'f', '.', 'y', '-', '?', ']', '3', "'", 'u', '"', 'c', 'l', ' ', 'n', '1', 'a', 't', 'k', '2', 'i', '9', '0', 'g', 's', '7', 'b', ',', '4', 'p', '(', '5', 'j', 'r']


(None, 49)

In [13]:
train_data = []
start_token_id = token_to_ids['<START>']
end_token_id = token_to_ids['<END>']
for line in data:
    tokens = list(map(token_to_ids.get, line))
    train_data.append([start_token_id] + tokens + [end_token_id])

In [15]:
MAX_LEN = len(max(train_data, key=len))
for x in train_data:
    x.extend([0] * (MAX_LEN - len(x)))

In [16]:
def get_batch(batch_size):
    "Prepare a batch of training data"
    
    # sample random batch
    idx = random.choices(range(len(train_data)), k=batch_size)
    batch = [train_data[i] for i in idx]
    
    # shift to right to get labels
    batch = torch.LongTensor(batch)
    X = batch[:, :-1]
    Y = batch[:, 1:]
    
    return X, Y

In [17]:
def tokens_to_sentence(tokens):
    if isinstance(tokens, torch.Tensor):
        tokens = tokens.tolist()
    return "".join(list(map(ids_to_token.get, tokens)))

In [22]:
# print some sample data
tokens_to_sentence(get_batch(1)[0][0])

'<START>cold inside my arms you are<END><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD>'

## Let's train charGPT

In [23]:
config ={"d_model": 128,
         "n_heads": 8,
         "n_layers": 4,
         "vocab_size": len(vocab)}

gpt = Transformer(config)
cross_entropy = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(gpt.parameters(), lr=1e-5)

In [92]:
batch_size = 8
num_steps = 10_000

for step in range(num_steps):
    X, Y = get_batch(batch_size)
    logits = gpt(X)
    loss = cross_entropy(logits.flatten(0, 1), Y.flatten())
    
    if step % 50 == 0:
        # print predicted sequence
        preds = torch.argmax(logits, dim=-1)
        preds = preds[0].tolist()
        preds = preds[:Y[0].tolist().index(2) + 1]

        print(f"Step = {step}, loss = {loss.item():.2f}, {tokens_to_sentence(preds)}")
    
    loss.backward()
    optimizer.step()

Step = 0, loss = 4.11, 82.3]k6lf7gix.?]gl6)2f]g38g22.32j6
Step = 50, loss = 3.74, w<START>2h6ghn g9269n ghn n2h
Step = 100, loss = 3.42,  ?li?l otfu<START> thn t uu 
Step = 150, loss = 3.14,  ks t ntn i u eht usoeu   
Step = 200, loss = 3.09, s u eht uu t u  
Step = 250, loss = 2.85, s eur  tnt n  tnhtr t n et u   t ue eu e
Step = 300, loss = 2.62, s ero u  t er  u t n  tn t ute 
Step = 350, loss = 2.64, int u t ur teunr tu a re  tnt u t ur t en n tn trteeu u  ree
Step = 400, loss = 2.64, so  te true sn 
Step = 450, loss = 2.68, i e   theu  eha  ueae   tu tn a ur aon  te  n  
Step = 500, loss = 2.46, a ne ae trt onee  
Step = 550, loss = 2.52, ahe ta   en  t r      ter   the m erllaeur  tal r   
Step = 600, loss = 2.50, aou teue tn t re  
Step = 650, loss = 2.56, int   etrteeuhu  ree
Step = 700, loss = 2.44, inthe u trt n  u tue  t n  tr  t h etheer  e
Step = 750, loss = 2.56, inht thne t u tou thnt u 
Step = 800, loss = 2.38, s   gtnhtn tourete n h
Step = 850, loss = 2.56, s    tnht e