In [None]:
!git clone https://github.com/karpathy/minGPT.git
!pip install -e ./minGPT

In [1]:
%env TOKENIZERS_PARALLELISM=true

env: TOKENIZERS_PARALLELISM=true


In [2]:
import torch
from torch.utils.data import Dataset, DataLoader

from mingpt.model import GPT
from mingpt.trainer import Trainer

from datasets import load_dataset
from transformers import GPT2TokenizerFast



In [3]:
model_type = 'gpt-nano'
model_config = GPT.get_default_config()
model_config.model_type = model_type
model_config.vocab_size = 50257 # openai's model vocabulary
model_config.block_size = 1024  # openai's model block_size (i.e. input context length)
model = GPT(model_config)

number of parameters: 2.55M


In [4]:
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
batch_size = 64
max_length = 512

In [5]:
# dataset = dataset.map(
#     lambda batch: tokenizer(batch["text"], padding='max_length', max_length=max_length, truncation=True, return_tensors='pt'), 
#     remove_columns=['text'], 
#     batch_size=batch_size,
#     batched=True, 
# )
# return dataset.with_format("torch")

In [6]:
def get_wikitext(split, tokenizer, max_length=512):
    dataset = load_dataset("wikitext", "wikitext-103-v1", split=split)
    dataset = dataset.filter(lambda x: len(x['text']) > 0)
    def encode(batch):
        return tokenizer(batch["text"], padding='max_length', max_length=max_length, truncation=True, return_tensors='pt')
    dataset.set_transform(encode)
    return dataset


train_dataset = get_wikitext('train[:50%]', tokenizer, max_length=max_length)
eval_dataset = get_wikitext('validation', tokenizer, max_length=max_length)
train_dataset, eval_dataset

(Dataset({
     features: ['text'],
     num_rows: 582510
 }),
 Dataset({
     features: ['text'],
     num_rows: 2461
 }))

In [7]:
class WikitextDataset(Dataset):
    def __init__(self, dataset, max_length=512):
        self.dataset = dataset
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x = self.dataset[idx]['input_ids']
        y = None

        if isinstance(idx, torch.Tensor) or isinstance(idx, slice):
            size, _ = x.shape
            y = torch.empty((size, max_length))
            y[:,:-1] = x[:,1:]
            y[:,-1] = torch.ones(size) * 50256
        else:
            y = torch.empty((max_length), dtype=torch.long)
            y[:-1] = x[1:]
            y[-1] = 50256
        
        return x, y

tr_dataset = WikitextDataset(train_dataset, max_length)
ev_dataset = WikitextDataset(train_dataset, max_length)

In [8]:
# class WikitextDataset(Dataset):
#     def __init__(self, dataset, max_length=512):
#         x = dataset['input_ids']
#         size, _ = x.shape
#         y = torch.empty((size, max_length))
#         y[:,:-1] = x[:,1:]
#         y[:,-1] = torch.ones(size) * 50256
#         self.data = torch.cat((x.unsqueeze(2), y.unsqueeze(2)), dim=2)
#         self.data = self.data.reshape((size, 2, max_length))
#         self.data = self.data.to(torch.int64)
#         print(self.data.shape)

#     def __len__(self):
#         return self.data.shape[0]

#     def __getitem__(self, idx):
#         x, y = self.data[idx]
#         return x, y

# tr_dataset = WikitextDataset(train_dataset, max_length)
# ev_dataset = WikitextDataset(eval_dataset, max_length)

In [9]:
import gc
gc.collect()

import psutil
process = psutil.Process()

def get_mem():
    return process.memory_info().rss 

get_mem()

651210752

In [12]:
batch_size = 32

train_config = Trainer.get_default_config()
train_config.learning_rate = 6e-4
train_config.max_iters = 60_000
train_config.batch_size = batch_size
train_config.num_workers = 4
trainer = Trainer(train_config, model, tr_dataset)

ev_loader = DataLoader(
    ev_dataset,
    shuffle=False,
    batch_size=batch_size,
    num_workers=1,
)

@torch.no_grad()
def custom_evaluate(model, device):
    model.eval()
    losses = []
    for batch in ev_loader:
        batch = [t.to(device) for t in batch]
        x, y = batch
        logits, loss = model(x, y)
        losses.append(loss.item())
    model.train()
    return sum(losses) / len(losses)


benchmark = []

def on_batch_end(t):
    # if t.iter_num % 2 == 0: 
    eval_loss = 0 # custom_evaluate(t.model, device=t.device)
    mem = get_mem()
    print(f'DT: {t.iter_dt:.3f}, iter: {t.iter_num:05d}, train_loss: {t.loss:.4f}, eval_loss: {eval_loss:.4f}, mem: {mem / (1024 * 1024):.2f} MB')
    benchmark.append({'iter': t.iter_num, 'train_loss': t.loss, 'eval_loss': eval_loss, 'time': t.iter_time, 'mem': mem})

trainer.add_callback('on_batch_end', on_batch_end)
trainer.run()

running on device cpu


: 

In [8]:
torch.save(model.state_dict(), './model.pt')