In [None]:
!git clone https://github.com/karpathy/minGPT.git
!pip install -e ./minGPT

In [2]:
%env TOKENIZERS_PARALLELISM=true

env: TOKENIZERS_PARALLELISM=true


In [1]:
import torch
from torch.utils.data import Dataset, DataLoader

from mingpt.model import GPT
from mingpt.trainer import Trainer

from datasets import load_dataset
from transformers import GPT2TokenizerFast



In [3]:
model_type = 'gpt-nano'
model_config = GPT.get_default_config()
model_config.model_type = model_type
model_config.vocab_size = 50257 # openai's model vocabulary
model_config.block_size = 1024  # openai's model block_size (i.e. input context length)
model = GPT(model_config)

number of parameters: 2.55M


In [4]:
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
batch_size = 16
max_length = 256

In [5]:
# dataset = dataset.map(
#     lambda batch: tokenizer(batch["text"], padding='max_length', max_length=max_length, truncation=True, return_tensors='pt'), 
#     remove_columns=['text'], 
#     batch_size=batch_size,
#     batched=True, 
# )
# return dataset.with_format("torch")

In [15]:
def get_wikitext(split, tokenizer, max_length=512):
    dataset = load_dataset("wikitext", "wikitext-103-v1", split=split)
    dataset = dataset.filter(lambda x: len(x['text']) > 0)
    # dataset = dataset.map(lambda x: {'text': x['text'], 'length': [len(y) for y in x['text']] }, batched=True, batch_size=64)
    # dataset = dataset.sort('length')
    # def encode(batch):
    #     return tokenizer(batch["text"], padding='max_length', max_length=max_length, truncation=True, return_tensors='pt')
    # dataset.set_transform(encode)
    dataset = dataset.map(
        lambda batch: tokenizer(batch["text"], padding='max_length', max_length=max_length, truncation=True, return_tensors='pt'), 
        remove_columns=['text'], 
        batch_size=batch_size,
        batched=True, 
    )
    dataset = dataset.remove_columns(['attention_mask'])
    dataset = dataset.with_format("torch")
    return dataset


train_dataset = get_wikitext('train[:50%]', tokenizer, max_length=max_length)
eval_dataset = get_wikitext('validation', tokenizer, max_length=max_length)
train_dataset, eval_dataset

(Dataset({
     features: ['input_ids'],
     num_rows: 582510
 }),
 Dataset({
     features: ['input_ids'],
     num_rows: 2461
 }))

In [16]:
class WikitextDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)
    
    # x = tokenizer(self.dataset[idx]["text"], padding='max_length', max_length=max_length, truncation=True, return_tensors='pt')
    # x = x['input_ids']
    
    # y = torch.empty(x.shape)
    # y[:,:-1] = x[:,1:]
    # y[:,-1] = torch.ones(x.shape[0]) * 50256

    # =========================================================

    # x = self.dataset[idx]['input_ids']
    # y = None
    # if isinstance(idx, torch.Tensor) or isinstance(idx, slice):
    #     y = torch.empty(x.shape)
    #     y[:,:-1] = x[:,1:]
    #     y[:,-1] = torch.ones(x.shape[0]) * 50256
    # else:
    #     y = torch.empty(x.shape, dtype=torch.long)
    #     y[:-1] = x[1:]
    #     y[-1] = 50256

    def __getitem__(self, idx):
        x = self.dataset[idx]['input_ids']
        y = torch.empty(x.shape, dtype=x.dtype)
        y[:-1] = x[1:]
        y[-1] = 50256
        return x, y

tr_dataset = WikitextDataset(train_dataset)
ev_dataset = WikitextDataset(train_dataset)

In [20]:
tr_dataset[:2], tr_dataset[0]

((tensor([[  796,   569, 18354,  7496, 17740,  6711,   796,   220,   198, 50256,
           50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
           50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
           50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
           50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
           50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
           50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
           50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
           50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
           50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
           50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
           50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
           50256, 50256, 502

In [8]:
# class WikitextDataset(Dataset):
#     def __init__(self, dataset, max_length=512):
#         x = dataset['input_ids']
#         size, _ = x.shape
#         y = torch.empty((size, max_length))
#         y[:,:-1] = x[:,1:]
#         y[:,-1] = torch.ones(size) * 50256
#         self.data = torch.cat((x.unsqueeze(2), y.unsqueeze(2)), dim=2)
#         self.data = self.data.reshape((size, 2, max_length))
#         self.data = self.data.to(torch.int64)
#         print(self.data.shape)

#     def __len__(self):
#         return self.data.shape[0]

#     def __getitem__(self, idx):
#         x, y = self.data[idx]
#         return x, y

# tr_dataset = WikitextDataset(train_dataset, max_length)
# ev_dataset = WikitextDataset(eval_dataset, max_length)

In [21]:
import gc
gc.collect()

import psutil
process = psutil.Process()

def get_mem():
    return process.memory_info().rss 

get_mem()

2382630912

In [22]:
train_config = Trainer.get_default_config()
train_config.learning_rate = 6e-4
train_config.max_iters = 3600
train_config.batch_size = batch_size
train_config.num_workers = 4
trainer = Trainer(train_config, model, tr_dataset)

ev_loader = DataLoader(
    ev_dataset,
    shuffle=False,
    batch_size=batch_size,
    num_workers=1,
)

@torch.no_grad()
def custom_evaluate(model, device):
    model.eval()
    losses = []
    for batch in ev_loader:
        batch = [t.to(device) for t in batch]
        x, y = batch
        logits, loss = model(x, y)
        losses.append(loss.item())
    model.train()
    return sum(losses) / len(losses)


benchmark = []

def on_batch_end(t):
    # if t.iter_num % 2 == 0: 
    eval_loss = 0 # custom_evaluate(t.model, device=t.device)
    mem = get_mem()
    print(f'DT: {t.iter_dt:.3f}, iter: {t.iter_num:05d}, train_loss: {t.loss:.4f}, eval_loss: {eval_loss:.4f}, mem: {mem / (1024 * 1024):.2f} MB')
    benchmark.append({'iter': t.iter_num, 'train_loss': t.loss, 'eval_loss': eval_loss, 'time': t.iter_time, 'mem': mem})
    gc.collect()

trainer.add_callback('on_batch_end', on_batch_end)
trainer.run()

running on device cpu
DT: 0.000, iter: 00000, train_loss: 10.9032, eval_loss: 0.0000, mem: 3281.75 MB
DT: 1.184, iter: 00001, train_loss: 10.4913, eval_loss: 0.0000, mem: 3287.70 MB
DT: 1.253, iter: 00002, train_loss: 10.4378, eval_loss: 0.0000, mem: 3467.96 MB
DT: 1.008, iter: 00003, train_loss: 10.3889, eval_loss: 0.0000, mem: 3575.82 MB
DT: 0.972, iter: 00004, train_loss: 10.1820, eval_loss: 0.0000, mem: 3683.86 MB
DT: 0.979, iter: 00005, train_loss: 10.1112, eval_loss: 0.0000, mem: 3683.82 MB
DT: 0.979, iter: 00006, train_loss: 10.1139, eval_loss: 0.0000, mem: 3731.95 MB
DT: 1.024, iter: 00007, train_loss: 10.0356, eval_loss: 0.0000, mem: 3756.01 MB
DT: 8.943, iter: 00008, train_loss: 10.1388, eval_loss: 0.0000, mem: 3792.07 MB
DT: 2.230, iter: 00009, train_loss: 9.9840, eval_loss: 0.0000, mem: 3839.73 MB
DT: 1.914, iter: 00010, train_loss: 9.8615, eval_loss: 0.0000, mem: 3923.80 MB
DT: 1.787, iter: 00011, train_loss: 9.8537, eval_loss: 0.0000, mem: 4043.74 MB
DT: 1.878, iter: 0001

: 

In [8]:
torch.save(model.state_dict(), './model.pt')

In [11]:
benchmark

[{'iter': 0,
  'train_loss': tensor(10.8628, grad_fn=<NllLossBackward0>),
  'eval_loss': 0,
  'time': 1698593017.1018848,
  'mem': 1785602048},
 {'iter': 1,
  'train_loss': tensor(10.5546, grad_fn=<NllLossBackward0>),
  'eval_loss': 0,
  'time': 1698593018.4563336,
  'mem': 2024230912},
 {'iter': 2,
  'train_loss': tensor(10.4861, grad_fn=<NllLossBackward0>),
  'eval_loss': 0,
  'time': 1698593019.8006442,
  'mem': 2070908928},
 {'iter': 3,
  'train_loss': tensor(10.2586, grad_fn=<NllLossBackward0>),
  'eval_loss': 0,
  'time': 1698593020.986921,
  'mem': 2083426304},
 {'iter': 4,
  'train_loss': tensor(10.2835, grad_fn=<NllLossBackward0>),
  'eval_loss': 0,
  'time': 1698593021.920062,
  'mem': 2158997504},
 {'iter': 5,
  'train_loss': tensor(10.2894, grad_fn=<NllLossBackward0>),
  'eval_loss': 0,
  'time': 1698593022.7859998,
  'mem': 2247143424},
 {'iter': 6,
  'train_loss': tensor(10.3551, grad_fn=<NllLossBackward0>),
  'eval_loss': 0,
  'time': 1698593023.662099,
  'mem': 23225057