In [None]:
import torch 
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from jaxtyping import Float, Int
from transformers import AutoTokenizer # type: ignore
from transformer import TransformerEncoder
from utils import Model, Batch, ModelOutput
from typing import Literal
from dataclasses import dataclass

# reload imported modules automatically (so you dont have to restart kernel when changing .py files)
%load_ext autoreload
%autoreload 2

tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [None]:
@dataclass
class GPT2Batch(Batch):
    x: Int[Tensor, "B T"]
    y: Int[Tensor, "B T"] 
 
@dataclass 
class GPT2Output(ModelOutput):
    logits: Float[Tensor, "B T C"]
    loss: Float[Tensor, ""]

class TokenizedTextDataset(Dataset):
    def __init__(self, text_path: str, tokenizer: AutoTokenizer, train_ratio=0.9, split: Literal["train", "val"]="train"):
        self.tokenizer = tokenizer
        with open(text_path, 'r', encoding='utf-8') as f:
            text = f.read()
        print(f"Tokenizing text...")
        self.tokens = tokenizer.encode(text) # type: ignore
        self.tokens = torch.tensor(self.tokens, dtype=torch.long)
        print(f"Total tokens: {len(self.tokens)}")
        
        split_idx = int(len(self.tokens) * train_ratio)
        if split == "train": self.tokens = self.tokens[:split_idx]
        elif split == "val": self.tokens = self.tokens[split_idx:]
        else: raise ValueError("split must be 'train' or 'val'")
        print(f"Tokens in split: {len(self.tokens)}")
    
    def __len__(self):
        return len(self.tokens)
    
    def __getitem__(self, idx):
        return self.tokens[idx]
    
class GPTDataLoader:
    def __init__(self, dataset: TokenizedTextDataset, n_ctx=512, batch_size=4, shuffle=True):
        self.dataset = dataset
        self.n_ctx = n_ctx
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.tokenizer = dataset.tokenizer
        
    def __iter__(self):
        n_samples = max(0, len(self.dataset) - self.n_ctx)
        indices = torch.randperm(n_samples).tolist() if self.shuffle else list(range(n_samples))
        batch_x, batch_y = [], []
        
        for idx in indices:
            # (need +1 to have targets for the last position)
            chunk = self.dataset[idx:idx+self.n_ctx+1]
            batch_x.append(chunk[:-1])
            batch_y.append(chunk[1:])
            
            if len(batch_x) == self.batch_size:
                yield GPT2Batch(torch.stack(batch_x), torch.stack(batch_y))
                batch_x, batch_y = [], []
                
        if batch_x:
            yield GPT2Batch(torch.stack(batch_x), torch.stack(batch_y))
    
    def __len__(self):
        n_samples = max(0, len(self.dataset) - self.n_ctx)
        return (n_samples + self.batch_size - 1) // self.batch_size

dataset_train = TokenizedTextDataset("data/shakespeare.txt", tokenizer, split="train") # type: ignore
dataset_val = TokenizedTextDataset("data/shakespeare.txt", tokenizer, split="val") # type: ignore
train_loader = GPTDataLoader(dataset_train, n_ctx=512, batch_size=4)
val_loader = GPTDataLoader(dataset_val, n_ctx=512, batch_size=4, shuffle=False)

In [None]:
class GPT2(nn.Module):
    def __init__(self, tokenizer: AutoTokenizer):
        super().__init__()
        self.tokenizer = tokenizer
        n_vocab = self.tokenizer.vocab_size # type: ignore
        self.n_ctx = 512
        self.transformer = TransformerEncoder(
            n_vocab=n_vocab,
            n_ctx=self.n_ctx,
            n_layers=12,
            n_head=12,
            d_emb=768,
            d_mlp=3072,
            d_head=64,
            p_dropout=0.1,
            masked_attention=True
        )
        self.lm_head = nn.Linear(768, n_vocab, bias=False)
        self.lm_head.weight = self.transformer.tok_emb.weight # share weights with token embedding # type: ignore
    
    def forward(self, x: Int[Tensor, "B T"]) -> Float[Tensor, "B T C"]:
        logits = self.lm_head(self.transformer(x))
        return logits
    
    def get_output(self, batch: GPT2Batch) -> GPT2Output:
        logits = self.forward(batch.x)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), batch.y.view(-1))
        return GPT2Output(loss, logits)

    def generate(self, prompt: str, max_new_tokens: int=100, temperature: float=1.0) -> str:
        tokens = self.tokenizer.encode_plus(prompt) # type: ignore
        context = torch.tensor(tokens, dtype=torch.long).unsqueeze(0)
        result_tokens = tokens.copy()
        print(tokens)
        for _ in range(max_new_tokens):
            if context.size(1) > self.n_ctx: context = context[:, -self.n_ctx:]
            logits = self.forward(context)[0]
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            result_tokens.append(next_token.item())
            context = torch.cat((context, next_token), dim=1)
        return self.tokenizer.decode(result_tokens) # type: ignore

In [None]:
torch.manual_seed(42)

model = GPT2(tokenizer)
model

In [None]:
print(len(train_loader), len(val_loader))

batch = next(iter(train_loader))
model.get_output(batch).logits.var()