In [1]:
import os
import requests
import numpy as np
import torch
import torch.nn.functional as F

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# download the tiny shakespeare dataset
input_file_path = 'shakespeare.txt'
if not os.path.exists(input_file_path):
    data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    with open(input_file_path, 'w', encoding='utf-8') as f:
        f.write(requests.get(data_url).text)

with open(input_file_path, 'r', encoding='utf-8') as f:
    data = f.read()
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

In [4]:
stoi = {ch: i for i, ch in enumerate(sorted(set(data)))}
itos = {i: ch for ch, i in stoi.items()}

In [5]:
train_data_encoded = np.array([stoi[ch] for ch in train_data], dtype=np.uint16)
val_data_encoded = np.array([stoi[ch] for ch in val_data], dtype=np.uint16)
def encode(s):
    return [stoi[c] for c in s] # encoder: take a string, output a list of integers
def decode(l):
    return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

In [6]:
vocab_size = len(stoi)
n_embd = 32
batch_size = 640
block_size = 8
head_size = 16

In [7]:
def get_batch(split):
    if split == 'train':
        data = train_data_encoded
    elif split == 'val':
        data = val_data_encoded
    else:
        raise ValueError('split must be either train or val')
    start_idx = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in start_idx])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in start_idx])
    return x.to(device), y.to(device)


In [8]:
get_batch('train')

(tensor([[46, 58,  1,  ..., 59,  1, 42],
         [56, 57, 43,  ...,  0, 18, 30],
         [50,  1, 54,  ...,  1, 63, 53],
         ...,
         [58, 53,  1,  ..., 49, 43,  1],
         [53, 52, 39,  ..., 44, 47, 52],
         [44, 39, 51,  ..., 57,  1, 47]], device='cuda:0'),
 tensor([[58,  1, 63,  ...,  1, 42, 47],
         [57, 43, 10,  ..., 18, 30, 21],
         [ 1, 54, 59,  ..., 63, 53, 59],
         ...,
         [53,  1, 51,  ..., 43,  1, 46],
         [52, 39, 11,  ..., 47, 52, 42],
         [39, 51, 53,  ...,  1, 47, 57]], device='cuda:0'))

In [9]:
class Head(torch.nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = torch.nn.Linear(n_embd, head_size)
        self.query = torch.nn.Linear(n_embd, head_size)
        self.value = torch.nn.Linear(n_embd, head_size)
        self.register_buffer('mask', torch.tril(torch.ones(block_size, block_size)) == 0)
    
    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        wei = q @ k.transpose(-2, -1) * (head_size**-0.5)
        wei = wei.masked_fill(self.mask[:T, :T], float('-inf'))
        wei = F.softmax(wei, dim=-1)
        return wei @ v        

In [57]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, n_heads, head_size):
        super().__init__()
        self.heads = torch.nn.ModuleList([Head(head_size) for _ in range(n_heads)])
        self.proj = torch.nn.Linear(n_heads*head_size, n_embd)
    
    def forward(self, x):
        x = torch.cat([h(x) for h in self.heads], dim=-1)
        x = self.proj(x)
        return x

In [58]:
class FeedForward(torch.nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(n_embd, 4 * n_embd),
            torch.nn.ReLU(),
            torch.nn.Linear(4 * n_embd, n_embd)
        )
    
    def forward(self, x):
        return self.net(x)

In [59]:
class Block(torch.nn.Module):
    def __init__(self, n_emnd, n_heads):
        super().__init__()
        head_size = n_embd // n_heads
        self.sa = MultiHeadAttention(n_heads, head_size)
        self.ffwd = FeedForward(n_embd)
        
    def forward(self, x):
        x = x + self.sa(x)
        x = x + self.ffwd(x)
        return x

In [60]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.token_embedding_table = torch.nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = torch.nn.Embedding(block_size, n_embd)
        self.blocks = torch.nn.Sequential(*[Block(n_embd, 4) for _ in range(3)])
        self.lm_head = torch.nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        logits = self.lm_head(x)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for i in range(max_new_tokens):
            logits, _ = self(idx[:, -block_size:])
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            new_token = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, new_token], dim=-1)
        return idx

In [61]:
lm = Model().to(device)
optimizer = torch.optim.AdamW(lm.parameters(), lr=0.001)

In [62]:
for step in range(10000):
    x, y = get_batch('train')
    logits, loss = lm(x, y)
    if step % 1000 == 0:
        print(f'step {step}, loss {loss.item()}')
    lm.zero_grad()
    loss.backward()
    optimizer.step()

step 0, loss 4.578716278076172
step 1000, loss 1.9970357418060303
step 2000, loss 1.9055207967758179
step 3000, loss 1.8674724102020264
step 4000, loss 1.8498947620391846
step 5000, loss 1.8113723993301392
step 6000, loss 1.79720139503479
step 7000, loss 1.757796287536621
step 8000, loss 1.8209480047225952
step 9000, loss 1.7407855987548828


In [63]:
x = torch.zeros((1, 1), dtype=torch.int64).to(device)
y_ = lm.generate(x, 1000)
print(decode(y_[0].tolist()))


HENRY VI
How in her,
It him ass
I am you has, my four coswere of Dishort but can daughter as in and loal my both, harm?
Our tendle thou, stay, and Senfock chused mine,
Wasted sweet with with shaments but I dead: subquartten the ewo order neess;
Not of direin court ABETH:
Uthere-movine us have all vate ever, maid renefore fly stummanden the for'l quato thus will timold better, More him?

Shecties.
Trough and but ourseen outh,
And riden us been.

QUEER:
Welcomas.

YENIUS:
Now,
For that wear guife.

BRUTABENVOLIXENES:
'Gre do sleep alpatchman:
Herern! Seventence ruestrain you; now unneds?

GLOUCESTER:
Thou with our have have poil of chout with the shake to his women France hod, Vole ordreator gries; thou? all can I your look's his king?

SYONNRENCE:
For VI:
Before, patol; alloa part.

KING HENRY BOLINGBROMEO:
A pursme despire, when stay,
And cut?
'Tis bless, oness of coxtently lord; thousand too; and deck the part.

LEONTES:
That prosys, who childly been, a sold good, keep whrone.

MARCI

In [56]:
x = torch.tensor([encode('Hi M')]).to(device)
y_ = lm.generate(x, 30)
print(decode(y_[0].tolist()))

Hi Maraant bre in by be bode,
And 
