# AVQA Zero-shot & Fine-tuned Evaluation

This notebook demonstrates how to load a small subset of the [AVQA](https://github.com/AlyssaYoung/AVQA) dataset, instantiate the multimodal QA model implemented in `captionqa.qa`, and run both zero-shot and fine-tuned inference.

Update the dataset path below to match your local environment (see the project README for the recommended `CAPTIONQA_DATASETS` layout).

In [None]:
from pathlib import Path

from captionqa.qa import load_avqa_subset

DATASET_ROOT = Path("D:/CaptionQA/data/avqa/AVQA")  # TODO: customise for your machine
dataset = load_avqa_subset(DATASET_ROOT, split="val", subset_size=4)
print(f"Loaded {len(dataset)} AVQA items from {DATASET_ROOT}")

In [None]:
class SimpleTokenizer:
    def __init__(self, samples):
        self.special_tokens = ["<pad>", "<bos>", "<eos>"]
        self.token_to_id = {tok: idx for idx, tok in enumerate(self.special_tokens)}
        self.id_to_token = list(self.token_to_id.keys())
        for sample in samples:
            self._add_text(sample["question"])
            if sample.get("answer"):
                self._add_text(sample["answer"])
        self.pad_token_id = self.token_to_id["<pad>"]
        self.bos_token_id = self.token_to_id["<bos>"]
        self.eos_token_id = self.token_to_id["<eos>"]

    def _add_text(self, text):
        for token in text.lower().split():
            if token not in self.token_to_id:
                self.token_to_id[token] = len(self.id_to_token)
                self.id_to_token.append(token)

    def encode(self, text, add_special_tokens=True):
        tokens = []
        if add_special_tokens:
            tokens.append(self.bos_token_id)
        for token in text.lower().split():
            tokens.append(self.token_to_id.get(token, self.pad_token_id))
        if add_special_tokens:
            tokens.append(self.eos_token_id)
        return tokens

    def decode(self, token_ids, skip_special_tokens=True):
        words = []
        for token_id in token_ids:
            if skip_special_tokens and token_id < len(self.special_tokens):
                continue
            if token_id < len(self.id_to_token):
                words.append(self.id_to_token[token_id])
        return " ".join(words).strip()

    def __len__(self):
        return len(self.id_to_token)

tokenizer = SimpleTokenizer(dataset)\nprint(f"Tokenizer vocab size: {len(tokenizer)}")

In [None]:
import torch

from captionqa.qa import AVQAModel, run_zero_shot

model = AVQAModel(
    video_dim=512,
    audio_dim=256,
    vocab_size=len(tokenizer),
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)
results = run_zero_shot(model, dataset, tokenizer, device="cpu", max_length=8)
for item in results:
    print(f"{item.question_id}: {item.prediction} (answer: {item.reference})")

In [None]:
# Uncomment and provide a checkpoint file to evaluate a fine-tuned model
# from captionqa.qa import run_fine_tuned
# checkpoint_path = Path("path/to/avqa_checkpoint.pt")
# ft_results = run_fine_tuned(checkpoint_path, dataset, tokenizer, device="cuda")
# for item in ft_results:
#     print(f"{item.question_id}: {item.prediction} (answer: {item.reference})")