In [None]:
!git clone https://github.com/karpathy/minGPT.git
!pip install -e ./minGPT

In [1]:
import torch
from torch.utils.data import Dataset

from mingpt.model import GPT
from mingpt.trainer import Trainer

from datasets import load_dataset
from transformers import GPT2TokenizerFast



In [2]:
model_type = 'gpt-nano'
model_config = GPT.get_default_config()
model_config.model_type = model_type
model_config.vocab_size = 50257 # openai's model vocabulary
model_config.block_size = 1024  # openai's model block_size (i.e. input context length)
model = GPT(model_config)

number of parameters: 2.55M


In [3]:
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

batch_size = 64
max_length = 512
train_dataset = load_dataset("wikitext", "wikitext-103-v1", split='train[:5%]')
train_dataset = train_dataset.map(
    lambda batch: tokenizer(batch["text"], padding='max_length', max_length=max_length, truncation=True, return_tensors='pt'), 
    remove_columns=['text'], 
    batch_size=batch_size,
    batched=True, 
)
train_dataset = train_dataset.with_format("torch")
train_dataset

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 90068
})

In [4]:
class WikitextDataset(Dataset):
    def __init__(self, dataset):
        x = dataset['input_ids']
        size, _ = x.shape
        y = torch.empty((size, max_length))
        y[:,:-1] = x[:,1:]
        y[:,-1] = torch.ones(size) * 50256
        self.data = torch.cat((x.unsqueeze(2), y.unsqueeze(2)), dim=2)
        self.data = self.data.reshape((size, 2, max_length))
        self.data = self.data.to(torch.int64)
        print(self.data.shape)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        x, y = self.data[idx]
        return x, y

dataset = WikitextDataset(train_dataset)

torch.Size([90068, 2, 512])


In [9]:
import gc
gc.collect()

import psutil
process = psutil.Process()

In [10]:
train_config = Trainer.get_default_config()
train_config.learning_rate = 6e-4
train_config.max_iters = 60_000
train_config.batch_size = 16
trainer = Trainer(train_config, model, dataset)
def on_batch_end(t):
    print(f'{t.iter_time} | DT: {t.iter_dt}, iter: {t.iter_num}, loss: {t.loss}, mem: {process.memory_info().rss}')
trainer.add_callback('on_batch_end', on_batch_end)
trainer.run()

running on device cpu
1698332923.6657112 | DT: 0.0, iter: 0, loss: 0.04953203722834587, mem: 5228290048
1698332927.0223153 | DT: 3.3566040992736816, iter: 1, loss: 0.057481199502944946, mem: 5183942656
1698332929.613135 | DT: 2.590819835662842, iter: 2, loss: 0.49202868342399597, mem: 5190336512
1698332932.1125753 | DT: 2.4994401931762695, iter: 3, loss: 0.2675166428089142, mem: 5190610944
1698332934.9692745 | DT: 2.856699228286743, iter: 4, loss: 0.07993975281715393, mem: 5221822464
1698332937.5030942 | DT: 2.5338196754455566, iter: 5, loss: 0.4100649058818817, mem: 5190447104
1698332940.013746 | DT: 2.5106518268585205, iter: 6, loss: 0.012333880178630352, mem: 5196398592
1698332942.7196143 | DT: 2.7058682441711426, iter: 7, loss: 0.01463429443538189, mem: 5234278400
1698332945.373458 | DT: 2.653843641281128, iter: 8, loss: 0.10016204416751862, mem: 5233557504
1698332947.997038 | DT: 2.623579978942871, iter: 9, loss: 0.01660393364727497, mem: 5193203712
1698332951.0798109 | DT: 3.0827

KeyboardInterrupt: 