# Model Training

---

## Setup


In [1]:
# Move up to project root directory (parent directory) for module imports
import os

os.chdir("../")

# Current working directory should now be project root
print("Current working directory:", os.getcwd())

Current working directory: /mnt/c/Users/hoxia/Documents/NLDeeznuts/gpt2-image-captioning


In [None]:
# Imports

import torch
from transformers import set_seed

from src.dataset import CocoDataset
from src.models import RetrievalAugmentedTransformer, TransformerMappingNetwork
from src.train import train_rat
from src.utils import load_gpt2_tokenizer, get_max_workers
from src.database.image_store import create_objectbox_store

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

In [None]:
# Seed
SEED = 42
set_seed(SEED)  # Helper function that sets the seed in all relevant libraries

---

## Prepare Datasets


In [None]:
MAX_CAPTION_LENGTH = 50
DATA_DIR = "coco_data/"
EMBEDDINGS_PATH = DATA_DIR + "embeddings/"
ANNOTATIONS_PATH = DATA_DIR + "annotations/"
CHECKPOINTS_PATH = "checkpoints/"

In [None]:
# Tokenizer
gpt2_tokenizer = load_gpt2_tokenizer()

# Training Dataset
train_dataset = CocoDataset(
    embeddings_path=EMBEDDINGS_PATH + "train_clip_embeddings.pt",
    annotations_path=ANNOTATIONS_PATH + "captions_train2017.json",
    tokenizer=gpt2_tokenizer,
    max_length=MAX_CAPTION_LENGTH,
    normalize_embeddings=False,  # `.pt` files already contain normalized embeddings
)

# Validation Dataset
val_dataset = CocoDataset(
    embeddings_path=EMBEDDINGS_PATH + "val_clip_embeddings.pt",
    annotations_path=ANNOTATIONS_PATH + "captions_val2017.json",
    tokenizer=gpt2_tokenizer,
    max_length=MAX_CAPTION_LENGTH,
    normalize_embeddings=False,
)

# Note: No test dataset as test annotations are not publicly available

---

## Prepare Vector Database

In [None]:
DB_STORE_PATH = "vector_db"

db_store = create_objectbox_store(db_directory=DB_STORE_PATH)

---

## Prepare Model


In [None]:
# Transformer Mapping Network Params
EMBED_DIM = 512  # Embedding dimension
GPT_DIM = 768  # GPT-2 embedding dimension
PREFIX_LENGTH = 40
HIDDEN_LENGTH = 40

# Image Captioning Model Params
FREEZE_GPT_WEIGHTS = True  # Whether to fine-tune GPT-2 alongside the mapping network
PREFIX_TASK_PROMPT: str | None = None

BATCH_SIZE = 64
MAX_WORKERS = get_max_workers(BATCH_SIZE)

print(f"Using max workers: {MAX_WORKERS}")

In [None]:
# Models
mapping_network = TransformerMappingNetwork(
    embed_dim=EMBED_DIM,
    gpt_dim=GPT_DIM,
    prefix_length=PREFIX_LENGTH,
    hidden_length=HIDDEN_LENGTH,
)

model = RetrievalAugmentedTransformer(
    embed_dim=EMBED_DIM,
    max_workers=MAX_WORKERS,
    mapping_network=mapping_network,
    prefix_task_prompt=PREFIX_TASK_PROMPT,
    tokenizer=gpt2_tokenizer,
    freeze_gpt_weights=FREEZE_GPT_WEIGHTS,
).to(DEVICE)

print(model)

---

## Train Model (with Validation Evaluation)


In [None]:
# Training Params
TRAIN_BATCH_SIZE = 64
NUM_EPOCHS = 1
NUM_WORKERS = 4
LEARNING_RATE = 2e-5
NUM_WARMUP_STEPS = 0
SAVE_EVERY_EPOCH = 1

# Validation Params
EVAL_EVERY_EPOCH = 1
EVAL_BATCH_SIZE = 64
EVAL_MAX_CAPTION_LENGTH = MAX_CAPTION_LENGTH
EVAL_TEMPERATURE = 1.0
EVAL_TOP_P = 0.9

In [None]:
# Train image captioning model with validation evaluation
train_history = train_rat(
    # Training params
    train_dataset=train_dataset,
    model=model,
    db_store=db_store,
    top_k=2,
    top_i=10,
    batch_size=TRAIN_BATCH_SIZE,
    num_epochs=NUM_EPOCHS,
    num_workers=NUM_WORKERS,
    learning_rate=LEARNING_RATE,
    num_warmup_steps=NUM_WARMUP_STEPS,
    save_every_epoch=SAVE_EVERY_EPOCH,
    device=DEVICE,
    outputs_dir=CHECKPOINTS_PATH,
    # Eval params
    val_dataset=val_dataset,
    val_annotations_path=ANNOTATIONS_PATH + "captions_val2017.json",
    eval_every_epoch=EVAL_EVERY_EPOCH,
    eval_batch_size=EVAL_BATCH_SIZE,
    eval_max_length=EVAL_MAX_CAPTION_LENGTH,
    eval_temperature=EVAL_TEMPERATURE,
    eval_top_p=EVAL_TOP_P,
)

print(
    f"\nBest validation CIDEr: {train_history['best_val_cider']:.4f} at epoch {train_history['best_epoch']}"
)

In [None]:
train_history