In [7]:
import torch

import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

from transformers import ImageGPTFeatureExtractor, ImageGPTForCausalImageModeling

%matplotlib inline

%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
def generate_chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def generate_images(feature_extractor, model, batch_size, num_iters, path_to_generated_images, device):
    index = list(generate_chunks([i for i in range(num_iters)], batch_size))
    images = np.random.rand(num_iters, 3, 32, 32)
    for i in tqdm(range(len(index))):
        images[index[i]] = generate_images_batch(feature_extractor, model, len(index[i]), device)
        np.savez('sampled.npz', images[:index[i][-1]])    
    np.savez('sampled.npz', images)   
    return images
    

@torch.no_grad()
def generate_images_batch(feature_extractor, model, batch_size, device):
    
    context = torch.full((batch_size, 1), model.config.vocab_size - 1) #initialize with SOS token (with ID 512)
    context = torch.Tensor(context).to(device)
    
    output = model.generate(input_ids=context, max_length=model.config.n_positions + 1, temperature=1.0, do_sample=True, top_k=40)        
    
    clusters = feature_extractor.clusters
    n_px = feature_extractor.size

    samples = output[:,1:].cpu().detach().numpy()
    samples_img = [np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [32, 32, 3]).astype(np.uint8) for s in samples]
    samples_img = [img.transpose(2, 0, 1) for img in samples_img]
    return samples_img

def plot_image_batch(imgs):
    
    imgs = [img.transpose(1, 2, 0) for img in imgs]
    
    f, axes = plt.subplots(1, len(imgs), dpi=300)
    
    for img, ax in zip(imgs, axes):
        ax.axis('off')
        ax.imshow(img)    
        

In [10]:
SIZE = 'small'
PRETRAINED_MODEL = f'openai/imagegpt-{SIZE}'
NUM_IMAGES_TO_GENERATE = 10000
BATCH_SIZE = 16

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

feature_extractor = ImageGPTFeatureExtractor.from_pretrained(PRETRAINED_MODEL)
model = ImageGPTForCausalImageModeling.from_pretrained(PRETRAINED_MODEL)
model.to(device)

ImageGPTForCausalImageModeling(
  (transformer): ImageGPTModel(
    (wte): Embedding(513, 512)
    (wpe): Embedding(1024, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-23): 24 x ImageGPTBlock(
        (ln_1): ImageGPTLayerNorm()
        (attn): ImageGPTAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): ImageGPTLayerNorm()
        (mlp): ImageGPTMLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): QuickGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): ImageGPTLayerNorm()
  )
  (lm_head): Linear(in_features=512, out_features=512, bias=False)
)

In [5]:
# imgs = generate_images(
#     feature_extractor=feature_extractor,
#     model=model,
#     batch_size=BATCH_SIZE, 
#     num_iters=NUM_IMAGES_TO_GENERATE,
#     device=device,
#     path_to_generated_images='sampled-medium.npz'
# )

  0%|                                                  | 0/1250 [00:00<?, ?it/s]This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (1024). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.
  0%|                                      | 2/1250 [04:47<49:49:48, 143.74s/it]


KeyboardInterrupt: 

In [12]:
batch_size = 8 
context = torch.full((batch_size, 1), model.config.vocab_size - 1) #initialize with SOS token (with ID 512)
context = torch.Tensor(context).to(device)
    

In [13]:
context

tensor([[512],
        [512],
        [512],
        [512],
        [512],
        [512],
        [512],
        [512]], device='cuda:0')

In [14]:
output = model.generate(input_ids=context, max_length=model.config.n_positions + 1, temperature=1.0, do_sample=True, top_k=10)    

In [15]:
output.shape

torch.Size([8, 1025])

In [None]:
tokenized_datasets = dataset.map(tokenize_function, batched=True)