In [41]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import random

In [42]:
inp_text = str(open('input.txt').read())
chars = {}
set_chars = set()
for i in inp_text:
    for j in str(i):
        set_chars.add(j)
for i, letter in enumerate(sorted(list(set_chars))):
    chars[letter] = i

In [43]:
keys_chars = list(chars.keys())
vocab_size = len(keys_chars)
batch_size = 64
context_len = 256
emb_neur = 384
block_size = 256
epochs = 5000
num_blocks = 6
number_of_heads = 6
dropout_neur = 0.2
lr = 3e-4

# batch_size = 4
# context_len = 8
# emb_neur = 64
# epochs = 5000
# num_blocks = 4
# number_of_heads = 4
# dropout_neur = 0.2
# lr = 3e-4

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [44]:
def encode(inp):
    ans = []
    for i in inp:
        ans.append(chars[i])
    return ans

def decode(inp):
    ans = ""
    for i in inp:
        ans+=str(keys_chars[i])
    return ans
inp_text = torch.tensor(encode(inp_text), dtype=torch.long)

In [45]:
n1 = int(0.9 * len(inp_text))
train_data = inp_text[:n1]
val_data = inp_text[n1:]


In [46]:
def get_batch(data):
    ixs = torch.randint(0, len(data)-context_len-1, (batch_size,))
    x_batches = torch.stack([data[i:i+context_len] for i in ixs])
    ys = torch.stack([data[i+1:i+context_len+1] for i in ixs])
    x_batches, ys = x_batches.to(device), ys.to(device)
    return x_batches, ys


In [47]:
@torch.no_grad()
def calculate_loss(data):
    losses = []
    m.eval()
    for _ in range(1000):
        xs, ys = get_batch(data)
        logits, loss = m(xs, ys)
        losses.append(loss.item())
    losses = torch.tensor(losses)
    m.train()
    return losses.mean()

In [48]:
torch.manual_seed(1337)

class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(emb_neur, head_size, bias=False)
        self.query = nn.Linear(emb_neur, head_size, bias=False)
        self.value = nn.Linear(emb_neur, head_size, bias=False)

    def forward(self, xs):
        B, T, C = xs.shape
        k = self.key(xs)
        q = self.query(xs)
        v = self.value(xs)

        xs = k @ q.transpose(-2, -1) * k.shape[-1]**-0.5
        xs = torch.tril(xs)
        xs = xs.masked_fill(xs == 0, float("-inf"))
        xs = F.softmax(xs, dim=-1)
        
        out = xs @ v
        return out

class MultiHead(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(emb_neur, emb_neur)
        self.dropout = nn.Dropout(dropout_neur)

    def forward(self, xs):
        xs = torch.cat([head(xs) for head in self.heads], dim=-1)
        return self.dropout(self.proj(xs))

class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(emb_neur, 4 * emb_neur),
            nn.ReLU(),
            nn.Linear(4 * emb_neur, emb_neur),
            nn.Dropout(dropout_neur),
        )

    def forward(self, xs):
        return self.net(xs)
        

class Block(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.multi_heads = MultiHead(num_heads, head_size)
        self.ff = FeedForward()
        self.ln1 = nn.LayerNorm(emb_neur)
        self.ln2 = nn.LayerNorm(emb_neur)

    def forward(self, xs):
        xs = xs + self.multi_heads(self.ln1(xs))
        xs = xs + self.ff(self.ln2(xs))
        return xs
        

class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.tokens_embedding = nn.Embedding(vocab_size, emb_neur)
        self.position_embedding = nn.Embedding(context_len, emb_neur)
        # self.sa_head = Head(emb_neur)
        # self.multi_head = MultiHead(number_of_heads, emb_neur//number_of_heads)
        self.blocks = nn.Sequential( *[Block(number_of_heads, emb_neur//number_of_heads) for _ in range(num_blocks)])
        self.ln = nn.LayerNorm(emb_neur)
        self.ll_head = nn.Linear(emb_neur, vocab_size)

    def forward(self, xs_inputs, targets=None):
        B, T = xs_inputs.shape
        
        embedded_tokens = self.tokens_embedding(xs_inputs) # B, T, emb_neur
        embedded_position = self.position_embedding(torch.arange(T, device=device)) # T, emb_neur
        
        xs_inputs = embedded_tokens + embedded_position # B, T, emb_neur
        xs_inputs = self.blocks(xs_inputs)
        xs_inputs = self.ln(xs_inputs)
        logits = self.ll_head(xs_inputs)
        
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, xs_inputs, max_tokens):
        for _ in range(max_tokens):
            xs_cond = xs_inputs[:, -context_len:]
            logits, loss = self(xs_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            next_letter = torch.multinomial(probs, num_samples=1)
            xs_inputs = torch.cat((xs_inputs, next_letter), dim=1)
        return xs_inputs
        


In [49]:
m = BigramLanguageModel()
m = m.to(device)
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')
optmizer = torch.optim.Adam(m.parameters(), lr=lr)

10.788929 M parameters


In [50]:
for epoch in range(epochs):
    xs, ys = get_batch(train_data)
    logits, loss = m(xs, ys)
    optmizer.zero_grad(set_to_none=True)
    loss.backward()
    optmizer.step()
    if epoch % 2000 == 0:
        print(epoch, '/', epochs, "loss:", calculate_loss(val_data))

0 / 5000 loss: tensor(3.5523)
2000 / 5000 loss: tensor(1.5433)
4000 / 5000 loss: tensor(1.5360)


In [54]:
calculate_loss(val_data)

tensor(1.5711)

In [52]:
print(decode(m.generate(torch.zeros(1, 1, dtype=torch.int32, device=device), 500)[0].tolist()))


Come, but Margaret!

CAMILLO:
How mildly! where, Hi!

POLIXENES:
Mark, sweet man!
Part an old Isabout have tableness!
Enougle him, to set the power of death!
Who deep authority, butch'd by great pretty!
Come, cousin perheal; measures our toasy.
Busines, thou surest. See, feel when your tensible, are
grace waxt manner of the morning, in the wholes
greet strumpetry, envy on Albion, breat both welch-ruin; peever,
though some varlet, wast we both you,
Pour queen, of the ugleets and full of greet.

L


In [53]:
torch.save(m, 'model.pth')