In [1]:
import numpy as np
import torchvision
import torch
import logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)

from mingpt.utils import set_seed
set_seed(42)

In [2]:
from datasets import CopyDataset

ds_train_params = dict(
    n_episodes=1,
    colors=2,
    fig_w=2,
    fig_h=2,
    field_w=10,
    field_h=10,
    n_figs_on_field=3
)

ds_test_params = dict(
    n_episodes=1,
    colors=2,
    fig_w=3,
    fig_h=3,
    field_w=10,
    field_h=10,
    n_figs_on_field=2
)

train_data = CopyDataset(1000_000, **ds_train_params)
test_data = CopyDataset(10_000, **ds_test_params)

maximum size ~ 4096000000
maximum size ~ 2621440000


In [3]:
test_data.ds[100,0]

array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 1, 0, 0, 0, 0, 0, 0],
       [1, 0, 1, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

In [4]:
train_data.ds[100,0]

array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

In [5]:
from torch.utils.data import Dataset

class ImageCopyDataset(Dataset):
    """
    wrap up the pytorch CIFAR-10 dataset into our own, which will convert images into sequences of integers
    """
    
    def __init__(self, ds):
        self.ds = ds.ds
        self.vocab_size = ds.colors
        self.block_size = ds.field_w * ds.field_h * 2 - 1 #duplicate field
        
    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        x = self.ds[idx][0].flatten() #take only first episode
        perm_context = x #this context is permanent for all seqs
        not_perm_context = x[:-1] #this context will be masked in attention
        y = x
        x = np.concatenate((perm_context, not_perm_context))
        return x, y # always just predict the next one in the sequence
    
train_dataset = ImageCopyDataset(train_data)
test_dataset = ImageCopyDataset(test_data)

In [6]:
train_dataset[0]

(array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
        0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0]),
 array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0,

In [7]:
from mingpt.model import GPT, GPTConfig, GPT1Config
from mingpt.trainer import Trainer, TrainerConfig

# we'll do something a bit smaller
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size,
                  embd_pdrop=0.0, resid_pdrop=0.1, attn_pdrop=0.1,
                  n_layer=12, n_head=8, n_embd=256)
model = GPT(mconf)

tokens_per_epoch = len(train_dataset) * train_dataset.block_size
train_epochs = 1 # todo run a bigger model and longer, this is tiny

# initialize a trainer instance and kick off training
tconf = TrainerConfig(max_epochs=train_epochs, batch_size=128, learning_rate=3e-3,
                      betas = (0.9, 0.95), weight_decay=0,
                      lr_decay=True, warmup_tokens=tokens_per_epoch, final_tokens=train_epochs*tokens_per_epoch,
                      ckpt_path='cifar10_model.pt',
                      num_workers=4, tolerance_lim=1000)

trainer = Trainer(model, train_dataset, test_dataset, tconf)

10/20/2020 17:30:53 - INFO - mingpt.model -   number of parameters: 9.529600e+06


In [8]:
trainer.train()

epoch 1 iter 2077: train loss 0.01124. lr 4.009809e-04:  27%|██▋       | 2078/7813 [11:45<32:27,  2.94it/s]10/20/2020 17:42:42 - INFO - mingpt.trainer -   
early stopping: best_loss %f, loss %f
10/20/2020 17:42:42 - INFO - mingpt.trainer -   saving cifar10_model.pt
epoch 1 iter 2077: train loss 0.01124. lr 4.009809e-04:  27%|██▋       | 2078/7813 [11:45<32:27,  2.95it/s]
10/20/2020 17:42:50 - INFO - mingpt.trainer -   test loss: 0.014688
10/20/2020 17:42:50 - INFO - mingpt.trainer -   saving cifar10_model.pt


In [10]:
checkpoint = torch.load('cifar10_model.pt')
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [25]:
indices = np.random.randint(0, len(test_dataset), 10)
test = np.array([test_dataset[i][1] for i in indices])
test = torch.from_numpy(test).to(trainer.device)

In [26]:
test.shape

torch.Size([10, 100])

In [23]:
from mingpt.utils import sample

In [27]:
res = sample(model, test, None)

In [32]:
res[9].reshape(2,10,10)

tensor([[[0, 1, 1, 0, 0, 0, 0, 0, 0, 0],
         [0, 1, 1, 0, 0, 0, 0, 0, 0, 0],
         [0, 1, 1, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
         [0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
         [0, 0, 0, 0, 0, 0, 1, 1, 1, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[0, 1, 1, 0, 0, 0, 0, 0, 0, 0],
         [0, 1, 1, 0, 0, 0, 0, 0, 0, 0],
         [0, 1, 1, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
         [0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
         [0, 0, 0, 0, 0, 0, 1, 1, 1, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]], device='cuda:0')