# Initial configuration

In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
!nvidia-smi

Thu Feb  8 08:28:31 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.06              Driver Version: 545.23.06    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A40                     On  | 00000000:23:00.0 Off |                    0 |
|  0%   28C    P8              26W / 300W |      7MiB / 46068MiB |      0%   E. Process |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

# SuccessVQA Baseline Script

In [None]:
import argparse
from data.captaincook4d.constants import ANNOTATIONS_DIR, VIDEO_DIR, ERROR_CATEGORIES
import os
from pprint import pprint
import random
import torch
from tqdm import tqdm
import yaml

from travel.constants import DATA_CACHE_DIR, MODEL_CACHE_DIR, RESULTS_DIR
from travel.model import MISTAKE_DETECTION_STRATEGIES, heuristic_cutoff_time, HEURISTIC_TARGET_FRAMES_PROPORTION
from travel.model.vqa import VQAOutputs, VQAResponse, SUCCESSVQA_PROMPT_TEMPLATES, get_vqa_response_token_ids
from travel.data import MistakeDetectionTasks, get_cutoff_time_by_proportion
from travel.data.captaincook4d import CaptainCook4DDataset

os.environ['HF_HOME'] = MODEL_CACHE_DIR
from transformers import AutoProcessor, AutoModelForVision2Seq

parser = argparse.ArgumentParser()
parser.add_argument("task", type=str, default="captaincook4d", choices=[task.value for task in MistakeDetectionTasks])
parser.add_argument("eval_partition", type=str, choices=["val", "test"])
parser.add_argument("vlm_name", type=str, default="llava-hf/llava-1.5-7b-hf", choices=list(SUCCESSVQA_PROMPT_TEMPLATES.keys()), help="Name or path to Hugging Face model for VLM.")
parser.add_argument("mistake_detection_strategy", type=str, default="heuristic", choices=list(MISTAKE_DETECTION_STRATEGIES.keys()))
args = parser.parse_args()

# Load mistake detection dataset
# TODO: implement cache for pre-loaded dataset?
eval_dataset = CaptainCook4DDataset(data_split=args.eval_partition,
                                    debug_n_examples_per_class=20)

# Some mistake detection strategies are only applied to a specific proportion of frames; if so, we can skip running inference on these frames
if args.mistake_detection_strategy == "heuristic":
    target_frames_proportion = HEURISTIC_TARGET_FRAMES_PROPORTION
else:
    target_frames_proportion = None

# Load VLM
vlm_processor = AutoProcessor.from_pretrained(args.vlm_nmae)
vlm = AutoModelForVision2Seq.from_pretrained(args.vlm_name, cache_dir=DATA_CACHE_DIR, load_in_8bit=True) # NOTE: when loading in 8bit, batch inference may output nans
pprint(vlm.config)
# device = "cuda" if torch.cuda.is_available() else "cpu"
# print(device)
# TODO: ensure zero temperature

prompt_template = SUCCESSVQA_PROMPT_TEMPLATES[args.vlm_name]
response_token_ids = get_vqa_response_token_ids(vlm_processor)

# TODO: perform inference in batches?
# TODO: cache VQA outputs for models?
vqa_outputs = []
for example in tqdm(eval_dataset, "running inference on clips"):
    this_vqa_outputs = []
    
    step_id = example.procedure_id
    step = example.procedure_description
    
    prompt = prompt_template.format(step=step)
    expected_answer = VQAResponse["Yes"]
    
    if target_frames_proportion is not None:
        cutoff_time = get_cutoff_time_by_proportion(example, target_frames_proportion)
    else:
        cutoff_time = None

    for frame, frame_time in zip(example.frames, example.frame_times):
        if cutoff_time is not None and frame_time < cutoff_time:
            # Don't run inference on this frame
            this_vqa_outputs.append([VQAOutputs(
                step_id,
                frame,
                prompt,
                expected_answer,
                response_token_ids,
                torch.zeros((vlm_processor.tokenizer.vocab_size)).float() # Placeholder zero logits since we didn't prompt the VLM
            )])
            continue

        # Forward pass
        with torch.no_grad():
            inputs = vlm_processor(text=prompt, images=frame, return_tensors="pt").to(vlm.device)
            logits = vlm(**inputs).logits[0] # (seq length, vocab size)
            logits = logits[-1].detach().cpu() # (vocab size)

            this_vqa_outputs.append(
                [VQAOutputs(
                    step_id,
                    frame,
                    prompt,
                    expected_answer,
                    response_token_ids,
                    logits,        
                )]
            )        
        
    vqa_outputs.append(this_vqa_outputs)

# TODO: add DET curve to evaluator based on confidence - improve heuristic approach better?
evaluator = MISTAKE_DETECTION_STRATEGIES[args.mistake_detection_strategy](eval_dataset.examples, vqa_outputs)
metrics = evaluator.get_mistake_detection_metrics()
print("Mistake Detection Metrics:")
pprint(metrics)

results_folder = f"SuccessVQA_{args.vlm_name.split("/")[-1]}"
metrics_filename = f"metrics_{args.mistake_detection_strategy}_{args.eval_partition}"