In [216]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

In [217]:
data = open('input.txt').read()
alphabet = set([c for c in data])
stoi = {s:i for i,s in enumerate(alphabet)}
itos = {value: key for key, value in stoi.items()}
vocab_size = len(alphabet)
print(f"{vocab_size=}")
def encode(s):
    return [stoi[c] for c in s]

def decode(encoded):
    return ''.join([itos[i] for i in encoded])

encoded = torch.tensor(encode(data), dtype=torch.long)
encoded[:20]

vocab_size=65


tensor([31, 60, 50, 54, 64,  1, 21, 60, 64, 60, 37, 49, 19, 44, 41, 16, 49, 63,
        53, 50])

In [218]:
x = int(0.9 * len(data))
training_data = encoded[:x]
validation_data = encoded[x:]

In [219]:
torch.manual_seed(42)
batch_size = 32
block_size = 8
embedding_size = 64


def get_batch(split):
    data = training_data if split == 'train' else validation_data
    indices = torch.randint(0, data.shape[0] - 1 - block_size, (batch_size,))
    xs = torch.stack([data[i:i+block_size] for i in indices])
    ys = torch.stack([data[i+1:i+block_size+1] for i in indices])
    return xs, ys

xbatch, ybatch = get_batch('train')

query, key, value


In [220]:
# single self attention head
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.query = nn.Linear(embedding_size, head_size)
        self.key = nn.Linear(embedding_size, head_size)
        self.value = nn.Linear(embedding_size, head_size)
        self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)))
    
    def forward(self, x):
        # print(f'{x.shape=}, {vocab_size=}, {head_size=}')
        q = self.query(x) # B, T, H
        k = self.key(x) # B, T, H

        # print(f'{q.shape=}, {k.shape=}')

        wei = q @ k.transpose(-1, -2) * (q.shape[2] ** -0.5) # B, T, T
        # print(f'{wei.shape=}, {self.mask.shape=}')
        T = wei.shape[1]
        wei = wei.masked_fill(self.mask[:T, :T] == 0, -float('inf')) # B, T, T
        wei = F.softmax(wei, dim=-1) # B, T, T
        v = self.value(x) # B, T, H
        # print(f'{wei.shape=}\n\n {v.shape=}')
        return wei @ v # B T T @ B T H -> B T H

        

In [221]:
class FeedForward(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(head_size, 4 * head_size),
            nn.ReLU(),
            nn.Linear(4 * head_size, head_size),
        )

    def forward(self, x):
        return self.net(x)


In [222]:
class LanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embeddings = nn.Embedding(vocab_size, embedding_size)
        self.position_embeddings = nn.Embedding(block_size, embedding_size)
        self.head = Head(embedding_size)
        self.ff = FeedForward(embedding_size)
        self.final = nn.Linear(embedding_size, vocab_size)
    

    def forward(self, x, targets=None):
        tokens_embeddings = self.token_embeddings(x) 
        position_embeddings = self.position_embeddings(torch.arange(x.shape[1]))
        # print(f'{tokens_embeddings.shape=}, {position_embeddings.shape=}')
        x = tokens_embeddings + position_embeddings

        h = self.head(x) # (batch_size, block_size, head_size)
        # print(f'{h.shape=} {x.shape=}')
        x = x + h
        x = x + self.ff(x) # (batch_size, block_size, head_size)

        # print(x)
        logits = self.final(x) # (batch_size, block_size, vocab_size)

        loss = None
        if targets is not None:
            B, T, C = logits.shape
            loss = F.cross_entropy(
                logits.view(B * T, C),
                targets.view(B * T)
            )
        # print(f'{x.shape=}')
        return loss, logits
    
    def generate(self, length):
        result = ''
        context = torch.randint(0, vocab_size, (1,1))
        for _ in range(length):
            # print(f'{context.shape=}')
            _, logits = self(context)
            # print(f'{logits.shape=}')
            probs = F.softmax(logits, dim=-1)
            # print(f'{probs.shape=}')
            sample = torch.multinomial(probs[:, -1], 1, replacement=True)
            # print(f'{sample.shape=}, {sample.item()}')
            context = torch.cat([context, sample.view(1, 1)], dim=1)
            if context.shape[1] >= block_size:
                context = context[:, 1:]
            result += itos[sample.item()]
        return result

In [223]:
m = LanguageModel()
m.generate(10)

'LSYCjUziHL'

In [224]:
model = LanguageModel()
learning_rate = 0.001
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
losses = []

for step in range(20000):
    xbatch, ybatch = get_batch('train')
    loss, _ = model(xbatch, targets=ybatch)

    optimizer.zero_grad()
    loss.backward()
    losses.append(loss.item())
    optimizer.step()

    if step % 1000 == 0:
        print(f'step {step} loss {loss.item()}')

step 0 loss 4.539271354675293
step 1000 loss 2.3351659774780273
step 2000 loss 2.2629077434539795
step 3000 loss 2.236481189727783
step 4000 loss 2.033236026763916
step 5000 loss 1.9371086359024048
step 6000 loss 2.062751531600952
step 7000 loss 1.9947153329849243
step 8000 loss 1.9462963342666626
step 9000 loss 1.9698054790496826
step 10000 loss 1.947810411453247
step 11000 loss 1.9600310325622559
step 12000 loss 1.7514913082122803
step 13000 loss 1.8486696481704712
step 14000 loss 1.8105295896530151
step 15000 loss 1.8423959016799927
step 16000 loss 2.106238603591919
step 17000 loss 2.083470344543457
step 18000 loss 1.79603910446167
step 19000 loss 1.9162648916244507


In [225]:
print(model.generate(1000))

ETER:
How, beance wake how
dutes, by your wilkecond bed come.
-save, king Rome uman: our that sly wervand him me me sinen envion:
Ere, thee to thy lontagelast the did vaze me hem yo, saguends, ther bedard?

SICAPULINA:
Well world Cadamine
Theaved give!
I at's place!

Look bogged-lerse out,
As I mustould to make nect fear a and bears way gono.
Fort no motheir
of this purt. Tlook conce be that have thou to to a llice.
Whaties my lose, as igonsuriefter highne,
Brother, swooXson'd stage,
Ippose wormentent lead compain'd will and to ye lies none in pride had
Sore to cusenquon this sulserdly
To starmy lords;
Wheds. Here them I me!
Be eoverss rece stie.

LORD:
Onchope.

KING RICHORS:
Pompon so
no in the entsted sle! thought the but imps, should if thearstagento him:
Comord's, o surves,
March this a bany theave a daid bloods, will again, take swomnerys eno 'tward, lisheir of E VI:
I at the by that fors conders upiting
I my see grace,
And womper: to lore endepeak: hanereou!
He schoservy
expassu

In [226]:
eval_iters = 200

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            loss, _ = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [227]:
estimate_loss()

{'train': tensor(1.8799), 'val': tensor(2.0228)}