# Model Training and Evaluation

---

## Setup


In [None]:
# Move up to project root directory (parent directory) for module imports
import os
os.chdir("../")

# Current working directory should now be project root
print("Current working directory:", os.getcwd())

In [None]:
# Imports

import json
import torch
from transformers import set_seed
import tempfile
import matplotlib.pyplot as plt

from src.visualize import create_captioning_dataset

from src.dataset import CocoDataset
from src.models import ImageCaptioningModel, TransformerMappingNetwork
from src.train import train

from src.eval import compute_caption_metrics, evaluate_captions, evaluate_epoch

In [None]:
# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

In [None]:
# Seed
SEED = 42
set_seed(SEED) # Helper function that sets the seed in all relevant libraries

---

## Prepare Datasets


In [None]:
MAX_CAPTION_LENGTH = 50
DATA_DIR = "coco_data/"
EMBEDDINGS_PATH = DATA_DIR + "embeddings/"
ANNOTATIONS_PATH = DATA_DIR + "annotations/"
CHECKPOINTS_PATH = "checkpoints/"
VAL_PATH = DATA_DIR + "val2017/"  # For FiftyOne visualization

In [None]:
# Training Dataset
train_dataset = CocoDataset(
    embeddings_path=EMBEDDINGS_PATH + "train_clip_embeddings.pt",
    annotations_path=ANNOTATIONS_PATH + "captions_train2017.json", 
    max_length=MAX_CAPTION_LENGTH,
    normalize_embeddings=False,  # `.pt` files already contain normalized embeddings
)

# Validation Dataset
val_dataset = CocoDataset(
    embeddings_path=EMBEDDINGS_PATH + "val_clip_embeddings.pt",
    annotations_path=ANNOTATIONS_PATH + "captions_val2017.json",
    max_length=MAX_CAPTION_LENGTH,
    normalize_embeddings=False,
)

# Note: No test dataset as test annotations are not publicly available

---

## Prepare Model


In [None]:
# Models
mapping_network = TransformerMappingNetwork(
    embed_dim=512,  # Embedding dimension
    gpt_dim=768,  # GPT-2 embedding dimension
    prefix_length=40,
    hidden_length=40,
)

model = ImageCaptioningModel(
    mapping_network=mapping_network,
    freeze_gpt_weights=True,  # We only fine-tune the mapping network during training
).to(DEVICE)

print(model)

---

## Train Model (with Validation Evaluation)


In [None]:
# Train image captioning model with validation evaluation
history = train(
    train_dataset=train_dataset,
    model=model,
    batch_size=64,
    num_epochs=1,
    device=DEVICE,
    outputs_dir=CHECKPOINTS_PATH,
    # Evaluation on validation set
    val_dataset=val_dataset,
    val_annotations_path=ANNOTATIONS_PATH + "val_split.json",
    eval_every_epoch=1,
)

print(
    f"\nBest validation CIDEr: {history['best_val_cider']:.4f} at epoch {history['best_epoch']}"
)