In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from torch.utils.data import DataLoader
from transformers import GPT2Tokenizer

from src.dataset import CocoDataset
from src.models import ClipCapModel, MLPMappingNetwork
from src.train import train

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = (
    tokenizer.eos_token
)  # Because GPT2 does not have a pad token by default

In [None]:
# Take note that the embeddings in coco_train2014_image_embeddings.pt are already normalised

dataset = CocoDataset(
    embeddings_path="data/coco_train2014_image_embeddings.pt",
    annotations_path="data/coco_train2014_captions.json",
    tokenizer=tokenizer,
    max_length=50,
)

In [None]:
dataset[0]

In [None]:
batch_size = 4

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Number of batches
print(f"Number of batches in dataloader: {len(dataloader)}")

# View one batch
for batch_idx, batch in enumerate(dataloader):
    print(
        f"Batch {batch_idx} contents:"
        f"\n token_ids shape: {batch['token_ids'].shape}"
        f"\n clip_embedding shape: {batch['clip_embedding'].shape}"
        f"\n attention_mask shape: {batch['attention_mask'].shape}"
        f"\n captions: {batch['caption_text']}"
        f"\n image ids: {batch['image_id']}"
    )
    # Dataloader is smart: will stack non-tensor items as well
    break

In [None]:
len(dataloader)

In [None]:
mapping_network = MLPMappingNetwork(prefix_length=10)

model = ClipCapModel(mapping_network=mapping_network)

train(train_dataset=dataset, model=model, batch_size=batch_size, num_epochs=1)