In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from src.dataset import CocoDataset
from src.models import MLPMappingNetwork, TransformerMappingNetwork, ClipCapModel
from src.train import train
from torch.utils.data import DataLoader
from transformers import GPT2Tokenizer



In [4]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token # Because GPT2 does not have a pad token by default

In [5]:
# Take note that the embeddings in coco_train2014_image_embeddings.pt are already normalised

dataset = CocoDataset(
    embeddings_path="data/coco_train2014_image_embeddings.pt",
    annotations_path="data/coco_train2014_captions.json",
    tokenizer=tokenizer,
    max_length=50
)

Dataset ready: 414113 captions for 82783 images.


In [6]:
dataset[0]

{'token_ids': tensor([   32,   845,  3424,   290,   880, 24789,  6565, 12436, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256]),
 'clip_embedding': tensor([-2.0582e-02,  5.1814e-02, -1.1414e-02, -2.8239e-03,  1.4750e-02,
          2.4026e-02,  3.4900e-02, -2.4130e-02, -2.3070e-02,  5.2361e-04,
         -1.5503e-02, -3.9511e-02, -1.9680e-02, -2.1942e-02,  6.5903e-03,
         -6.5856e-03, -4.0837e-02,  1.2718e-02,  4.7172e-03,  1.3094e-02,
          5.8578e-02, -3.2071e-04,  4.9159e-02,  5.4241e-02, -1.5497e-02,
          3.0319e-02, -1.4742e-02, -1.4111e-02,  1.1611e-02, -1.3766e-04,
          1.6698e-02,  2.0703e-02, -2.6310e-02, -5.1049e-03, -3.2191e-03,
          2.0171e-02, -1.3052e-02, -2.5989e-02,  4.378

In [7]:
batch_size = 4

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [8]:
# View one batch
for batch in dataloader:
    print(batch)
    # Dataloader is smart: will stack non-tensor items as well
    break

{'token_ids': tensor([[   64,  1402,  2933,   318,  4769,   257,  4077,   290,  7872, 16162,
         32680, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256],
        [43559,  3180,   317, 41153,  3963, 41670, 11651,   367, 15173,  2751,
           317, 18671,  3069,  9370, 11651, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256],
        [   32,  2415,  4769,   257,  7480,   351,   257,  3704,   286, 14256,
           319,   340,    13, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256,

In [10]:
mapping_network = MLPMappingNetwork(
    prefix_length=10
)

model = ClipCapModel(
    mapping_network=mapping_network
)

train(dataset=dataset, model=model, batch_size=batch_size, num_epochs=1)

Batch 0 contents:
 token_ids shape: torch.Size([4, 50])
 clip_embedding shape: torch.Size([4, 512])
 attention_mask shape: torch.Size([4, 50])
Epoch [1/1], Step [1/103529], Loss: 8.0623
Batch 1 contents:
 token_ids shape: torch.Size([4, 50])
 clip_embedding shape: torch.Size([4, 512])
 attention_mask shape: torch.Size([4, 50])
Batch 2 contents:
 token_ids shape: torch.Size([4, 50])
 clip_embedding shape: torch.Size([4, 512])
 attention_mask shape: torch.Size([4, 50])
Batch 3 contents:
 token_ids shape: torch.Size([4, 50])
 clip_embedding shape: torch.Size([4, 512])
 attention_mask shape: torch.Size([4, 50])
Batch 4 contents:
 token_ids shape: torch.Size([4, 50])
 clip_embedding shape: torch.Size([4, 512])
 attention_mask shape: torch.Size([4, 50])
Batch 5 contents:
 token_ids shape: torch.Size([4, 50])
 clip_embedding shape: torch.Size([4, 512])
 attention_mask shape: torch.Size([4, 50])
Batch 6 contents:
 token_ids shape: torch.Size([4, 50])
 clip_embedding shape: torch.Size([4, 512])

KeyboardInterrupt: 