In [187]:
import os
import torch
import tiktoken
import torch.nn as nn
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader

In [191]:
def encode(data):
    return torch.tensor(backbone.encode(data), dtype=torch.long)
def decode(data):
    return backbone.decode(data)

In [192]:
class HarryPotterDataset(Dataset):
    def __init__(self, data: str, context_len: int):
        self.encoded_data = encode(data)
        self.context_len = context_len
    def __getitem__(self, idx):
        x = torch.stack([self.encoded_data[idx:idx+self.context_len]])
        y = torch.stack([self.encoded_data[idx+1:idx+self.context_len+1]])
        return torch.squeeze(x), torch.squeeze(y)
    def __len__(self):
        return len(self.encoded_data)

In [179]:
class BigramModel(nn.Module):
    def __init__(self, vocab_size):
        super(BigramModel, self).__init__()
        self.token_emb_table = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, x, targets=None):
        logits = self.token_emb_table(x)
        if targets is None:
            loss = None
        else:
            B, C, V = logits.shape
            logits = logits.view(B*C, V)
            targets = targets.view(B*C)
            loss = nn.functional.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, idx, max_tokens):
        for _ in range(max_tokens):
            logits, _ = self(idx)
            logits = logits[:, -1, :]
            logit_probs = nn.functional.softmax(logits, dim=-1)
            next_idx = torch.multinomial(logit_probs, num_samples=1)
            idx = torch.cat((idx, next_idx), dim=1)
        return idx

In [None]:
if __name__ == "__main__":
    vocab_size = backbone.n_vocab
    epochs = 1000
    val_interval = 100
    batch_size = 32
    context_len = 16
    lr = 1e-3
    device = torch.device('cuda:0')
    
    with open("/content/all_books.txt", "r") as file:
        data = file.read()
    
    backbone = tiktoken.get_encoding('gpt2')
    model = BigramModel(vocab_size).to(device)
    optim = torch.optim.AdamW(model.parameters(), lr=lr)
    
    # Make dataset and dataloader
    hpd = HarryPotterDataset(data, context_len)
    hpl = DataLoader(hpd, batch_size=batch_size, shuffle=True)
    
    for epx in tqdm(range(epochs)):
        if epx % val_interval == 0:
            idx = torch.zeros((1, 1), dtype=torch.long, device=device)
            generated_text = model.generate(idx, max_new_tokens=500)
            generated_text = decode(generated_text.tolist())
            print(f"epoch: {epx+1}\t|\ttext: {generated_text}")
        
        for x, y in hpl:
            logits, loss = model(x, y)
            optim.zero_grad(set_to_none=True)
            loss.backward()
            optim.step()