In [16]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
import torch
from torch.utils.data import DataLoader

from src.dataset import CocoDataset
from src.models import ClipCapModel, MLPMappingNetwork
from src.train import train
from src.test import generate_test_caption_predictions

In [18]:
# Take note that the embeddings in coco_train2014_image_embeddings.pt are already normalised

dataset = CocoDataset(
    embeddings_path="data/coco_train2014_image_embeddings.pt",
    annotations_path="data/coco_train2014_captions.json",
    max_length=50,
)

dataloader = DataLoader(
    dataset,
    batch_size=4, 
    shuffle=True
)

Dataset ready: 414113 captions for 82783 images.


In [24]:
# Number of batches
print(f"Number of batches in dataloader: {len(dataloader)}")

# View one batch
for batch in dataloader:
    print(
        f"Batch contents:"
        f"\n    token_ids shape: {batch['token_ids'].shape}"
        f"\n    clip_embedding shape: {batch['clip_embedding'].shape}"
        f"\n    attention_mask shape: {batch['attention_mask'].shape}"
        f"\n    captions: {batch['caption_text']}"
        f"\n    image ids: {batch['image_id']}"
    )
    # Dataloader is smart: will stack non-tensor items as well
    break

Number of batches in dataloader: 103529
Batch contents:
    token_ids shape: torch.Size([4, 50])
    clip_embedding shape: torch.Size([4, 512])
    attention_mask shape: torch.Size([4, 50])
    captions: ['A man riding some skis down a snowy trail.', 'a young boy is riding on a board outside', 'A man riding skis down the side of a snow covered ski slope.', 'Man holding a cat that is wearing a costume.']
    image ids: tensor([143306, 493888, 538187, 478675])


In [20]:
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Models
mapping_network = MLPMappingNetwork(prefix_length=10)
model = ClipCapModel(mapping_network=mapping_network).to(device)

In [25]:
print(model)

ClipCapModel(
  (mapping_network): MLPMappingNetwork(
    (model): Sequential(
      (0): Linear(in_features=512, out_features=3840, bias=True)
      (1): Tanh()
      (2): Linear(in_features=3840, out_features=7680, bias=True)
    )
  )
  (gpt): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D(nf=2304, nx=768)
            (c_proj): Conv1D(nf=768, nx=768)
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D(nf=3072, nx=768)
            (c_proj): Conv1D(nf=768, nx=3072)
            (act): NewGEL

In [None]:
# Train image captioning model
train(
    train_dataset=dataset, 
    model=model,
    batch_size=4, 
    num_epochs=1, 
    device=device
)