In [46]:
import torch
import os
import tiktoken
import math

from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from dataclasses import dataclass

In [38]:
dataset_dir = Path("dataset/harry-potter-books")
os.listdir(dataset_dir)

['01 Harry Potter and the Sorcerers Stone.txt',
 '02 Harry Potter and the Chamber of Secrets.txt',
 '03 Harry Potter and the Prisoner of Azkaban.txt',
 '04 Harry Potter and the Goblet of Fire.txt',
 '05 Harry Potter and the Order of the Phoenix.txt',
 '06 Harry Potter and the Half-Blood Prince.txt',
 '07 Harry Potter and the Deathly Hallows.txt']

In [39]:
enc = tiktoken.get_encoding("gpt2")
enc.encode("hello, my name is mr. robot")

[31373, 11, 616, 1438, 318, 285, 81, 13, 9379]

In [40]:
@dataclass
class GPTConfig:
    block_size: int = 128
    n_embd: int = 64
    n_head: int = 4
    vocab_size: int = 50304
    n_layers: int = 4 # or change to 6 for betterment

In [41]:
class MHA(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.qkv_mat = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.register_buffer('bias', torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
    
    def forward(self, x):
        B, T, C = x.shape
        qkv = self.qkv_mat(x)
        q, k, v = torch.split(qkv, self.config.n_embd, dim=-1)
        q = q.view(B, T, self.config.n_head, C // self.config.n_head).transpose(1, 2)
        k = k.view(B, T, self.config.n_head, C // self.config.n_head).transpose(1, 2)
        v = v.view(B, T, self.config.n_head, C // self.config.n_head).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) / math.sqrt(k.shape[-1])
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, value=float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v

        y = y.transpose(1, 2).view(B, T, C)
        y = self.c_proj(y)
        return y

In [42]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.linear_1 = nn.Linear(config.n_embd, 2 * config.n_embd)
        self.gelu = nn.GELU(approximate="tanh")
        self.linear_2 = nn.Linear(2 * config.n_embd, config.n_embd)
    
    def forward(self, x):
        x = self.linear_1(x)
        x = self.gelu(x)
        x = self.linear_2(x)
        return x

In [43]:
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mha = MHA(config)
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
    
    def forward(self, x):
        x = x + self.ln_1(self.mha(x))
        x = x + self.ln_2(self.mlp(x))
        return x

In [81]:
class hpGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict(dict(
            w_tok_emb = nn.Embedding(config.vocab_size, config.n_embd),
            w_pos_emb = nn.Embedding(config.block_size, config.n_embd),
            blocks = nn.ModuleList([Block(config) for _ in range(config.n_layers)]),
            ln_f = nn.LayerNorm(config.n_embd)
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

    def forward(self, idx, targets=None): # idx : encoded text
        print("finding tok_emb")
        tok_emb = self.transformer.w_tok_emb(idx) # B, T, C
        print("finding pos_emb")
        pos_emb = self.transformer.w_pos_emb(torch.arange(self.config.block_size, device=idx.device))

        print(tok_emb.shape, pos_emb.shape)
        x = tok_emb + pos_emb

        for block in self.transformer.blocks:
            x = block(x)
        x = self.transformer.ln_f(x)
        
        logits = self.lm_head(x)
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1))
        
        return logits, loss
    
    def generate(self, start_string: str, num_return_sequences: int, max_length: int, device: torch.device):
        assert max_length <= self.config.block_size
        enc = tiktoken.get_encoding("gpt2")
        tokens = enc.encode(start_string)
        tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
        xgen = tokens.to(device)

        while xgen.shape[1] < max_length:
            with torch.no_grad():
                logits, _ = self(xgen)
                logits = logits[:, -1, :]
                probs = F.softmax(logits, dim=-1)
                top50_values, top50_indices = torch.topk(probs, k=50, dim=-1)
                prob_indices = torch.multinomial(top50_values, num_samples=1)
                attached_tensor = torch.gather(top50_indices, dim=-1, index=prob_indices)
                xgen = torch.cat((xgen, attached_tensor), dim=1)
        
        for row_t in xgen:
            print(enc.decode(row_t.tolist()))


In [45]:
# ROUGH WORK
# 
# probz = torch.randn(3, 4, 5)
# probz[:, -1, :].shape

# a = torch.arange(15).reshape(3, 1, -1)
# a

# top3_vals, top3_idx = torch.topk(a, k=3)
# top3_vals, top3_idx
# top3_idx.shape

# sample_idx = torch.multinomial(top3_vals.view(top3_vals.shape[0], -1).type(torch.float), num_samples=1)
# sample_idx

# torch.gather(top3_idx.squeeze(1), dim=1, index=sample_idx)

hpGPT:
embedding matrix x2
blocks list
    blocks:
        mha + ln
        mlp + ln
mlp
softmax

In [None]:
# class HPDataset(Dataset):
#     def __init__(self, texts, tokenizer_name="gpt2", max_length=128):
#         self.texts = texts
#         self.tokenizer = tiktoken.get_encoding(tokenizer_name)
#         self.max_length = max_length
    
#     def __len__(self)
#         return len(self.texts)
    
#     def 

In [84]:
def load_tokens(filename):
    with open(filename, "r") as f:
        f_content = f.read()
    
    enc = tiktoken.get_encoding("gpt2")
    tokens = enc.encode(f_content)
    tokens = torch.tensor(tokens, dtype=torch.long)
    return tokens


class HPDataloaderLite:
    def __init__(self, B, T, split):
        self.B = B
        self.T = T
        assert split in {"train", "val"}

        data_root = dataset_dir
        shards = os.listdir(data_root)
        shards = [s for s in shards if split in s]
        shards = sorted(shards)
        shards = [os.path.join(data_root, s) for s in shards]
        self.shards = shards
        assert len(shards) > 0, f"no shards found for split {split}"
        
        self.reset()
    
    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position : self.current_position + B*T + 1]
        if len(buf) != B*T + 1:
            self.current_shard = (self.current_shard + 1) % len(self.shards)
            self.tokens = load_tokens(self.shards[self.current_shard])
            self.current_position = 0
            buf = self.tokens[self.current_position : self.current_position + B*T + 1]
            
        x = buf[:-1].view(B, T)
        y = buf[1:].view(B, T)

        print(f"inside dataloader: {x.shape, y.shape}, shapes of x, y")

        self.current_position += B*T

        # if self.current_position == (self.current_parent_batch + 1) * 32 * B * T:
        #     if self.current_shard + 1 == len(self.shards):
        #         self.current_parent_batch = (self.current_parent_batch + 1) % 381
        #     self.current_shard = (self.current_shard + 1) % len(self.shards)
        #     self.tokens = load_tokens(self.shards[self.current_shard])
        #     self.current_position = self.current_parent_batch * 32 * B * T
            
        return x, y

    def reset(self):
        self.current_shard = 0
        self.tokens = load_tokens(self.shards[self.current_shard])
        self.current_position = 0
        self.current_parent_batch = 0

    def reset_from_config(self, current_shard, current_parent_batch):
        self.current_shard = current_shard
        self.current_parent_batch = current_parent_batch
        self.tokens = load_tokens(self.shards[self.current_shard])
        self.current_position = current_parent_batch * 32 * self.B * self.T

In [74]:
book_list = os.listdir(dataset_dir)
for i in range(7):
    book_list[i] = dataset_dir / book_list[i]


for book in book_list:
    with open(book, 'r') as f:
        f_content = f.read()
    f_content_encoding = enc.encode(f_content)
    print(len(f_content_encoding))

124336
137248
175913
289032
367905
269790
305068


batch size = 4096
B = 8

grad accumulation for grad_accum_steps = 4098 / 8

accumulate and then take average

design a cosine dropping learning rate


In [62]:
max_lr = 3e-4
min_lr = 0.1 * max_lr
warmup_steps = 16
max_steps = 128

def get_lr(step):
    if step < warmup_steps:
        return min_lr + 0.9 * max_lr * step / warmup_steps
    else:
        coeff = (max_steps - step) / (max_steps - warmup_steps)
        coeff = math.sin(coeff * math.pi / 2)
        return min_lr + coeff * (max_lr - min_lr)

In [63]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [65]:
model = hpGPT(GPTConfig)
model.to(device)
model = torch.compile(model)

In [None]:
batch_size = 4096 * 2
B = 8 # can increase in case of gpu training
T = GPTConfig.block_size
grad_accum_steps = batch_size //(B*T)

train_loader = HPDataloaderLite(B, T, "train")
val_loader = HPDataloaderLite(B, T, "val")

optimizer = torch.optim.AdamW(params=model.parameters(), lr=3e-4)

iters = []
loss_list = []

for step in range(max_steps):
    val_loss = 0.0
    if step % 8 == 0:
        model.eval()
        with torch.no_grad():
            x, y = val_loader.next_batch()
            x, y = x.to(device), y.to(device)

            print("inside train loop:", {x.shape, y.shape}, "shapes of x, y")

            logits, val_loss = model(x, y)
        print(f"val loss: {val_loss.item():.4f}")
        print()

    model.train()
    loss_accum = 0.0
    optimizer.zero_grad()
    for grad_step in range(grad_accum_steps):
        x, y = train_loader.next_batch()
        x, y = x.to(device), y.to(device)

        logits, loss = model(x, y)
        loss /= grad_accum_steps
        loss_accum += loss.detach()
        loss.backward()
    
    lr = get_lr(step)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
        
    optimizer.step()
    iters.apend(step)
    loss_list.append(loss_accum.cpu().item())

    if step % 4 == 0 or step == max_steps-1:
        print(f"step: {step}, avg_loss: {loss_accum.item():.4f}, lr: {lr:.4e}")
    
    if step % 16 == 0 or step == max_steps-1:
        model.eval()
        model.generate("Wingardium Leviosa!", 3, 50, device)
        print()
        print()
        

E0705 23:39:55.392000 13957 torch/_subclasses/fake_tensor.py:2431] [0/5] failed while attempting to run meta for aten.view.default
E0705 23:39:55.392000 13957 torch/_subclasses/fake_tensor.py:2431] [0/5] Traceback (most recent call last):
E0705 23:39:55.392000 13957 torch/_subclasses/fake_tensor.py:2431] [0/5]   File "/home/lurix/code/ml/.venv/lib64/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 2427, in _dispatch_impl
E0705 23:39:55.392000 13957 torch/_subclasses/fake_tensor.py:2431] [0/5]     r = func(*args, **kwargs)
E0705 23:39:55.392000 13957 torch/_subclasses/fake_tensor.py:2431] [0/5]   File "/home/lurix/code/ml/.venv/lib64/python3.13/site-packages/torch/_ops.py", line 756, in __call__
E0705 23:39:55.392000 13957 torch/_subclasses/fake_tensor.py:2431] [0/5]     return self._op(*args, **kwargs)
E0705 23:39:55.392000 13957 torch/_subclasses/fake_tensor.py:2431] [0/5]            ~~~~~~~~^^^^^^^^^^^^^^^^^
E0705 23:39:55.392000 13957 torch/_subclasses/fake_tensor.py

torch.Size([8, 128]) torch.Size([8, 128])


TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_method view(*(FakeTensor(..., size=(8, 128, 4, 16)), 8, 128, 64), **{}): got ValueError('Cannot view a tensor with shape torch.Size([8, 128, 4, 16]) and strides (8192, 16, 2048, 1) as a tensor with shape (8, 128, 64)!')

from user code:
   File "/tmp/ipykernel_13957/2311210325.py", line 19, in forward
    x = block(x)
  File "/tmp/ipykernel_13957/168379530.py", line 10, in forward
    x = x + self.ln_1(self.mha(x))
  File "/tmp/ipykernel_13957/3830721539.py", line 22, in forward
    y = y.transpose(1, 2).view(B, T, C)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"


In [1]:
import os
os.getcwd()

'/home/lurix/code/ml/projects/nanogpt-self'

In [6]:
from typing import List


class BPETokenizer:
    def __init__(self):
        self.itos = {}
        self.stoi = {}
        self.vocab_size = 0
        self.max_key = 0
        self.merges = {}
    
    def train(self, text_corpus: str, vocab_size: int):
        """
        Takes in the corpus of text to train on and the vocabulary size.
        
        Input: 
            text_corpus: str - The corpus of text to train the tokenizer on.

            vocab_size: int - The size of the vocabulary of the tokenizer.
        """
        self.vocab_size = vocab_size
        text_corpus_set = set(text_corpus)
        text_corpus_unique = sorted(text_corpus_set)

        assert len(text_corpus_unique) <= vocab_size, f"vocab_size ({vocab_size}) must be greater than or equal to the number of unique chars ({len(text_corpus_unique)}) in the text corpus"

        for ch in text_corpus_unique:
            self.stoi[ch] = self.max_key
            self.itos[self.max_key] = ch
            self.max_key += 1

        text_idx = [self.stoi[ch] for ch in text_corpus]

        while self.max_key < vocab_size:
            pair_count = {}
            for i1, i2 in zip(text_idx, text_idx[1:]):
                pair_count[(i1, i2)] = pair_count.get((i1, i2), 0) + 1

            max_pair = max(pair_count, key=lambda pair: pair_count[pair])
            new_text_idx = []

            i = 0
            while i < len(text_idx)-1:
                if (text_idx[i], text_idx[i+1]) == max_pair:
                    new_text_idx.append(self.max_key)
                    i += 2
                else:
                    new_text_idx.append(text_idx[i])
                    i += 1
            
            merged_pair = self.itos[max_pair[0]] + self.itos[max_pair[1]]
            self.itos[self.max_key] = merged_pair
            self.stoi[merged_pair] = self.max_key

            self.merges[max_pair] = self.max_key
            self.max_key += 1
            text_idx = new_text_idx

        self.final_vocab = sorted(self.stoi.items(), key=lambda item: len(item[0]), reverse=True)
        print("Tokenizer successfully trained!")
        return
    
    def encode(self, text: str):
        """
        Function to encode a given string.

        Input:
            text: str - The text (string) which is to be encoded.
        """
        encoded_idx = []
        rem_text = text

        while rem_text:
            print(f"rem_text: {rem_text}")
            for substr, subkey in self.final_vocab:
                if rem_text.startswith(substr):
                    encoded_idx.append(subkey)
                    rem_text = rem_text[len(substr):]
                    break
        
        return encoded_idx

    def decode(self, idx: List[int]):
        """
        Function to decode a given list of encoded indices.

        Input:
            idx: List[int] - The list of encoded indices (encoding) which is to be decoded.
        """
        return "".join(self.itos[i] for i in idx)


In [7]:
enc = BPETokenizer()
enc.train("abab", 3)

enc.final_vocab

Tokenizer successfully trained!


[('ab', 2), ('a', 0), ('b', 1)]

In [8]:
enc.encode("ababc")

rem_text: ababc
rem_text: abc
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: c
rem_text: 

KeyboardInterrupt: 