In [49]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from tqdm import tqdm

In [50]:
# @dataclass
# class GPTconfig:
    # block_size: int = 128
    # vocab_size = 50257
    # n_embd: int = 768
    # n_layer: int = 12
    # n_head: int = 12
    # dropout: float = 0.0

batch_size: int = 16
block_size: int = 32
epochs = 5000
eval_iters = 100
learning_rate = 1e-3
# vocab_size = 50257
n_embd: int = 64
n_layer: int = 4
n_head: int = 4
dropout: float = 0.0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")    

In [51]:
with open("input.txt", "r") as f:
    text = f.read()

In [52]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [53]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size, ))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x = x.to(device)
    y = y.to(device)
    return x, y

In [54]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for i in range(eval_iters):
            x, y = get_batch(split)
            logits, loss = model(x, y)
            losses[i] = loss.item()
        out[split] = losses.mean().item()
    model.train()  
    return out

In [55]:
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x) # (B, T, head_size)
        q = self.query(x) # (B, T, head_size)
        v = self.value(x) # (B, T, head_size)

        scale_value = k.shape[-1] ** -0.5
        attention = q @ k.transpose(-2, -1) * scale_value # (B, T, T)

        masked_attention = attention.masked_fill(self.tril[:T, :T] == 0, float("-inf")) # (B, T, T)
        masked_attention = F.softmax(masked_attention, dim=-1) # (B, T, T)
        masked_attention = self.dropout(masked_attention) # (B, T, T)

        out = masked_attention @ v # (B, T, head_size)

        return out    

In [56]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_head, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(n_head)])
        self.proj = nn.Linear(n_head * head_size, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = torch.cat([h(x) for h in self.heads], dim=-1)
        return self.dropout(self.proj(x))


In [57]:
class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(n_embd, n_embd * 4),
            nn.ReLU(), 
            nn.Linear(n_embd * 4, n_embd),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.layers(x)

In [58]:
class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.self_attn = MultiHeadAttention(n_head, head_size)
        self.feed_forward = FeedForward(n_embd)
        self.layer_norm1 = nn.LayerNorm(n_embd)
        self.layer_norm2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.self_attn(self.layer_norm1(x))
        x = x + self.feed_forward(self.layer_norm2(x))
        return x

In [82]:
class GPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embeddings = nn.Embedding(vocab_size, n_embd)
        self.position_embeddings = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        tok_emb = self.token_embeddings(idx) # (B, T, C)
        pos_emb = self.position_embeddings(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)

        logits = self.head(x)

        if targets is not None:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)

        else:
            loss = None

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
            

In [90]:
model = GPT()
model = model.to(device)
print(sum(p.numel() for p in model.parameters()))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


209729


  0%|                                                                      | 0/5000 [00:00<?, ?it/s]

  0%|                                                              | 5/5000 [00:01<21:41,  3.84it/s]

epoch 0, train loss 4.302, valid loss 4.271


  2%|█▎                                                          | 105/5000 [00:06<08:57,  9.11it/s]

epoch 100, train loss 3.450, valid loss 3.431


  4%|██▍                                                         | 204/5000 [00:11<12:34,  6.36it/s]

epoch 200, train loss 3.365, valid loss 3.221


  6%|███▋                                                        | 303/5000 [00:18<24:38,  3.18it/s]

epoch 300, train loss 2.877, valid loss 3.042


  8%|████▊                                                       | 402/5000 [00:30<55:08,  1.39it/s]

epoch 400, train loss 2.975, valid loss 2.897


 10%|██████                                                      | 505/5000 [00:44<23:30,  3.19it/s]

epoch 500, train loss 2.761, valid loss 2.792


 12%|███████▎                                                    | 605/5000 [00:53<25:17,  2.90it/s]

epoch 600, train loss 2.679, valid loss 2.722


 14%|████████▍                                                   | 705/5000 [01:01<22:53,  3.13it/s]

epoch 700, train loss 2.696, valid loss 2.661


 16%|█████████▋                                                  | 804/5000 [01:10<17:40,  3.96it/s]

epoch 800, train loss 2.661, valid loss 2.616


 18%|██████████▊                                                 | 903/5000 [01:17<19:33,  3.49it/s]

epoch 900, train loss 2.544, valid loss 2.588


 20%|███████████▊                                               | 1005/5000 [01:23<14:07,  4.72it/s]

epoch 1000, train loss 2.498, valid loss 2.557


 22%|█████████████                                              | 1104/5000 [01:30<13:48,  4.70it/s]

epoch 1100, train loss 2.544, valid loss 2.546


 24%|██████████████▏                                            | 1200/5000 [01:34<03:13, 19.60it/s]

epoch 1200, train loss 2.562, valid loss 2.527


 26%|███████████████▍                                           | 1304/5000 [01:47<12:45,  4.83it/s]

epoch 1300, train loss 2.456, valid loss 2.513


 28%|████████████████▌                                          | 1402/5000 [01:53<16:12,  3.70it/s]

epoch 1400, train loss 2.500, valid loss 2.500


 30%|█████████████████▋                                         | 1503/5000 [02:10<43:22,  1.34it/s]

epoch 1500, train loss 2.555, valid loss 2.485


 32%|██████████████████▉                                        | 1604/5000 [02:23<20:14,  2.80it/s]

epoch 1600, train loss 2.507, valid loss 2.451


 34%|████████████████████                                       | 1703/5000 [02:32<19:04,  2.88it/s]

epoch 1700, train loss 2.562, valid loss 2.451


 36%|█████████████████████▎                                     | 1804/5000 [02:38<10:23,  5.12it/s]

epoch 1800, train loss 2.442, valid loss 2.429


 38%|██████████████████████▍                                    | 1905/5000 [02:44<08:34,  6.01it/s]

epoch 1900, train loss 2.371, valid loss 2.422


 40%|███████████████████████▋                                   | 2003/5000 [02:54<21:31,  2.32it/s]

epoch 2000, train loss 2.460, valid loss 2.403


 42%|████████████████████████▊                                  | 2105/5000 [03:02<10:12,  4.73it/s]

epoch 2100, train loss 2.520, valid loss 2.414


 44%|██████████████████████████                                 | 2205/5000 [03:09<13:02,  3.57it/s]

epoch 2200, train loss 2.235, valid loss 2.399


 46%|███████████████████████████▏                               | 2305/5000 [03:19<13:06,  3.43it/s]

epoch 2300, train loss 2.328, valid loss 2.394


 48%|████████████████████████████▍                              | 2407/5000 [03:26<07:10,  6.02it/s]

epoch 2400, train loss 2.307, valid loss 2.369


 50%|█████████████████████████████▌                             | 2504/5000 [03:32<09:31,  4.37it/s]

epoch 2500, train loss 2.335, valid loss 2.367


 52%|██████████████████████████████▋                            | 2605/5000 [03:39<08:44,  4.57it/s]

epoch 2600, train loss 2.345, valid loss 2.364


 54%|███████████████████████████████▉                           | 2704/5000 [03:45<07:36,  5.03it/s]

epoch 2700, train loss 2.405, valid loss 2.351


 56%|█████████████████████████████████                          | 2805/5000 [03:52<08:17,  4.41it/s]

epoch 2800, train loss 2.209, valid loss 2.349


 58%|██████████████████████████████████▎                        | 2904/5000 [03:59<07:19,  4.77it/s]

epoch 2900, train loss 2.436, valid loss 2.331


 60%|███████████████████████████████████▍                       | 3005/5000 [04:07<11:25,  2.91it/s]

epoch 3000, train loss 2.375, valid loss 2.328


 62%|████████████████████████████████████▋                      | 3105/5000 [04:13<06:41,  4.72it/s]

epoch 3100, train loss 2.380, valid loss 2.324


 64%|█████████████████████████████████████▊                     | 3204/5000 [04:20<06:17,  4.75it/s]

epoch 3200, train loss 2.361, valid loss 2.325


 66%|███████████████████████████████████████                    | 3306/5000 [04:26<05:37,  5.01it/s]

epoch 3300, train loss 2.211, valid loss 2.307


 68%|████████████████████████████████████████▏                  | 3404/5000 [04:32<05:18,  5.01it/s]

epoch 3400, train loss 2.291, valid loss 2.300


 70%|█████████████████████████████████████████▎                 | 3503/5000 [04:39<08:03,  3.09it/s]

epoch 3500, train loss 2.206, valid loss 2.297


 72%|██████████████████████████████████████████▌                | 3602/5000 [04:48<13:27,  1.73it/s]

epoch 3600, train loss 2.242, valid loss 2.290


 74%|███████████████████████████████████████████▋               | 3703/5000 [04:57<07:31,  2.87it/s]

epoch 3700, train loss 2.278, valid loss 2.285


 76%|████████████████████████████████████████████▉              | 3805/5000 [05:04<04:34,  4.35it/s]

epoch 3800, train loss 2.185, valid loss 2.279


 78%|██████████████████████████████████████████████             | 3905/5000 [05:11<03:57,  4.62it/s]

epoch 3900, train loss 2.180, valid loss 2.278


 80%|███████████████████████████████████████████████▏           | 4004/5000 [05:17<03:30,  4.74it/s]

epoch 4000, train loss 2.314, valid loss 2.259


 82%|████████████████████████████████████████████████▍          | 4106/5000 [05:24<03:14,  4.61it/s]

epoch 4100, train loss 2.167, valid loss 2.251


 84%|█████████████████████████████████████████████████▌         | 4205/5000 [05:31<02:50,  4.67it/s]

epoch 4200, train loss 2.184, valid loss 2.242


 86%|██████████████████████████████████████████████████▊        | 4304/5000 [05:37<02:28,  4.70it/s]

epoch 4300, train loss 2.283, valid loss 2.254


 88%|███████████████████████████████████████████████████▉       | 4406/5000 [05:44<02:00,  4.92it/s]

epoch 4400, train loss 2.235, valid loss 2.242


 90%|█████████████████████████████████████████████████████▏     | 4505/5000 [05:51<01:46,  4.66it/s]

epoch 4500, train loss 2.270, valid loss 2.239


 92%|██████████████████████████████████████████████████████▎    | 4606/5000 [05:57<01:25,  4.60it/s]

epoch 4600, train loss 2.249, valid loss 2.235


 94%|███████████████████████████████████████████████████████▌   | 4705/5000 [06:04<01:01,  4.80it/s]

epoch 4700, train loss 2.090, valid loss 2.224


 96%|████████████████████████████████████████████████████████▋  | 4804/5000 [06:10<00:40,  4.87it/s]

epoch 4800, train loss 2.200, valid loss 2.212


 98%|█████████████████████████████████████████████████████████▊ | 4903/5000 [06:18<00:30,  3.16it/s]

epoch 4900, train loss 2.186, valid loss 2.206


100%|███████████████████████████████████████████████████████████| 5000/5000 [06:24<00:00, 13.00it/s]

epoch 4999, train loss 2.256, valid loss 2.220





In [100]:
for epoch in tqdm(range(1000), ncols=100):
    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()  

    if epoch % eval_iters == 0 or epoch == epochs - 1:
        losses = estimate_loss()
        print(f'epoch {epoch}, train loss {loss.item():.3f}, valid loss {losses["val"]:.3f}')

  0%|▏                                                             | 4/1000 [00:01<05:49,  2.85it/s]

epoch 0, train loss 2.023, valid loss 2.098


 11%|██████▎                                                     | 106/1000 [00:06<01:37,  9.14it/s]

epoch 100, train loss 1.934, valid loss 2.090


 21%|████████████▎                                               | 206/1000 [00:10<01:28,  8.98it/s]

epoch 200, train loss 2.089, valid loss 2.084


 31%|██████████████████▍                                         | 307/1000 [00:15<01:13,  9.43it/s]

epoch 300, train loss 2.072, valid loss 2.081


 41%|████████████████████████▍                                   | 407/1000 [00:19<01:08,  8.67it/s]

epoch 400, train loss 2.005, valid loss 2.078


 51%|██████████████████████████████▍                             | 507/1000 [00:24<00:56,  8.69it/s]

epoch 500, train loss 2.005, valid loss 2.083


 61%|████████████████████████████████████▍                       | 607/1000 [00:28<00:46,  8.48it/s]

epoch 600, train loss 2.041, valid loss 2.077


 71%|██████████████████████████████████████████▎                 | 706/1000 [00:33<00:35,  8.36it/s]

epoch 700, train loss 1.945, valid loss 2.063


 81%|████████████████████████████████████████████████▎           | 806/1000 [00:38<00:25,  7.53it/s]

epoch 800, train loss 2.075, valid loss 2.078


 90%|██████████████████████████████████████████████████████▏     | 904/1000 [00:44<00:23,  4.04it/s]

epoch 900, train loss 2.042, valid loss 2.078


100%|███████████████████████████████████████████████████████████| 1000/1000 [00:48<00:00, 20.68it/s]


In [101]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokens=1000)[0].tolist()))


Fethy ruls, for, qut.

EY VINCER:
Pow, bemres!
Thou mure ore rump hy carept, and anlive.
As him pante thyee ofrunt. Renges:
I she il me desparse tuing have say Setrored,-
O'GLIOLIUKE-
DWHUKINA:
Til, delly go thim
Marbes camendy!
Sewor' him dond a sereiid day.
Thon your madady me thou thas toundentle.

DUCTALIUCHBES:

IGRGAO:
F deacl in, love.

KING TIIVM:
I powita wor not chis is, the! the ut I'lCYoon

FHORULAPE:

Diest u.
JULARIS:
Th wing wour that kere! Thou los this tha me turcke sorough.

POM! INGTEDW:
But It ard in that diall known sities flaqied;
Aund, cope wethe and, by have
I'e are me diths manse bed
Sith, all toin cissmand, my:
sup wich caingh'd sire of
Lor thee is kis nor well.
QTON'US:
Yoave ay the meake think:
I thak
I sthereedserpad, this leese;
See with lictimenk exsgh mer bepte knot thootir. Save grent bearfulm? word,
What a and in lamefner'nt, this do tuby.
The weweld she firme lo, tho agu'g, trow rentu,--

TUCES:
Thas ha me?
What rast'd men thorthe sullfaing awit
Thee