In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [3]:
txt = open('shakespeare.txt', 'r').read()
len(txt)

5447119

In [4]:
txt[0:500]

"                     1\n  From fairest creatures we desire increase,\n  That thereby beauty's rose might never die,\n  But as the riper should by time decease,\n  His tender heir might bear his memory:\n  But thou contracted to thine own bright eyes,\n  Feed'st thy light's flame with self-substantial fuel,\n  Making a famine where abundance lies,\n  Thy self thy foe, to thy sweet self too cruel:\n  Thou that art now the world's fresh ornament,\n  And only herald to the gaudy spring,\n  Within thine own bud"

In [5]:
chars = list(set(txt))
chars.sort()

ctoi = {c:i for i, c in enumerate(chars)}
itoc = {i:c for i, c in enumerate(chars)}
vocab_size = len(chars)

print("".join(chars))
print(vocab_size)


 !"&'(),-.0123456789:;<>?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]_`abcdefghijklmnopqrstuvwxyz|}
84


In [6]:
i =  math.floor(0.9 * len(txt))
train_txt = txt[0:i]
valid_txt = txt[i+1:]

len(train_txt), len(valid_txt)

(4902407, 544711)

In [7]:
train_tkns = [ctoi[c] for c in train_txt]
valid_tkns = [ctoi[c] for c in valid_txt]

In [8]:
block_size = 8
batch_size = 32

def txt_to_token(t):
    return [ctoi[c] for c in t]
    
# (B, L)
def random_batch():
    xi = torch.randint(0, len(train_tkns)-block_size, (batch_size,))
    x = torch.tensor([train_tkns[i:i+block_size] for i in xi])
    y = torch.tensor([train_tkns[i+1:i+block_size+1] for i in xi])
    
    return x, y

x, y = random_batch()
x.shape

torch.Size([32, 8])

In [9]:
x[0]

tensor([60,  1, 80, 70, 76,  1, 58, 56])

In [10]:
def eval_split(split, model):
    tkn = train_tkns if split == "train" else valid_tkns
    tkn = torch.tensor(tkn)

    bsize = len(tkn) // block_size
    x = tkn[0:bsize*block_size]
    y = tkn[1:bsize*block_size+1]
    x = x.view(bsize, block_size) # (B, L)
    y = y.view(bsize, block_size) # (B, L)
    
    logits = model(x) # (B, L, C)
    B, L, C = logits.shape
    loss = F.cross_entropy(logits.view(B*L, C), y.view(B*L))
    
    return loss.item()

In [11]:
def sample(model):
    max_len = 500
    tks = [0]*block_size

    for i in range(max_len):
        ctx = torch.tensor(tks[i:i+block_size]) # (L)
        ctx = ctx.view(1, -1) # (B, L)

        logits = model(ctx) # (B, L, C)
        probs = F.softmax(logits, dim=2) # (B, L, C)
        probs = probs[0,-1,:] # (C), # the last in the sequence is the newly generated
        yi = torch.multinomial(probs, 1)
        tks.append(yi.item())

    chars = [itoc[t] for t in tks]
    return "".join(chars)

In [12]:
def fit(model):
    optim = torch.optim.Adam(model.parameters())
    
    for i in range(50000):
        optim.zero_grad()

        xb, yb = random_batch()
        logits = model(xb) # (B, L, C)

        B, L, C = logits.shape
        loss = F.cross_entropy(logits.view(B*L, C), yb.view(B*L))

        loss.backward()
        optim.step()

        if i % 5000 == 0:
            print(f"{loss.item():.4f}")

In [18]:
# Attension mechanism
head_size = 5
x = torch.rand(4, block_size, vocab_size) # (B, L, C)

key = nn.Linear(vocab_size, head_size)
query = nn.Linear(vocab_size, head_size)
value = nn.Linear(vocab_size, head_size)

k = key(x)   # (B, L, C)
q = query(x) # (B, L, C)
v = value(x) # (B, L, C)

q = q.permute(0, 2, 1) # (B, C, L)
w = k @ q # (B, L, L)

B, L, C = k.shape
mask = torch.tril(torch.ones(L, L))
mask = mask == 0
w = w.masked_fill(mask, -float('inf'))

prob = F.softmax(w, dim=2) # (B, L, L) TODO: not sure which dimension
a = prob @ v # (B, L, L) @ # (B, L, C)
a.shape # (B, L, C)

torch.Size([4, 8, 5])

In [16]:
a = torch.randn(3, 3, 4)
b = torch.randn(3, 4, 5)
c = a @ b
c.shape

torch.Size([3, 3, 5])

In [48]:
x = torch.randn(2, 3)
torch.cat((x, x, x), 0).shape

torch.Size([6, 3])

In [62]:
2**-2

0.25

In [69]:
class MultiHeadAttension(nn.Module):    
    
    def __init__(self, config):
        super().__init__()
        self.c = config
        self.attn = nn.Linear(in_size, 3 * self.c.head_num * shead_size, bias=False)
        self.ffn = nn.Linear(head_num * head_size, out_size, bias=False)
    
        
    # x: (B, L, C)  
    # return: (B, L, C')
    def forward(self, x):
        B, L, C = x.shape
        
        z = self.attn(x) # (B, L, 3 * hn * hs)
        k, q, v = torch.split(z, self.head_num * self.head_size, dim=2) # (B, L, hn * hs)
        
        k = k.view(B, L, self.head_num, self.head_size).permute(0, 2, 1, 3) # (B, hn, L, hs)
        q = q.view(B, L, self.head_num, self.head_size).permute(0, 2, 1, 3)
        v = v.view(B, L, self.head_num, self.head_size).permute(0, 2, 1, 3)
        
        q = q.permute(0, 1, 3, 2) # (B, hn, hs, L)
        attn = (k @ q) / self.head_size**0.5 # (B, hn, L, L)
        mask = torch.tril(torch.ones(L, L)) == 0
        attn = attn.masked_fill(mask, -float('inf')) # (B, hn, L, L)
        attn = F.softmax(attn, dim=3)
        
        y = attn @ v # (B, hn, L, hs)
        y = y.permute(0, 2, 1, 3) # (B, L, hn, hs)
        y = y.contiguous().view(B, L, -1) # (B, L, hn * hs)
        y = self.ffn(y) # (B, L, C)
        
        return y 
    
        
x = torch.randn(2, block_size, 9) # (B, L, C)
mh = MultiHeadAttension(5, 3, 9, 7)
mh(x).shape

torch.Size([2, 8, 7])

In [22]:
# return (L, C)
def pos_encoding(L, C):
    pos = torch.arange(0, L).view(-1, 1)
    d = torch.pow(10000, 2 * torch.arange(0, C) / C)
    u = pos * d

    pe[:,0::2] = torch.sin(u[:,0::2])
    pe[:,1::2] = torch.cos(u[:,0::2])
    return pe

In [54]:
hidden_size = 100
head_size = 20
head_num = 5

class Transformer(nn.Module):    
    
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.mheads = MultiHeadAttension(head_num, head_size, hidden_size, hidden_size)
        self.ffn1 = nn.Linear(hidden_size, hidden_size)
        self.ffn2 = nn.Linear(hidden_size, vocab_size)

    # (B, L) -> (B, L, C)
    def forward(self, x):
        y = self.embed(x) # (B, L, C)
        y = self.mheads(y) # (B, L, C)
        y = self.ffn1(y)
        y = torch.relu(y)
        y = self.ffn2(y)
        
        return y

In [55]:
tf_model = Transformer()

In [56]:
fit(tf_model)

4.4592
2.2921
2.0884
1.9356
2.2011
1.9967
2.0981
1.9260
2.1591
1.9390


In [None]:
tr_loss = eval_split("train", tf_model)
va_loss = eval_split("valid", tf_model)

print(f"train: {tr_loss:.4f}")
print(f"valid: {va_loss:.4f}")

train: 2.0446
valid: 2.0999


In [80]:
print(sample(tf_model))









     Ask nto w aithard n tobe bovef n oouche mft orourte and e ares.
  THE
AN. F  wowoulld te.
   Or F Co, HE HENESDENA. MA; cacte.


        Ato bth r fif y o; 
 Upenay trsisirvon; t toturunt bt terothif; ir, idigaut watither.
  Howith im vee r ie he r.  fayicar Gorthibree ghtthemsthat wil.
   N.   dithelus  with with cofr Algoolol e thours tis bergrema ht
   willl fathithy h's imang t utay as whansped han wicof inquouaiped t,
   BANDMAND. TIS. I'eavote
      Shino shicalisw-I d bilorese am'nd 


## Log

- Bi-gram: 2.4716, 2.4755
- Single-head attention: 2.3899, 2.4041
- Multi-head attention, single layer: 2.0820, 2.1165
    - block_size = 8
    - hidden_size = 100
    - head_size = 20
    - head_num = 5