In [107]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import random

In [108]:
inp_text = str(open('input.txt').read())
chars = {}
set_chars = set()
for i in inp_text:
    for j in str(i):
        set_chars.add(j)
for i, letter in enumerate(sorted(list(set_chars))):
    chars[letter] = i

In [123]:
keys_chars = list(chars.keys())
vocab_size = len(keys_chars)
batch_size = 4
context_len = 8
emb_neur = 64
epochs = 5000
num_blocks = 4
number_of_heads = 4
dropout_neur = 0.2
lr = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [110]:
def encode(inp):
    ans = []
    for i in inp:
        ans.append(chars[i])
    return ans

def decode(inp):
    ans = ""
    for i in inp:
        ans+=str(keys_chars[i])
    return ans
inp_text = torch.tensor(encode(inp_text), dtype=torch.long)

In [111]:
n1 = int(0.9 * len(inp_text))
train_data = inp_text[:n1]
val_data = inp_text[n1:]


In [112]:
def get_batch(data):
    ixs = torch.randint(0, len(data)-context_len-1, (batch_size,))
    x_batches = torch.stack([data[i:i+context_len] for i in ixs])
    ys = torch.stack([data[i+1:i+context_len+1] for i in ixs])
    x_batches, ys = x_batches.to(device), ys.to(device)
    return x_batches, ys


In [113]:
@torch.no_grad()
def calculate_loss(data):
    losses = []
    m.eval()
    for _ in range(1000):
        xs, ys = get_batch(data)
        logits, loss = m(xs, ys)
        losses.append(loss.item())
    losses = torch.tensor(losses)
    m.train()
    return losses.mean()

In [130]:
torch.manual_seed(1337)

class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(emb_neur, head_size, bias=False)
        self.query = nn.Linear(emb_neur, head_size, bias=False)
        self.value = nn.Linear(emb_neur, head_size, bias=False)

    def forward(self, xs):
        B, T, C = xs.shape
        k = self.key(xs)
        q = self.query(xs)
        v = self.value(xs)

        xs = k @ q.transpose(-2, -1) * k.shape[-1]**-0.5
        xs = torch.tril(xs)
        xs = xs.masked_fill(xs == 0, float("-inf"))
        xs = F.softmax(xs, dim=-1)
        
        out = xs @ v
        return out

class MultiHead(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        # self.proj = nn.Linear(emb_neur, emb_neur)
        # self.dropout = nn.Dropout(dropout_neur)

    def forward(self, xs):
        xs = torch.cat([head(xs) for head in self.heads], dim=-1)
        # return self.dropout(self.proj(xs))
        return xs

class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(emb_neur, 4 * emb_neur),
            nn.ReLU(),
            nn.Linear(4 * emb_neur, emb_neur),
            # nn.Dropout(dropout_neur),
        )

    def forward(self, xs):
        return self.net(xs)
        

class Block(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.multi_heads = MultiHead(num_heads, head_size)
        self.ff = FeedForward()
        # self.ln1 = nn.LayerNorm(emb_neur)
        # self.ln2 = nn.LayerNorm(emb_neur)

    def forward(self, xs):
        # xs = xs + self.multi_heads(self.ln1(xs))
        # xs = xs + self.ff(self.ln2(xs))
        xs = xs + self.multi_heads(xs)
        xs = xs + self.ff(xs)
        return xs
        

class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.tokens_embedding = nn.Embedding(vocab_size, emb_neur)
        self.position_embedding = nn.Embedding(context_len, emb_neur)
        # self.sa_head = Head(emb_neur)
        # self.multi_head = MultiHead(number_of_heads, emb_neur//number_of_heads)
        self.blocks = nn.Sequential( Block(number_of_heads, emb_neur//number_of_heads), Block(number_of_heads, emb_neur//number_of_heads), Block(number_of_heads, emb_neur//number_of_heads) )
        self.ln = nn.LayerNorm(emb_neur)
        self.ll_head = nn.Linear(emb_neur, vocab_size)

    def forward(self, xs_inputs, targets=None):
        B, T = xs_inputs.shape
        
        embedded_tokens = self.tokens_embedding(xs_inputs) # B, T, emb_neur
        embedded_position = self.position_embedding(torch.arange(T, device=device)) # T, emb_neur
        
        xs_inputs = embedded_tokens + embedded_position # B, T, emb_neur
        xs_inputs = self.blocks(xs_inputs)
        # xs_inputs = self.ln(xs_inputs)
        logits = self.ll_head(xs_inputs)
        
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, xs_inputs, max_tokens):
        for _ in range(max_tokens):
            xs_cond = xs_inputs[:, -context_len:]
            logits, loss = self(xs_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            next_letter = torch.multinomial(probs, num_samples=1)
            xs_inputs = torch.cat((xs_inputs, next_letter), dim=1)
        return xs_inputs
        


In [131]:
m = BigramLanguageModel()
m = m.to(device)
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')
optmizer = torch.optim.Adam(m.parameters(), lr=lr)

0.145153 M parameters


In [133]:
for epoch in range(epochs):
    xs, ys = get_batch(train_data)
    logits, loss = m(xs, ys)
    optmizer.zero_grad(set_to_none=True)
    loss.backward()
    optmizer.step()
    # if epoch %  == 0:
    print(epoch, '/', epochs, "loss:")

0 / 5000 loss:
1 / 5000 loss:
2 / 5000 loss:
3 / 5000 loss:
4 / 5000 loss:
5 / 5000 loss:
6 / 5000 loss:
7 / 5000 loss:
8 / 5000 loss:
9 / 5000 loss:
10 / 5000 loss:
11 / 5000 loss:
12 / 5000 loss:
13 / 5000 loss:
14 / 5000 loss:
15 / 5000 loss:
16 / 5000 loss:
17 / 5000 loss:
18 / 5000 loss:
19 / 5000 loss:
20 / 5000 loss:
21 / 5000 loss:
22 / 5000 loss:
23 / 5000 loss:
24 / 5000 loss:
25 / 5000 loss:
26 / 5000 loss:
27 / 5000 loss:
28 / 5000 loss:
29 / 5000 loss:
30 / 5000 loss:
31 / 5000 loss:
32 / 5000 loss:
33 / 5000 loss:
34 / 5000 loss:
35 / 5000 loss:
36 / 5000 loss:
37 / 5000 loss:
38 / 5000 loss:
39 / 5000 loss:
40 / 5000 loss:
41 / 5000 loss:
42 / 5000 loss:
43 / 5000 loss:
44 / 5000 loss:
45 / 5000 loss:
46 / 5000 loss:
47 / 5000 loss:
48 / 5000 loss:
49 / 5000 loss:
50 / 5000 loss:
51 / 5000 loss:
52 / 5000 loss:
53 / 5000 loss:
54 / 5000 loss:
55 / 5000 loss:
56 / 5000 loss:
57 / 5000 loss:
58 / 5000 loss:
59 / 5000 loss:
60 / 5000 loss:
61 / 5000 loss:


KeyboardInterrupt: 

In [None]:
calculate_loss(val_data)

In [90]:
# print(decode(m.generate(torch.zeros(1, 1, dtype=torch.int32, device=device), 500)[0].tolist()))

In [14]:
# (2.2899)
# tensor(2.1783) ln
# tensor(2.0981)