In [3]:
from pathlib import Path
import torch
import sys
sys.path.append('..')
from src.model import GPTLanguageModel

In [4]:
dataset_name = "TinyStories-train.txt"

In [5]:
file_path = Path.cwd() / '..' / 'data' / 'raw' / dataset_name
with open(file_path, 'r') as f:
    text = f.read()

# Explore the dataset

In [6]:
print(len(text))

1922767089


In [7]:
print(text[:1000])

One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.
Lily went to her mom and said, "Mom, I found this needle. Can you share it with me and sew my shirt?" Her mom smiled and said, "Yes, Lily, we can share the needle and fix your shirt."
Together, they shared the needle and sewed the button on Lily's shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together.
<|endoftext|>
Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong.
One day, Beep was driving in the park when he saw a big tree. The tree had many leaves that were

# Tokenize the dataset
We will use characters as tokens, just as a baseline.

Note: special tokens already exist, like `<|endoftext|>`. We'll just ignore them for now, and generate infinite text.

Note: there appears to be different languages, symbols, and emojis. I can recognize Chinese characters. We'll ignore them for now.

In [8]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)

	
 !"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]_`abcdefghijklmnopqrstuvwxyz{|}~ ¡¢£§«­°´·»¿ÂÉßàáâåèéêíïñóöúāİœɪʏʙʜіғᴀᴄᴅᴇᴏᴛᴜᴡᴢ   ​‌‎‐‑‒–—―‘’‚“”„…  ‪′€™−─❤　。」一了些他但保個們兒兩分到剛又和在天奮她己巴度很恩應把整是時會獨玉田留當的童答米給自興艾莉裡這過難高ﬁﬂ️﻿，￼�𝑐🌴🌹🍌🍞🎓💖🙂🤩
243


In [9]:
class TokenizerBase:
    def encode(self, s):
        raise NotImplementedError
    
    def decode(self, t):
        raise NotImplementedError


class CharacterTokenizer(TokenizerBase):
    def __init__(self, chars):
        self.stoi = { ch:i for i, ch in enumerate(chars) }
        self.itos = { i:ch for i, ch in enumerate(chars) }
    
    def encode(self, s):
        return [self.stoi[c] for c in s]
    
    def decode(self, l):
        return ''.join([self.itos[i] for i in l])

In [10]:
tokenizer = CharacterTokenizer(chars)
encoded = tokenizer.encode("Hii there")
print(encoded)
decoded = tokenizer.decode(encoded)
print(decoded)

[42, 74, 74, 2, 85, 73, 70, 83, 70]
Hii there


In [11]:
train_data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
print(train_data[:1000])

tensor([49, 79, 70,  2, 69, 66, 90, 14,  2, 66,  2, 77, 74, 85, 85, 77, 70,  2,
        72, 74, 83, 77,  2, 79, 66, 78, 70, 69,  2, 46, 74, 77, 90,  2, 71, 80,
        86, 79, 69,  2, 66,  2, 79, 70, 70, 69, 77, 70,  2, 74, 79,  2, 73, 70,
        83,  2, 83, 80, 80, 78, 16,  2, 53, 73, 70,  2, 76, 79, 70, 88,  2, 74,
        85,  2, 88, 66, 84,  2, 69, 74, 71, 71, 74, 68, 86, 77, 85,  2, 85, 80,
         2, 81, 77, 66, 90,  2, 88, 74, 85, 73,  2, 74, 85,  2, 67, 70, 68, 66,
        86, 84, 70,  2, 74, 85,  2, 88, 66, 84,  2, 84, 73, 66, 83, 81, 16,  2,
        46, 74, 77, 90,  2, 88, 66, 79, 85, 70, 69,  2, 85, 80,  2, 84, 73, 66,
        83, 70,  2, 85, 73, 70,  2, 79, 70, 70, 69, 77, 70,  2, 88, 74, 85, 73,
         2, 73, 70, 83,  2, 78, 80, 78, 14,  2, 84, 80,  2, 84, 73, 70,  2, 68,
        80, 86, 77, 69,  2, 84, 70, 88,  2, 66,  2, 67, 86, 85, 85, 80, 79,  2,
        80, 79,  2, 73, 70, 83,  2, 84, 73, 74, 83, 85, 16,  1, 46, 74, 77, 90,
         2, 88, 70, 79, 85,  2, 85, 80, 

# Create a dataloader

In [12]:
context_length = 8

In [13]:
class DataLoader:
    def __init__(self, context_length, batch_size, data):
        self.context_length = context_length
        self.batch_size = batch_size
        self.data = data
    
    def get_batch(self):
        ix = torch.randint(len(self.data) - self.context_length, (self.batch_size,))
        x = torch.stack([self.data[i:i+self.context_length] for i in ix])
        y = torch.stack([self.data[i+1:i+self.context_length+1] for i in ix])
        return x, y

In [14]:
torch.manual_seed(1337)
dl_train = DataLoader(context_length=8, batch_size=4, data=train_data)
xb, yb = dl_train.get_batch()
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(dl_train.batch_size):
    for t in range(dl_train.context_length):
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
torch.Size([4, 8])
tensor([[72, 70, 85,  2, 73, 86, 83, 85],
        [68, 76,  2, 84, 80, 78, 70,  2],
        [70, 84, 85,  2, 71, 83, 74, 70],
        [ 2, 66,  2, 88, 66, 77, 76, 16]])
targets:
torch.Size([4, 8])
tensor([[70, 85,  2, 73, 86, 83, 85,  2],
        [76,  2, 84, 80, 78, 70,  2, 85],
        [84, 85,  2, 71, 83, 74, 70, 79],
        [66,  2, 88, 66, 77, 76, 16,  1]])
----
when input is [72] the target: 70
when input is [72, 70] the target: 85
when input is [72, 70, 85] the target: 2
when input is [72, 70, 85, 2] the target: 73
when input is [72, 70, 85, 2, 73] the target: 86
when input is [72, 70, 85, 2, 73, 86] the target: 83
when input is [72, 70, 85, 2, 73, 86, 83] the target: 85
when input is [72, 70, 85, 2, 73, 86, 83, 85] the target: 2
when input is [68] the target: 76
when input is [68, 76] the target: 2
when input is [68, 76, 2] the target: 84
when input is [68, 76, 2, 84] the target: 80
when input is [68, 76, 2, 84, 80] the target: 78
when input is [68, 

In [16]:
val_set_name = 'TinyStories-valid.txt'
val_path = Path.cwd() / '..' / 'data' / 'raw' / val_set_name
with open(val_path, 'r') as f:
    val_text = f.read()

val_data = torch.tensor(tokenizer.encode(val_text), dtype=torch.long)

dl_val = DataLoader(context_length=8, batch_size=4, data=val_data)

In [23]:
@torch.no_grad()
def estimate_loss(model, dl_train, dl_val, eval_iters):
    out = {}
    model.eval()
    for name, dl in [('train', dl_train), ('val', dl_val)]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = dl.get_batch()
            logits, loss = model(X, Y)
            losses[k] = loss.item()

        out[name] = losses.mean()
    model.train()
    return out

In [29]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = GPTLanguageModel(
    vocab_size=vocab_size,
    d_model=16,
    seq_len=context_length,
    n_layers=4,
    d_k=8,
    d_v=8,
    n_heads=4,
    device=device
)
m = model.to(device)

Count the number of parameters

In [30]:
print(sum(p.numel() for p in m.parameters()) / 1e6, 'M parameters')

0.025139 M parameters


In [31]:
learning_rate = 1e-3

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

Train the model.

In [32]:
eval_iters = 200
max_iters = 5000
eval_interval = 100

for iter in range(max_iters):
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss(model, dl_train, dl_val, eval_iters)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    
    xb, yb = dl_train.get_batch()

    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 5.6260, val loss 5.6227
step 100: train loss 3.5634, val loss 3.5697
step 200: train loss 3.0214, val loss 3.0535
step 300: train loss 2.8662, val loss 2.8639
step 400: train loss 2.7410, val loss 2.7486
step 500: train loss 2.6371, val loss 2.6446
step 600: train loss 2.6040, val loss 2.6114
step 700: train loss 2.5575, val loss 2.5482
step 800: train loss 2.4675, val loss 2.4920
step 900: train loss 2.4801, val loss 2.4892
step 1000: train loss 2.4732, val loss 2.4239
step 1100: train loss 2.4282, val loss 2.4311
step 1200: train loss 2.3685, val loss 2.4133
step 1300: train loss 2.3909, val loss 2.3778
step 1400: train loss 2.3518, val loss 2.3900
step 1500: train loss 2.3422, val loss 2.3148
step 1600: train loss 2.3422, val loss 2.3081
step 1700: train loss 2.3001, val loss 2.3426
step 1800: train loss 2.3503, val loss 2.3204
step 1900: train loss 2.2905, val loss 2.2734
step 2000: train loss 2.2482, val loss 2.2306
step 2100: train loss 2.2572, val loss 2.2353


Generate from the model.

In [35]:
context = torch.zeros((1, 8), dtype=torch.long, device=device)
generated = m.generate(context, max_new_tokens=2000)[0].tolist()
decoded = tokenizer.decode(generated)
print(decoded)

								ang.
Shey thew hite was ron said rikeans. She pether salan inchbesy nak the teu blnincgem lor sllcedote both, feche in sanaay diry anke sok aw ma bily greooo iiy fonr in. The buuir thto: es fuly wirlyes and pnekt dald sinay eve qpimed int. Shan giny, gou."
Bin.
<|e, hendole thapy ol fa inky, hip was mont a fer isted tul was Dhe meatt racig foun.
The rut roks faen frow. Wusted conn, Si. the holp darook a and!"Hs a a drooks. As ven nop sopit fon ske yom a and mar ig x and and hrop Than whas shink says in worghre wik be ug enbam ink tk, yom ol wanghe sain to clo"
Oner and sorin, saw hagan angrd""
He ou, thith shalt go thaid ay "Oin dimu this thur omom astead oun sbut saree far, che.Iat, in sor tpichh oon wisk is fore a wut was sman goLp and wing askid nnan das cpily woane fit tham san to goy he pive. Heoors inds."" The daidsude thappy, a put nukk to the shin
Tin man, hot the the hatigis. The ry, kwis. Soy tom shere foft on'ce to yhe pit noned a to to col dand she or llors min, the