In [1]:
%load_ext autoreload
%autoreload 2

## Setup

In [9]:
import torch

from src.dataset import CocoDataset, split_coco_annotations
from src.models import ImageCaptioningModel, TransformerMappingNetwork
from src.train import train

In [10]:
# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda')

In [11]:
# Seed
SEED = 42
# TODO: Set seed for each relevant library (torch, numpy, random, etc.)

## Prepare Datasets

In [12]:
MAX_CAPTION_LENGTH = 50
EMBEDDINGS_PATH = "data/coco/embeddings/"
ANNOTATIONS_PATH = "data/coco/annotations/"

In [None]:
# We split the original COCO 2014 training set into a new training and validation set
split_coco_annotations(
    annotations_path=ANNOTATIONS_PATH + "captions_train2014.json",
    output_dir=ANNOTATIONS_PATH,
    split_ratio=0.8,
    seed=SEED,
)

Splitting: 66226 Train images, 16557 Val images.
Created:
- data/coco/annotations/train_split.json
- data/coco/annotations/val_split.json


In [16]:
# Training Dataset (orig. COCO 2014 TRAIN)
train_dataset = CocoDataset(
    embeddings_path=EMBEDDINGS_PATH + "train_val_clip_embeddings.pt",
    annotations_path=ANNOTATIONS_PATH + "train_split.json",
    max_length=MAX_CAPTION_LENGTH,
    normalize_embeddings=False,  # `.pt` files already contain normalized embeddings
)

# Validation Dataset (orig. COCO 2014 TRAIN)
val_dataset = CocoDataset(
    embeddings_path=EMBEDDINGS_PATH + "train_val_clip_embeddings.pt",
    annotations_path=ANNOTATIONS_PATH + "val_split.json",
    max_length=MAX_CAPTION_LENGTH,
    normalize_embeddings=False,
)

# Test Dataset (orig. COCO 2017 Val)
test_dataset = CocoDataset(
    embeddings_path=EMBEDDINGS_PATH + "test_clip_embeddings.pt",
    annotations_path=ANNOTATIONS_PATH + "captions_val2017.json",
    max_length=MAX_CAPTION_LENGTH,
    normalize_embeddings=False,
)

Dataset ready: 331287 captions.
Dataset ready: 82826 captions.
Dataset ready: 25014 captions.


## Prepare Model

In [None]:
# Models
mapping_network = TransformerMappingNetwork(
    embed_dim=512,  # CLIP embedding dimension
    gpt_dim=768,  # GPT-2 embedding dimension
    prefix_length=40,
    hidden_length=40,
)

model = ImageCaptioningModel(
    mapping_network=mapping_network,
    freeze_gpt_weights=True,  # We only fine-tune the mapping network during training
).to(DEVICE)

print(model)



ImageCaptioningModel(
  (mapping_network): TransformerMappingNetwork(
    (linear): Linear(in_features=512, out_features=30720, bias=True)
    (transformer): TransformerEncoder(
      (layers): ModuleList(
        (0-7): 8 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (linear1): Linear(in_features=768, out_features=3072, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=3072, out_features=768, bias=True)
          (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (gpt): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(

## Train Model

In [None]:
# Train image captioning model
train(
    train_dataset=train_dataset, model=model, batch_size=64, num_epochs=1, device=DEVICE
)