In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
with open("./data/byron.txt", "r", encoding='utf-8') as f:
    text = f.read()
len(text)

2238662

In [3]:
vocab = sorted(list(set(text)))
vocab_size = len(vocab)

In [4]:
stoi = {ch:i for i,ch in enumerate(vocab)}
itos = {i:ch for i,ch in enumerate(vocab)}
encode = lambda s: [stoi[ch] for ch in s]
decode = lambda l: "".join([itos[i] for i in l])
encode('Hi there !')
decode([20, 47, 1, 58, 46, 43, 56, 43, 1, 2])

'4O ZNKXK !'

In [5]:
data = torch.tensor(encode(text), dtype=torch.long)
n = (int)(len(data) * 0.9)
data_train = data[:n]
data_val = data[n:]
len(data) == len(data_train) + len(data_val)

True

In [6]:
#block_size = 8
#batch_size = 4
#max_iterations = 3000
#eval_interval = 300
#eval_iterations = 200
#learning_rate = 1e-2
#device = 'cuda' if torch.cuda.is_available() else 'cpu'
#n_embed = 32
#head_size = 32
#num_heads = 4
#n_layers = 4
#dropout = 0.2

device = 'cuda' if torch.cuda.is_available() else 'cpu'
block_size = 256
batch_size = 64
max_iterations = 5000
eval_iterations = 200
eval_interval = 500
learning_rate = 3e-4
n_embed = 384
num_heads = 6
n_layers = 6
dropout = 0.2

In [7]:
device

'cuda'

In [8]:
def get_batch(split):
    data = data_train if split == 'train' else data_val
    ix = torch.randint(len(data) - block_size, [batch_size,])
    xb = torch.stack([data[i:i+block_size] for i in ix])
    yb = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = xb.to(device), yb.to(device)
    return x,y
    

In [9]:
x, y = get_batch('train')

In [10]:
print(x)
print(y)

tensor([[ 3,  1,  1,  ...,  1, 82, 71],
        [68, 81, 72,  ..., 78,  1, 75],
        [60, 81,  3,  ...,  1, 83, 78],
        ...,
        [68,  1, 64,  ..., 78, 84, 82],
        [81,  1, 68,  ..., 83, 86, 72],
        [82,  1, 70,  ...,  1, 80, 84]], device='cuda:0')
tensor([[ 1,  1,  1,  ..., 82, 71, 81],
        [81, 72, 83,  ...,  1, 75, 64],
        [81,  3, 12,  ..., 83, 78,  1],
        ...,
        [ 1, 64, 81,  ..., 84, 82,  1],
        [ 1, 68, 64,  ..., 86, 72, 77],
        [ 1, 70, 64,  ..., 80, 84, 72]], device='cuda:0')


In [11]:
class SingleHeadAttention(nn.Module):
    def __init__(self, head_size, latent_head_size):
        super().__init__()
        self.key_latent = nn.Linear(n_embed, latent_head_size, bias=False)
        self.key = nn.Linear(latent_head_size, head_size, bias=False)
        
        self.value_latent = nn.Linear(n_embed, latent_head_size, bias=False)
        self.value = nn.Linear(latent_head_size, head_size, bias=False)

        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x): #BTE
        B,T,E = x.shape

        k_latent = self.key_latent(x) #BTE @ EL -> BTL
        v_latent = self.value_latent(x) #BTE @ EL -> BTL

        k = self.key(k_latent) #BTL @ LH -> BTH
        v = self.value(v_latent) #BTL @ LH -> BTH

        q = self.query(x) #BTH
            
        wei = k @ q.transpose(-2, -1) * E ** -0.5 #BTH @ BHT -> BTT
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = wei.softmax(dim=-1)
        wei = self.dropout(wei)
            
        out = wei @ v # BTT @ BTH -> BTH
        return out    

In [12]:
class MultiheadLatentAttention(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        dim_head = n_embed // num_heads
        dim_latent_head = dim_head // 2
        self.multiheads  = nn.ModuleList([SingleHeadAttention(dim_head, dim_latent_head) 
                                          for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        out = torch.cat([h(x) for h in self.multiheads], dim=-1)  
        out = self.dropout(self.proj(out))
        return out

In [13]:
class Feedforward(nn.Module):
    def __init__(self):
        super().__init__()
        self.ffwd = nn.Sequential( 
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout))
        
    def forward(self, x):
        out = self.ffwd(x)
        return out

In [14]:
class Block(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        self.multiheads = MultiheadLatentAttention(num_heads)
        self.ffwd = Feedforward()
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)
        
    def forward(self, x):
        x = x + self.multiheads(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [15]:
class MLA_Bitsy(nn.Module):
    def __init__(self):
        super().__init__()
        self.tok_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.pos_embedding_table = nn.Embedding(block_size, n_embed)
        self.layers = nn.Sequential(*[Block(num_heads) for _ in range(n_layers)])
        self.ln = nn.LayerNorm(n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size)
        
    def forward(self, idx, targets=None):
        B,T = idx.shape
        tok_embed = self.tok_embedding_table(idx) #BTE
        pos_embed = self.pos_embedding_table(torch.arange(T, device=device))
        x = tok_embed + pos_embed
        x = self.layers(x)
        x = self.ln(x)
        logits = self.lm_head(x) #BTE @ EV -> BTV
        
        if targets == None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    @torch.no_grad()
    def estiamte_loss(self):
        out = {}
        self.eval()
        for split in ['train', 'val']:
            losses = torch.zeros(eval_iterations)
            for k in range(eval_iterations):
                x, y = get_batch(split)
                logits, loss = self(x, y)
                losses[k] = loss.item()
            out[split] = losses.mean()
        self.train()
        return out        
    
    @torch.no_grad()
    def generate(self, idx, max_tokens=300):
        for _ in range(max_tokens):
            idx_cond = idx[:, -block_size:]
            logits, losss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=-1)
        return idx
    
    def get_num_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def persistModel(self):
        path = 'model/mla-bitsy'
        torch.save(self.state_dict(), path)
        return path
    
    def loadModel(self):
        path = 'model/mla-bitsy'
        self.load_state_dict(torch.load(path, map_location='cuda'))

In [None]:
torch.manual_seed(1337)
m = MLA_Bitsy()
model = m.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
for k in range(max_iterations):
    if k % eval_interval == 0:
        out = model.estiamte_loss()
        print(f"iter: {k}, train_loss: {out['train']:.4f}, val_loss: {out['val']:.4f}")
    x, y = get_batch('train')
    logits, loss = model(x, y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(f"iter: {k}, train_loss: {out['train']:.4f}, val_loss: {out['val']:.4f}")

model.persistModel()

In [None]:
idx = torch.zeros((1,1), dtype=torch.long, device=device)
print(decode(model.generate(idx, 500)[0].tolist()))

In [70]:
torch.tril(torch.ones(block_size, block_size))

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [91]:
B,T,E = 1, 4, 4
H = 4
key = torch.randint(0, 3, (E, H))
query = torch.randint(0, 3, (E,H))
value = torch.randint(0, 3, (E,H))

print(f"key -> {key}")
print(f"query -> {query}")

x = torch.randint(0, 3, (B,T,E))
print(f"x -> {x}")
k = x @ key #BTE @ EH -> BTH
print(f"k -> {k}")
q = x @ query
v = x @ value
print(f"q -> {q}")

wei = k @ q.transpose(-2, -1) #BTH @ BHT -> BTT
print(f"wei -> {wei}")

tril = torch.tril(torch.ones(T,T))
wei = wei.masked_fill(tril==0, 100000)
print(wei)

key -> tensor([[0, 1, 2, 1],
        [2, 0, 0, 1],
        [2, 2, 0, 2],
        [1, 2, 0, 0]])
query -> tensor([[0, 0, 0, 0],
        [0, 2, 0, 1],
        [0, 0, 0, 2],
        [0, 0, 2, 1]])
x -> tensor([[[2, 1, 2, 1],
         [2, 2, 1, 0],
         [2, 1, 0, 2],
         [0, 1, 1, 1]]])
k -> tensor([[[7, 8, 4, 7],
         [6, 4, 4, 6],
         [4, 6, 4, 3],
         [5, 4, 0, 3]]])
q -> tensor([[[0, 2, 2, 6],
         [0, 4, 0, 4],
         [0, 2, 4, 3],
         [0, 2, 2, 4]]])
wei -> tensor([[[66, 60, 53, 52],
         [52, 40, 42, 40],
         [38, 36, 37, 32],
         [26, 28, 17, 20]]])
tensor([[[    66, 100000, 100000, 100000],
         [    52,     40, 100000, 100000],
         [    38,     36,     37, 100000],
         [    26,     28,     17,     20]]])
