In [1]:
import os
import requests
import numpy as np
import torch
import torch.nn.functional as F

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# download the tiny shakespeare dataset
input_file_path = 'shakespeare.txt'
if not os.path.exists(input_file_path):
    data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    with open(input_file_path, 'w', encoding='utf-8') as f:
        f.write(requests.get(data_url).text)

with open(input_file_path, 'r', encoding='utf-8') as f:
    data = f.read()
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

In [4]:
stoi = {ch: i for i, ch in enumerate(sorted(set(data)))}
itos = {i: ch for ch, i in stoi.items()}

In [5]:
train_data_encoded = np.array([stoi[ch] for ch in train_data], dtype=np.uint16)
val_data_encoded = np.array([stoi[ch] for ch in val_data], dtype=np.uint16)
def encode(s):
    return [stoi[c] for c in s] # encoder: take a string, output a list of integers
def decode(l):
    return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

In [6]:
vocab_size = len(stoi)
n_embd = 32
batch_size = 640
block_size = 8
head_size = 16

In [7]:
def get_batch(split):
    if split == 'train':
        data = train_data_encoded
    elif split == 'val':
        data = val_data_encoded
    else:
        raise ValueError('split must be either train or val')
    start_idx = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in start_idx])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in start_idx])
    return x.to(device), y.to(device)


In [8]:
get_batch('train')

(tensor([[27, 27, 28,  ..., 25, 43, 52],
         [57, 43, 52,  ..., 58, 46, 43],
         [ 1, 47, 51,  ..., 42, 43, 52],
         ...,
         [40, 59, 58,  ..., 52, 43, 61],
         [58,  1, 58,  ..., 50, 43,  8],
         [58, 46, 47,  ..., 39, 58,  1]]),
 tensor([[27, 28, 10,  ..., 43, 52,  1],
         [43, 52, 42,  ..., 46, 43,  1],
         [47, 51, 54,  ..., 43, 52, 58],
         ...,
         [59, 58,  1,  ..., 43, 61,  1],
         [ 1, 58, 47,  ..., 43,  8,  1],
         [46, 47, 57,  ..., 58,  1, 16]]))

In [20]:
class Head(torch.nn.Module):
    def __init__(self, input_size, head_size):
        super().__init__()
        self.key = torch.nn.Linear(input_size, head_size)
        self.query = torch.nn.Linear(input_size, head_size)
        self.value = torch.nn.Linear(input_size, head_size)
        self.register_buffer('mask', torch.tril(torch.ones(block_size, block_size)) == 0)
    
    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        wei = q @ k.transpose(-2, -1) * (head_size**-0.5)
        wei = wei.masked_fill(self.mask[:T, :T], float('-inf'))
        wei = F.softmax(wei, dim=-1)
        return wei @ v        

In [21]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, n_heads, head_size, input_size):
        super().__init__()
        self.heads = torch.nn.ModuleList([Head(input_size, head_size) for _ in range(n_heads)])
        self.proj = torch.nn.Linear(n_heads*head_size, input_size)
    
    def forward(self, x):
        x = torch.cat([h(x) for h in self.heads], dim=-1)
        x = self.proj(x)
        return x

In [22]:
class FeedForward(torch.nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(n_embd, 4 * n_embd),
            torch.nn.ReLU(),
            torch.nn.Linear(4 * n_embd, n_embd)
        )
    
    def forward(self, x):
        return self.net(x)

In [28]:
class Block(torch.nn.Module):
    def __init__(self, n_embd, n_heads):
        super().__init__()
        head_size = n_embd // n_heads
        self.sa = MultiHeadAttention(n_heads, head_size, n_embd)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = torch.nn.LayerNorm(n_embd)
        self.ln2 = torch.nn.LayerNorm(n_embd)
        
    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [30]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.token_embedding_table = torch.nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = torch.nn.Embedding(block_size, n_embd)
        self.blocks = torch.nn.Sequential(
            Block(n_embd, 4),
            Block(n_embd, 4),
            Block(n_embd, 4),
            torch.nn.LayerNorm(n_embd)
        )
        self.lm_head = torch.nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        logits = self.lm_head(x)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for i in range(max_new_tokens):
            logits, _ = self(idx[:, -block_size:])
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            new_token = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, new_token], dim=-1)
        return idx

In [31]:
lm = Model().to(device)
optimizer = torch.optim.AdamW(lm.parameters(), lr=0.001)

In [32]:
for step in range(5000):
    x, y = get_batch('train')
    logits, loss = lm(x, y)
    if step % 100 == 0:
        print(f'step {step}, loss {loss.item()}')
    lm.zero_grad()
    loss.backward()
    optimizer.step()

step 0, loss 4.653528213500977
step 100, loss 2.597411632537842
step 200, loss 2.380962610244751
step 300, loss 2.2899081707000732
step 400, loss 2.1950089931488037
step 500, loss 2.149674892425537
step 600, loss 2.1492626667022705
step 700, loss 2.097557544708252
step 800, loss 2.0751748085021973
step 900, loss 2.0033926963806152
step 1000, loss 1.9987818002700806
step 1100, loss 1.9997889995574951
step 1200, loss 1.9755010604858398
step 1300, loss 1.9807460308074951
step 1400, loss 1.9936994314193726
step 1500, loss 1.951111078262329
step 1600, loss 1.965682029724121
step 1700, loss 1.9140479564666748
step 1800, loss 1.976563811302185
step 1900, loss 1.9149105548858643
step 2000, loss 1.9191391468048096
step 2100, loss 1.9168260097503662
step 2200, loss 1.8981072902679443
step 2300, loss 1.9315135478973389
step 2400, loss 1.8563764095306396
step 2500, loss 1.8911774158477783
step 2600, loss 1.876099944114685
step 2700, loss 1.880812644958496
step 2800, loss 1.875939130783081
step 290

In [16]:
x = torch.zeros((1, 1), dtype=torch.int64).to(device)
y_ = lm.generate(x, 1000)
print(decode(y_[0].tolist()))


BENVOLIO:
He unk of Rarrand,
That tere beseep's of good father; the long-ate a woman him;
pounter's of in 'Ther him!
To this by where weeth his scow shut this belse part,
Have my spain,
faumel reman:
He him,
We as think die in him.

POMPEY:
When war?

ISABELLA:
That is,
To himself are ocley'd fear you hate faulted and with to gentleman:
This one what the down it,
Age mad is what numechanns of sorrow thous' they flown's mishes the cannori, I shall them wit,
There shade per:
Servess a back good on thus: how it in thee he bows his content. I draw are are farew thee: other
My four tone,
Make old escome the lak:
Ah, commarch'd Goding Mypearius.

LADY CAPULET:
We have day must-atter you.
What I we the down the And fear it to did me, love an thy hach curta is are his by count and my he' sweek heaven,
And by come me, and since with corsciolane hadst,
I secross nots,
If this, put unfan so day.

ROMNGHPEY:

ESCALUS:
The braise with not this,
In yours:
That not is the kingness a to winder:
She, 

In [19]:
x = torch.tensor([encode('Hi M')]).to(device)
y_ = lm.generate(x, 30)
print(decode(y_[0].tolist()))

Hi More's are of Romeo and made yo
