In [None]:
!nvidia-smi

In [21]:
from mingpt.bpe import BPETokenizer
from mingpt.model import GPT

import torch
from datasets import load_dataset
import pandas as pd

In [22]:
class StoryDataset:
    def __init__(self, data, tokenizer, block_size=128):
        self.tokenizer = tokenizer
        self.block_size = block_size
        self.data1 = [
            self.format_example(data_point["text"]) for data_point in data
        ]

    def format_example(self, text):
        tokens = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=self.block_size)
        tokens = tokens.squeeze(0)
        
        return tokens

    def __len__(self):
        return len(self.data1)

    def __getitem__(self, idx):
        """
        Returns:
            x (torch.Tensor): Input tokens (question + answer prompt).
            y (torch.Tensor): Target tokens (shifted output).
        """
        # Handle both single index and list of indices
        if isinstance(idx, int):  # Single index
            tokens = self.data1[idx]
            x = torch.tensor(tokens[:-1], dtype=torch.long)
            y = torch.tensor(tokens[1:], dtype=torch.long)
            return x, y
        
        elif isinstance(idx, list) or isinstance(idx, torch.Tensor):  # Batch case
            batch_tokens = [self.data1[i] for i in idx]
            x_batch = [torch.tensor(tokens[:-1], dtype=torch.long) for tokens in batch_tokens]
            y_batch = [torch.tensor(tokens[1:], dtype=torch.long) for tokens in batch_tokens]
            return x_batch, y_batch
        
        else:
            raise TypeError(f"Invalid index type: {type(idx)}")


In [23]:
data = load_dataset('roneneldan/TinyStories')
data_frame = pd.DataFrame(data["train"])

In [None]:
data_frame

In [None]:
data_frame.iloc[0]['text']

In [26]:
tokenizer = BPETokenizer()

train_dataset = StoryDataset(data["train"].select(range(100)), tokenizer=tokenizer, block_size=512)

In [None]:
model_type = 'gpt2'
device = 'cuda'

model = GPT.from_pretrained(model_type)
model.to(device)

In [28]:
def generate(prompt='', num_samples=1, steps=20, do_sample=True):
    tokenizer = BPETokenizer()

    if prompt == '':
        x = torch.tensor([[tokenizer.encoder.encoder['<|endoftext|>']]], dtype=torch.long)
    else:
        x = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True, max_length=len(prompt.split())).to(device)

    x = x.expand(num_samples, -1)

    y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40)

    endoftext_token_id = tokenizer.encoder.encoder['<|endoftext|>']  # Get the ID of <|endoftext|> token

    for i in range(num_samples):
        # Decode the generated tokens
        out = tokenizer.decode(y[i].cpu().squeeze())

        # Check if the output contains the <|endoftext|> token and stop generating when encountered
        out_tokens = y[i].cpu().squeeze().tolist()
        
        # Find the position of <|endoftext|> token in the generated output
        if endoftext_token_id in out_tokens:
            end_pos = out_tokens.index(endoftext_token_id)
            out = tokenizer.decode(y[i].cpu().squeeze()[:end_pos])  # Slice the output up to <|endoftext|> token

        print('\n' + '-' * 80)
        print(out)

# Before finetuning

In [None]:
generate(prompt='One day, a little girl named Lily found a', num_samples=2, steps=50)

In [None]:
generate(prompt='One day,', num_samples=2, steps=50)

# Finetuning

In [None]:
from mingpt.trainer import Trainer

train_config = Trainer.get_default_config()
train_config.learning_rate = 5e-5
train_config.batch_size = 2
train_config.max_iters = 50
train_config.num_workers = 0
def collate_fn(batch):
    x_batch, y_batch = zip(*batch)
    return torch.stack(x_batch), torch.stack(y_batch)

train_config.collate_fn = collate_fn
trainer = Trainer(train_config, model, train_dataset)

In [None]:
losses = []
def batch_end_callback(trainer):
    if trainer.iter_num % 100 == 0:
        print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}")
        losses.append(trainer.loss.item())
trainer.set_callback('on_batch_end', batch_end_callback)

trainer.run()

print(losses)

# After Finetung

In [None]:
generate(prompt='One day, a little girl named Lily found a', num_samples=2, steps=50)

In [None]:
generate(prompt='One day,', num_samples=2, steps=50)