In [37]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [38]:
device = torch.device("cpu")
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")

device

device(type='mps')

In [39]:
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [40]:
txt = open('input.txt', 'r').read()
len(txt)

1115394

In [41]:
txt[0:500]

"First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you know Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us kill him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be done: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor"

In [42]:
chars = list(set(txt))
chars.sort()

ctoi = {c:i for i, c in enumerate(chars)}
itoc = {i:c for i, c in enumerate(chars)}
vocab_size = len(chars)

print("".join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [43]:
i =  math.floor(0.9 * len(txt))
train_txt = txt[0:i]
valid_txt = txt[i+1:]

len(train_txt), len(valid_txt)

(1003854, 111539)

In [44]:
train_tkns = [ctoi[c] for c in train_txt]
valid_tkns = [ctoi[c] for c in valid_txt]

In [45]:
block_size = 256
batch_size = 64

def txt_to_token(t):
    return [ctoi[c] for c in t]
    
# (B, L)
def random_batch(split="train", batch_size=32):
    data = train_tkns if split == "train" else valid_tkns
    
    xi = torch.randint(0, len(data)-block_size, (batch_size,))
    x = torch.tensor([data[i:i+block_size] for i in xi], device=device)
    y = torch.tensor([data[i+1:i+block_size+1] for i in xi], device=device)
    
    x = x.to(device)
    y = y.to(device)
    
    return x, y

x, y = random_batch()
x.shape

torch.Size([32, 256])

In [46]:
x[0]

tensor([46, 63, 57, 43, 50, 44,  1, 57, 54, 43, 39, 49,  1, 61, 43, 50, 50, 10,
         1, 44, 53, 53, 50,  6,  1, 42, 53,  1, 52, 53, 58,  1, 44, 50, 39, 58,
        58, 43, 56,  8,  0, 25, 63,  1, 41, 53, 52, 57, 41, 47, 43, 52, 41, 43,
         1, 46, 39, 58, 46,  1, 39,  1, 58, 46, 53, 59, 57, 39, 52, 42,  1, 57,
        43, 60, 43, 56, 39, 50,  1, 58, 53, 52, 45, 59, 43, 57,  6,  0, 13, 52,
        42,  1, 43, 60, 43, 56, 63,  1, 58, 53, 52, 45, 59, 43,  1, 40, 56, 47,
        52, 45, 57,  1, 47, 52,  1, 39,  1, 57, 43, 60, 43, 56, 39, 50,  1, 58,
        39, 50, 43,  6,  0, 13, 52, 42,  1, 43, 60, 43, 56, 63,  1, 58, 39, 50,
        43,  1, 41, 53, 52, 42, 43, 51, 52, 57,  1, 51, 43,  1, 44, 53, 56,  1,
        39,  1, 60, 47, 50, 50, 39, 47, 52,  8,  0, 28, 43, 56, 48, 59, 56, 63,
         6,  1, 54, 43, 56, 48, 59, 56, 63,  6,  1, 47, 52,  1, 58, 46, 43,  1,
        46, 47, 45, 46,  5, 57, 58,  1, 42, 43, 45, 56, 43, 43,  0, 25, 59, 56,
        42, 43, 56,  6,  1, 57, 58, 43, 

In [47]:
@torch.no_grad()
def estimate_loss(model, batch_size=None):
    model.eval()
    losses = []
    
    for split in ["train", "valid"]:    
        x, y = random_batch(split, batch_size)
        logits = model(x) # (B, L, C)
        B, L, C = logits.shape
        loss = F.cross_entropy(logits.view(B*L, C), y.view(B*L))
        losses.append(loss.item())
        
    model.train()
    return losses

In [48]:
@torch.no_grad()
def sample(model):
    model.eval()
    
    max_len = 500
    tks = [0]

    for i in range(max_len):
        ctx = torch.tensor(tks[i:i+block_size]) # (L)
        ctx = ctx.view(1, -1) # (B, L)

        logits = model(ctx) # (B, L, C)
        probs = F.softmax(logits, dim=2) # (B, L, C)
        probs = probs[0,-1,:] # (C), # the last in the sequence is the newly generated
        yi = torch.multinomial(probs, 1)
        tks.append(yi.item())

    chars = [itoc[t] for t in tks]
    model.train()
    return "".join(chars)

In [49]:
class MultiHeadAttension(nn.Module):    
    
    def __init__(self, head_num, head_size, in_size, out_size):
        super().__init__()
        
        self.head_size = head_size
        self.head_num = head_num        
        self.attn = nn.Linear(in_size, 3 * head_num * head_size, bias=False)
        self.ffn = nn.Linear(head_num * head_size, out_size, bias=False)
    
        
    # x: (B, L, C)  
    # return: (B, L, C')
    def forward(self, x):
        B, L, C = x.shape
        
        z = self.attn(x) # (B, L, 3 * hn * hs)
        k, q, v = torch.split(z, self.head_num * self.head_size, dim=2) # (B, L, hn * hs)
        
        k = k.view(B, L, self.head_num, self.head_size).permute(0, 2, 1, 3) # (B, hn, L, hs)
        q = q.view(B, L, self.head_num, self.head_size).permute(0, 2, 1, 3)
        v = v.view(B, L, self.head_num, self.head_size).permute(0, 2, 1, 3)
        
        q = q.permute(0, 1, 3, 2) # (B, hn, hs, L)
        attn = (k @ q) / self.head_size**0.5 # (B, hn, L, L)
        mask = torch.tril(torch.ones(L, L)) == 0
        mask = mask.to(device)
        attn = attn.masked_fill(mask, -float('inf')) # (B, hn, L, L)
        attn = F.softmax(attn, dim=3)
        
        y = attn @ v # (B, hn, L, hs)
        y = y.permute(0, 2, 1, 3) # (B, L, hn, hs)
        y = y.contiguous().view(B, L, -1) # (B, L, hn * hs)
        y = self.ffn(y) # (B, L, C)
        
        return y 
    
        
x = torch.randn(2, block_size, 9) # (B, L, C)
x = x.to(device)
mh = MultiHeadAttension(5, 3, 9, 7)
mh = mh.to(device)
mh(x).shape

torch.Size([2, 256, 7])

In [50]:
class MLP(nn.Module):
    
    def __init__(self, in_size, out_size):
        super().__init__()
        self.linear1 = nn.Linear(in_size, out_size)
        self.linear2 = nn.Linear(out_size, out_size)
    
    # (B, L, C)
    def forward(self, x):
        y = self.linear1(x)
        y = torch.relu(y)
        y = self.linear2(y)
        
        return y

In [51]:
class Block(nn.Module):    
    
    def __init__(self, emb_size, head_size):
        super().__init__()
        
        assert emb_size % head_size == 0
        head_num = emb_size // head_size
        
        self.mha = MultiHeadAttension(head_num, 
                                      head_size, 
                                      in_size=emb_size, 
                                      out_size=emb_size)
        self.lnorm1 = nn.LayerNorm(emb_size)
        self.lnorm2 = nn.LayerNorm(emb_size)
        self.ffn = MLP(emb_size, emb_size)
        
        
    # x: (B, L, emb)
    def forward(self, x):
        y = self.mha(x) + x
        y = self.lnorm1(y)
        y = self.ffn(y) + y
        y = self.lnorm2(y)
        return y
    
# x = torch.randn(3, 4, 10)
# b = Block(10, 2)
# b(x)

In [52]:
# return (L, C)
def pos_encoding(x):
    B, L, C = x.shape
    pos = torch.arange(0, L).view(-1, 1) # (L, 1)
    div = 2 * torch.arange(0, C) / C # (C)
    div = torch.pow(10000, div) # (C)
    e = pos / div
    pe = torch.zeros(L, C)
    pe[:,0::2] = torch.sin(e[:,0::2])
    pe[:,1::2] = torch.cos(e[:,1::2])
    
    pe = pe.to(device)
    return pe

In [53]:
emb_size = 384
head_size = 64

class Transformer(nn.Module):    
    
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, emb_size)
        self.blocks = nn.Sequential(
            Block(emb_size, head_size),
            Block(emb_size, head_size),
            Block(emb_size, head_size),
        )
        self.linear = nn.Linear(emb_size, vocab_size)

    # (B, L) -> (B, L, C)
    def forward(self, x):
        y = self.embed(x) # (B, L, emb)
        y = y + pos_encoding(y) # (B, L, emb)
        y = self.blocks(y) # (B, L, emb)
        y = self.linear(y) # (B, L, vocab)
        
        return y

In [54]:
model = Transformer()
model = model.to(device)
optim = torch.optim.Adam(model.parameters())

count = sum([p.numel() for p in model.parameters()])
print(f"total parameter: {count}")

total parameter: 2711105


In [56]:
%%time

epoch = 50000
eval_interval = 5000
eval_size = 500
lossi = []

model.train()

for i in range(epoch):
    optim.zero_grad()

    xb, yb = random_batch()
    logits = model(xb) # (B, L, C)

    B, L, C = logits.shape
    loss = F.cross_entropy(logits.view(B*L, C), yb.view(B*L))
    loss.backward()
    optim.step()

    if i % eval_interval == 0 or i == epoch-1:
        tr, va = estimate_loss(model, eval_size)
        lossi.append((tr, va))
        print(f"{i:5d}/{epoch}: {tr:.4f}  {va:.4f}")

    break

    0/50000: 3.9321  3.9498
CPU times: user 252 ms, sys: 16.2 s, total: 16.4 s
Wall time: 58.2 s


In [55]:
tr_loss, va_loss = estimate_loss(model, 10000)

print(f"train: {tr_loss:.4f}")
print(f"valid: {va_loss:.4f}")

train: 1.7035
valid: 1.8571


In [67]:
print(sample(model))


See tholad bris.
TARowig.


QUS:
Pe s t r y ist me Ber kespr ald,
INGABad d An:
Whag be m, s wongin.

SCld p anagenolees t tins hor t awe!
CORS:
Win wen:
Wof weiorotisstofo ths gh

M: athely nsw m ig'r blleaxle int pis;
Mys,
Bare hiong, be;
Th sharsand thilifl hy t lay berakng? w calaveer t mend l bararethiver wh ckeee ke my ouindes
An. orrancho thy n wit t wode ce? spinningnaleend re Jus cul le,
Wh w ICKI's pey's s ark s and,
D: Vodil thr m My santoug aid or f br wove
In.
O iowis hedellmsshile 


## Log

- Bi-gram: 2.4716, 2.4755
- Single-head attention: 2.3899, 2.4041
- Multi-head attention, single layer: 2.0820, 2.1165
- Multi-head attention, single layer, positional encoding: 1.8575, 1.9216
- 2-layer transformer (with everything, MHA, positional encoding, layer norm): 1.7155, 1.7952