In [10]:
import os
import requests
import numpy as np
import torch
import torch.nn.functional as F

In [4]:

# download the tiny shakespeare dataset
input_file_path = 'shakespeare.txt'
if not os.path.exists(input_file_path):
    data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    with open(input_file_path, 'w', encoding='utf-8') as f:
        f.write(requests.get(data_url).text)

with open(input_file_path, 'r', encoding='utf-8') as f:
    data = f.read()
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

In [11]:
stoi = {ch: i for i, ch in enumerate(sorted(set(data)))}
itos = {i: ch for ch, i in stoi.items()}

In [28]:
train_data_encoded = np.array([stoi[ch] for ch in train_data], dtype=np.uint16)
val_data_encoded = np.array([stoi[ch] for ch in val_data], dtype=np.uint16)
def encode(s):
    return [stoi[c] for c in s] # encoder: take a string, output a list of integers
def decode(l):
    return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

In [29]:
vocab_size = len(stoi)
n_embd = 32
batch_size = 8
block_size = 8
steps = 100

In [39]:
def get_batch(split):
    if split == 'train':
        data = train_data_encoded
    elif split == 'val':
        data = val_data_encoded
    else:
        raise ValueError('split must be either train or val')
    start_idx = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in start_idx])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in start_idx])
    return x, y


In [57]:
get_batch('train')

(tensor([[58,  6,  1, 39, 52, 42,  1, 57],
         [27, 52, 43,  1, 46, 43, 39, 60],
         [ 1, 43, 52, 42, 59, 56, 43,  1],
         [53,  1, 61, 46, 47, 54,  1, 58],
         [63, 53, 59, 56,  1, 41, 53, 59],
         [ 1, 46, 53, 50, 63,  1, 20, 39],
         [ 1, 24, 53, 56, 42,  1, 13, 52],
         [ 1, 61, 53, 56, 58, 46, 47, 50]]),
 tensor([[ 6,  1, 39, 52, 42,  1, 57, 53],
         [52, 43,  1, 46, 43, 39, 60, 43],
         [43, 52, 42, 59, 56, 43,  1, 58],
         [ 1, 61, 46, 47, 54,  1, 58, 46],
         [53, 59, 56,  1, 41, 53, 59, 52],
         [46, 53, 50, 63,  1, 20, 39, 56],
         [24, 53, 56, 42,  1, 13, 52, 45],
         [61, 53, 56, 58, 46, 47, 50, 63]]))

In [104]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.token_embedding_table = torch.nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = torch.nn.Embedding(block_size, n_embd)
        self.lm_head = torch.nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
        x = tok_emb + pos_emb
        logits = self.lm_head(x)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for i in range(max_new_tokens):
            logits, _ = self(idx)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            new_token = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, new_token], dim=-1)
        return idx

In [105]:
lm = Model()
optimizer = torch.optim.Adam(lm.parameters(), lr=0.001)

In [110]:
for step in range(10000):
    x, y = get_batch('train')
    logits, loss = lm(x, y)
    if step % 1000 == 0:
        print(f'step {step}, loss {loss.item()}')
    lm.zero_grad()
    loss.backward()
    optimizer.step()

step 0, loss 2.2891223430633545
step 1000, loss 2.351369619369507
step 2000, loss 2.670475482940674
step 3000, loss 2.4176013469696045
step 4000, loss 2.1189658641815186
step 5000, loss 2.4586892127990723
step 6000, loss 2.5530014038085938
step 7000, loss 2.399681806564331
step 8000, loss 2.3939037322998047
step 9000, loss 2.6133534908294678


In [111]:
x = torch.zeros((1, 1), dtype=torch.int64)
y_ = lm.generate(x, 1000)
decode(y_[0].tolist())

"\nIIZESithahe benssiakee ptaby pof n:\nThaighik, hyounod ie gune, ofamavins th?\nFofe than?\nWhen te he, is, s, aved iasenspiour yot flfinkewaviceour je ct th aig batiopr, s l o me m:\nTalend d brades hetthencor, TExe han tu her bopr't, won on T:\nH:\nFin t h, fo caraisthathm\nINCAl t ed; ton S:\nME:\n\nHERI notan wo th ar KI savin olvely IOM s fusthe vintof d d hout:\n\nWhe g ERGof l, theO ounofor by ir t sus n, t y m eseathaigand ce ye mintharis t fedor\nBers, lode gh?-\n\nCKE ancoulth ddellise-\nFe,\n\nFrutheclotore:\n\nUnothamas t ns hieanh apusovome to e n, EENCHed;\ns the wove.\nBRTr meavelll. g; ro thes h thas, w n RYCENESwh amecorot ncal thate, wong an ay me anos:\n\nNToro beaithitath s, peanoit buthas mpe s thenty mowhan torou ofantos. f tthouou myo--m che, lawofa wout fuceaus,\nsy.\nPrailgaveexatherous ast h pl othee, fr. these nlisherecou\n\nGLI t fothan toun w Wicofsh sthonoloran, ngonanthinesinaithalor se?\nTond imid\nLEENUCHe thefond IULAno h I amy he y tour.\nButhriowre