In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
with open("./data/byron.txt", "r", encoding='utf-8') as f:
    text = f.read()
len(text)

1765483

In [3]:
vocab = sorted(list(set(text)))
vocab_size = len(vocab)

In [4]:
stoi = {ch:i for i,ch in enumerate(vocab)}
itos = {i:ch for i,ch in enumerate(vocab)}
encode = lambda s: [stoi[ch] for ch in s]
decode = lambda l: "".join([itos[i] for i in l])
encode('hi there !')
decode([19, 20, 1, 31, 19, 16, 29, 16, 1, 2])

'hi there !'

In [5]:
data = torch.tensor(encode(text), dtype=torch.long)
n = (int)(len(data) * 0.9)
data_train = data[:n]
data_val = data[n:]
len(data) == len(data_train) + len(data_val)

True

In [6]:
#block_size = 8
#batch_size = 4
#max_iterations = 3000
#eval_interval = 300
#eval_iterations = 200
#learning_rate = 1e-2
#device = 'cuda' if torch.cuda.is_available() else 'cpu'
#n_embed = 32
#head_size = 32
#num_heads = 4
#n_layers = 4
#dropout = 0.2

device = 'cuda' if torch.cuda.is_available() else 'cpu'
block_size = 256
batch_size = 64
max_iterations = 8000
eval_iterations = 200
eval_interval = 500
learning_rate = 3e-4
n_embed = 384
num_heads = 6
n_layers = 6
dropout = 0.2

In [7]:
device

'cuda'

In [8]:
def get_batch(split):
    data = data_train if split == 'train' else data_val
    ix = torch.randint(len(data) - block_size, [batch_size,])
    xb = torch.stack([data[i:i+block_size] for i in ix])
    yb = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = xb.to(device), yb.to(device)
    return x,y
    

In [9]:
x, y = get_batch('train')

In [10]:
print(x)
print(y)

tensor([[ 1, 24, 16,  ..., 16,  1, 36],
        [13, 32, 31,  ..., 24, 12, 25],
        [30,  1, 31,  ...,  4, 15,  7],
        ...,
        [32, 15,  1,  ..., 19, 12, 15],
        [25, 14, 26,  ..., 32, 23,  1],
        [29,  1, 34,  ..., 23, 30,  6]], device='cuda:0')
tensor([[24, 16,  1,  ...,  1, 36, 26],
        [32, 31,  1,  ..., 12, 25, 25],
        [ 1, 31, 26,  ..., 15,  7, 32],
        ...,
        [15,  1, 26,  ..., 12, 15, 26],
        [14, 26, 32,  ..., 23,  1, 12],
        [ 1, 34, 19,  ..., 30,  6,  0]], device='cuda:0')


In [11]:
class SingleHeadAttention(nn.Module):
    def __init__(self, head_size, latent_head_size):
        super().__init__()
        self.key_latent = nn.Linear(n_embed, latent_head_size, bias=False)
        self.key = nn.Linear(latent_head_size, head_size, bias=False)
        
        self.value_latent = nn.Linear(n_embed, latent_head_size, bias=False)
        self.value = nn.Linear(latent_head_size, head_size, bias=False)

        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x): #BTE
        B,T,E = x.shape

        k_latent = self.key_latent(x) #BTE @ EL -> BTL
        v_latent = self.value_latent(x) #BTE @ EL -> BTL

        k = self.key(k_latent) #BTL @ LH -> BTH
        v = self.value(v_latent) #BTL @ LH -> BTH

        q = self.query(x) #BTH
            
        wei = k @ q.transpose(-2, -1) * E ** -0.5 #BTH @ BHT -> BTT
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = wei.softmax(dim=-1)
        wei = self.dropout(wei)
            
        out = wei @ v # BTT @ BTH -> BTH
        return out    

In [12]:
class MultiheadLatentAttention(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        dim_head = n_embed // num_heads
        dim_latent_head = dim_head // 2
        self.multiheads  = nn.ModuleList([SingleHeadAttention(dim_head, dim_latent_head) 
                                          for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        out = torch.cat([h(x) for h in self.multiheads], dim=-1)  
        out = self.dropout(self.proj(out))
        return out

In [13]:
class Feedforward(nn.Module):
    def __init__(self):
        super().__init__()
        self.ffwd = nn.Sequential( 
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout))
        
    def forward(self, x):
        out = self.ffwd(x)
        return out

In [14]:
class Block(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        self.multiheads = MultiheadLatentAttention(num_heads)
        self.ffwd = Feedforward()
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)
        
    def forward(self, x):
        x = x + self.multiheads(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [15]:
class MLA_Bitsy(nn.Module):
    def __init__(self):
        super().__init__()
        self.tok_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.pos_embedding_table = nn.Embedding(block_size, n_embed)
        self.layers = nn.Sequential(*[Block(num_heads) for _ in range(n_layers)])
        self.ln = nn.LayerNorm(n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size)
        
    def forward(self, idx, targets=None):
        B,T = idx.shape
        tok_embed = self.tok_embedding_table(idx) #BTE
        pos_embed = self.pos_embedding_table(torch.arange(T, device=device))
        x = tok_embed + pos_embed
        x = self.layers(x)
        x = self.ln(x)
        logits = self.lm_head(x) #BTE @ EV -> BTV
        
        if targets == None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    @torch.no_grad()
    def estiamte_loss(self):
        out = {}
        self.eval()
        for split in ['train', 'val']:
            losses = torch.zeros(eval_iterations)
            for k in range(eval_iterations):
                x, y = get_batch(split)
                logits, loss = self(x, y)
                losses[k] = loss.item()
            out[split] = losses.mean()
        self.train()
        return out        
    
    @torch.no_grad()
    def generate(self, idx, max_tokens=300):
        for _ in range(max_tokens):
            idx_cond = idx[:, -block_size:]
            logits, losss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=-1)
        return idx
    
    def get_num_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def persistModel(self):
        path = 'model/mla-bitsy'
        torch.save(self.state_dict(), path)
        return path
    
    def loadModel(self):
        path = 'model/mla-bitsy'
        self.load_state_dict(torch.load(path, map_location='cuda'))

In [16]:
torch.manual_seed(1337)
m = MLA_Bitsy()
model = m.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
for k in range(max_iterations):
    if k % eval_interval == 0:
        out = model.estiamte_loss()
        print(f"iter: {k}, train_loss: {out['train']:.4f}, val_loss: {out['val']:.4f}")
    x, y = get_batch('train')
    logits, loss = model(x, y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(f"iter: {k}, train_loss: {out['train']:.4f}, val_loss: {out['val']:.4f}")

model.persistModel()

iter: 0, train_loss: 4.0069, val_loss: 4.0085
iter: 500, train_loss: 2.0705, val_loss: 2.0744
iter: 1000, train_loss: 1.7731, val_loss: 1.7741
iter: 1500, train_loss: 1.6358, val_loss: 1.6460
iter: 2000, train_loss: 1.5517, val_loss: 1.5681
iter: 2500, train_loss: 1.4903, val_loss: 1.5255
iter: 3000, train_loss: 1.4437, val_loss: 1.4868
iter: 3500, train_loss: 1.4058, val_loss: 1.4592
iter: 4000, train_loss: 1.3749, val_loss: 1.4437
iter: 4500, train_loss: 1.3485, val_loss: 1.4256
iter: 5000, train_loss: 1.3233, val_loss: 1.4186
iter: 5500, train_loss: 1.3012, val_loss: 1.4075
iter: 6000, train_loss: 1.2818, val_loss: 1.4006
iter: 6500, train_loss: 1.2628, val_loss: 1.3933
iter: 7000, train_loss: 1.2456, val_loss: 1.3933
iter: 7500, train_loss: 1.2258, val_loss: 1.3911
iter: 7999, train_loss: 1.2258, val_loss: 1.3911


'model/mla-bitsy'

In [17]:
idx = torch.zeros((1,1), dtype=torch.long, device=device)
print(decode(model.generate(idx, 1000)[0].tolist()))


enough we not where the seventh none,
from the dust caps, tincto? and all shall not wear
oh! question, more unpleased to thee
glad in the dops of men! but no flame
fortune flame
in sadnesshow dares beside that distant pale;
and here heart, was votice on its thine toil,
and gathler flag the claspion-locks the home
thee were those tide of tears and crief -
a glancing sisteam, glory, glare, or grave,
and therefore we lave alone further cry;
it is on thine falcowers are beauty
that holds in most pleasant men,
when strong agress have smunn'd, as thou,
and gentle to winds a scene riser sound,
and hail'd him and leaf him sleepings alone,
and was the woe to seek by spanish roach's rate
by him stare not both the clumb wand from a bird;
don juan's dyeth intresses past,
and they made him last mercyber known
such as from the nore they had died enough
that glanced where he must raise, and allure his hoary bore
judged he gleam'd he for the bird of his ranson!
trying seat was everybody boughabban;
o