In [1]:
%load_ext autoreload
%autoreload 2

## Setup

In [2]:
# Add project root (parent directory) to system path (for module imports)
import sys

sys.path.append("../")

In [3]:
import torch
import os

from src.dataset import CocoDataset, split_coco_annotations
from src.models import RetrievalAugmentedTransformer, TransformerMappingNetwork, ImageCaptioningModel
from src.train import train, train_rat
from src.database.image_store import create_objectbox_store

  from .autonotebook import tqdm as notebook_tqdm


KeyboardInterrupt: 

In [None]:
# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cpu')

In [None]:
# Seed
SEED = 42
# TODO: Set seed for each relevant library (torch, numpy, random, etc.)

## Prepare Datasets

In [None]:
MAX_CAPTION_LENGTH = 50
EMBEDDINGS_PATH = "/mnt/c/Users/hoxia/Documents/NLDeeznuts/gpt2-image-captioning/data/data/coco/embeddings/"
ANNOTATIONS_PATH = "/mnt/c/Users/hoxia/Documents/NLDeeznuts/gpt2-image-captioning/data/data/coco/annotations/"

In [None]:
# We split the original COCO 2014 training set into a new training and validation set
split_coco_annotations(
    annotations_path=ANNOTATIONS_PATH + "captions_train2014.json",
    output_dir=ANNOTATIONS_PATH,
    split_ratio=0.8,
    seed=SEED,
)

Splitting: 66226 Train images, 16557 Val images.
Created:
- /mnt/c/Users/hoxia/Documents/NLDeeznuts/gpt2-image-captioning/data/data/coco/annotations/train_split.json
- /mnt/c/Users/hoxia/Documents/NLDeeznuts/gpt2-image-captioning/data/data/coco/annotations/val_split.json


In [None]:
# Training Dataset (orig. COCO 2014 TRAIN)
train_dataset = CocoDataset(
    embeddings_path=EMBEDDINGS_PATH + "train_clip_embeddings.pt",
    annotations_path=ANNOTATIONS_PATH + "captions_train2014.json",
    max_length=MAX_CAPTION_LENGTH,
    normalize_embeddings=False,  # `.pt` files already contain normalized embeddings
)

# # Validation Dataset (orig. COCO 2014 TRAIN)
# val_dataset = CocoDataset(
#     embeddings_path=EMBEDDINGS_PATH + "train_val_clip_embeddings.pt",
#     annotations_path=ANNOTATIONS_PATH + "val_split.json",
#     max_length=MAX_CAPTION_LENGTH,
#     normalize_embeddings=False,
# )

# # Test Dataset (orig. COCO 2017 Val)
# test_dataset = CocoDataset(
#     embeddings_path=EMBEDDINGS_PATH + "test_clip_embeddings.pt",
#     annotations_path=ANNOTATIONS_PATH + "captions_val2017.json",
#     max_length=MAX_CAPTION_LENGTH,
#     normalize_embeddings=False,
# )

Dataset ready: 414113 captions.


## Prepare Model

In [None]:
# Models
mapping_network = TransformerMappingNetwork(
    embed_dim=512,  # CLIP embedding dimension
    gpt_dim=768,  # GPT-2 embedding dimension
    prefix_length=40,
    hidden_length=40,
)

model = RetrievalAugmentedTransformer(
    embed_dim=512,
    mapping_network=mapping_network,
    freeze_gpt_weights=True,  # We only fine-tune the mapping network during training
).to(DEVICE)

print(model)



RetrievalAugmentedTransformer(
  (mapping_network): TransformerMappingNetwork(
    (linear): Linear(in_features=512, out_features=30720, bias=True)
    (transformer): TransformerEncoder(
      (layers): ModuleList(
        (0-7): 8 x EncoderLayer(
          (transformer_layer): TransformerEncoderLayer(
            (self_attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (linear1): Linear(in_features=768, out_features=3072, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
            (linear2): Linear(in_features=3072, out_features=768, bias=True)
            (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout1): Dropout(p=0.1, inplace=False)
            (dropout2): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
  )
  (gpt): GPT2LMHeadModel(
  

## Train Model

In [None]:
# Only for RAT: specify ObjectBox database store path and create store.
DB_STORE_PATH = os.path.join(os.path.expanduser("~"), "objectbox_db_fast")

db_store = create_objectbox_store(db_directory=DB_STORE_PATH)

In [None]:
# Train image captioning model
if isinstance(model, RetrievalAugmentedTransformer):

    train_rat(
        train_dataset=train_dataset,
        model=model,
        db_store=db_store,
        top_k=2,
        top_i=4,
        batch_size=64,
        num_epochs=1,
        device=DEVICE,
    )

elif isinstance(model, ImageCaptioningModel):

    train(
        train_dataset=train_dataset,
        model=model, batch_size=64,
        num_epochs=1,
        device=DEVICE
    )


Epoch 1/1:   0%|          | 0/6471 [00:00<?, ?it/s]


Exception: Unsupported type for 'NEAREST_NEIGHBOR': <class 'torch.Tensor'>