# Setup

In [None]:
%load_ext autoreload
%autoreload 2

import os
os.chdir("/nfs/turbo/coe-chaijy/sstorks/simulation_informed_pcr4nlu/TRAVEl")
from travel import init_travel
init_travel()

import argparse
from collections import defaultdict, Counter
from copy import deepcopy
from itertools import product
import json
import numpy as np
import os
import pickle
from PIL import Image
from pprint import pprint
import shutil
import spacy
import time
import torch
from tqdm import tqdm
from transformers import AutoModelForVision2Seq, AutoModelForCausalLM, AutoProcessor, BitsAndBytesConfig, AutoModelForSequenceClassification, AutoTokenizer, PhrasalConstraint           

from travel.constants import RESULTS_DIR, IMAGES_CHUNK_SIZE, HF_TOKEN, CONFIG_PATH
from travel.data.captaincook4d import CaptainCook4DDataset
from travel.data.ego4d import Ego4DMistakeDetectionDataset
from travel.data.mistake_detection import MistakeDetectionTasks
from travel.data.vqa import VQAResponse, get_vqa_response_token_ids, VQAOutputs, DIALOG_START_TOKENS, IMAGE_TOKENS, USER_START_TOKENS, USER_END_TOKENS, ASSISTANT_START_TOKENS, ASSISTANT_END_TOKENS, IVQA_PREAMBLE, IVQA_SUCCESS_QUESTION
from travel.data.vqg import generate_vqg_prompt_icl
from travel.model import simple_lm_prompt_beam_search, simple_vlm_prompt_beam_search, compute_completion_log_likelihoods, compute_completion_log_likelihoods_encoder_decoder, compute_completion_log_likelihoods_vlm
from travel.model.metrics import question_coherence_metrics_nli, question_coherence_metrics_vlm, generate_det_curve, compile_accuracy_and_coherence_metrics, generate_3d_overview_graph
from travel.model.mistake_detection import MISTAKE_DETECTION_THRESHOLDS
from travel.model.nli import NLI_MODEL_PATH, NLI_BATCH_SIZE
from travel.model.vqa import run_vqa_with_visual_filter
from travel.model.vqg import cleanup_generated_question

Load model:

In [None]:
from transformers import AutoModelForVision2Seq, AutoModelForCausalLM, AutoProcessor, BitsAndBytesConfig, AutoModelForSequenceClassification, AutoTokenizer, PhrasalConstraint           

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=False,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

VLM_NAME = "llava-hf/llava-1.5-7b-hf"

# Load VLM - some VLMs may be under AutoModelForVision2Seq, some may be under AutoModelForCausalLM
try:
    vlm = AutoModelForVision2Seq.from_pretrained(VLM_NAME, quantization_config=bnb_config, trust_remote_code=True, token=HF_TOKEN)   
except Exception as e:
    print("Encountered exception when trying to load model with AutoModelForVision2Seq:")
    pprint(e)
    
    vlm = AutoModelForCausalLM.from_pretrained(VLM_NAME, quantization_config=bnb_config, trust_remote_code=True, token=HF_TOKEN)
vlm_processor = AutoProcessor.from_pretrained(VLM_NAME, trust_remote_code=True, token=HF_TOKEN)
vlm_processor.tokenizer.padding_side = "left"
response_token_ids = get_vqa_response_token_ids(vlm_processor.tokenizer)

# We'll use VLM's LM directly to generate questions
if getattr(vlm, "language_model", None):
    lm = vlm.language_model
else:
    lm = vlm
tokenizer = vlm_processor.tokenizer
tokenizer.pad_token_id = tokenizer.eos_token_id

# NLI model to score consistency and verifiability
nli_model = AutoModelForSequenceClassification.from_pretrained(NLI_MODEL_PATH, quantization_config=bnb_config)
nli_tokenizer = AutoTokenizer.from_pretrained(NLI_MODEL_PATH)

Load data:

In [None]:
# Load approopriate evaluation dataset
dataset = None
for retry in range(5):
    print(f"Loading evaluation dataset (try {retry})...")
    try:
        dataset = Ego4DMistakeDetectionDataset(data_split="val", 
                                                mismatch_augmentation=True,
                                                multi_frame=False,
                                                debug_n_examples_per_class=100)
        break
    except Exception as e:
        print("Encountered error during data loading:")
        pprint(e)
        time.sleep(60)
if dataset is None:
    raise ValueError("Could not load dataset after retrying!")

Other global args:

In [None]:
MAX_ITERATIONS = 10
N_ICL_DEMONSTRATIONS = 20
UNSURE_RANGE = 0.1

# kwargs to force question generations to have a "?" and start with words that would typically begin a yes/no question
question_generation_constraints = [    
    PhrasalConstraint(
        [vlm_processor.tokenizer("Is it blue?", add_special_tokens=False).input_ids[-1]]
    ),
]
yes_no_q_tokens = [
    vlm_processor.tokenizer("Is it blue?", add_special_tokens=False).input_ids[0], 
    vlm_processor.tokenizer("Was it blue?", add_special_tokens=False).input_ids[0],
    vlm_processor.tokenizer("Are they blue?", add_special_tokens=False).input_ids[0], 
    vlm_processor.tokenizer("Were they blue?", add_special_tokens=False).input_ids[0],
    vlm_processor.tokenizer("Does it look blue?", add_special_tokens=False).input_ids[0],
    vlm_processor.tokenizer("Do they look blue?", add_special_tokens=False).input_ids[0],
    vlm_processor.tokenizer("Did they look blue?", add_special_tokens=False).input_ids[0],
    vlm_processor.tokenizer("Has the oven turned on?", add_special_tokens=False).input_ids[0],
    vlm_processor.tokenizer("Have the eggs boiled?", add_special_tokens=False).input_ids[0],
    vlm_processor.tokenizer("Had the eggs boiled?", add_special_tokens=False).input_ids[0],
]
begin_suppress_tokens = [t for t in list(range(vlm_processor.tokenizer.vocab_size)) if t not in yes_no_q_tokens]
bad_words_ids = [[vlm_processor.tokenizer("Yes or no?", add_special_tokens=False).input_ids[1]], 
                 vlm_processor.tokenizer("successful", add_special_tokens=False).input_ids, 
                 vlm_processor.tokenizer("successfully", add_special_tokens=False).input_ids, 
                 vlm_processor.tokenizer("completed", add_special_tokens=False).input_ids,
                 vlm_processor.tokenizer("procedure", add_special_tokens=False).input_ids]

generation_kwargs = {
    "do_sample": False,
    "num_beams": 8,
    "num_return_sequences": 4,
    "constraints": question_generation_constraints,
    "begin_suppress_tokens": begin_suppress_tokens,   
    "bad_words_ids": bad_words_ids, 
    "pad_token_id": tokenizer.eos_token_id,
    "length_penalty": 1.0,
}

# Likelihood-Based Ranking (Vanilla)

In [None]:
import time

# Modify below values to tuned values for each experiment
EARLY_STOP_DELTA = 0.1
CONFIDENT_RANGE = 0.1

n_iterations_taken = []
time_taken = []
for batch_idx, batch_example in tqdm(enumerate(dataset.get_batches(1, 
                                                                    n_workers=1, 
                                                                    worker_index=0,
                                                                    load_frames=False)), 
                                                desc="running iterative VQA inference"):

    batch_examples = [batch_example]
    start = time.time()

    # Take first frame (expect there to only be one frame)
    batch_procedures = [example.procedure_description for example in batch_examples]
    batch_frames = [Image.open(example.frames[0]) for example in batch_examples]

    this_batch_size = len(batch_examples)

    prompts = [
        f'{DIALOG_START_TOKENS[VLM_NAME]}{USER_START_TOKENS[VLM_NAME]}{IMAGE_TOKENS[VLM_NAME]}{IVQA_PREAMBLE.format(procedure=procedure)}' 
        for procedure in batch_procedures
    ]
    questions = [[] for _ in range(this_batch_size)]
    frames = [[] for _ in range(this_batch_size)]
    candidate_questions = [[] for _ in range(this_batch_size)]
    candidate_questions_scores = [[] for _ in range(this_batch_size)]
    candidate_questions_sources = [[] for _ in range(this_batch_size)]
    scores = [[] for _ in range(this_batch_size)]
    answer_probs = [[] for _ in range(this_batch_size)] 
    answers = [[] for _ in range(this_batch_size)]
    success_probs = [[] for _ in range(this_batch_size)]
    success_probs_negated = [[] for _ in range(this_batch_size)]

    # Iteratively generate questions
    for question_idx in tqdm(range(MAX_ITERATIONS), desc="running iterative QA"):
        
        # Generate a question (with beam search so we have several candidates)
        prompts_q = [prompt + f"{ASSISTANT_END_TOKENS[VLM_NAME] if question_idx != 0 else USER_END_TOKENS[VLM_NAME]}{USER_START_TOKENS[VLM_NAME]}Q:" for prompt in prompts]
        new_questions, _ = simple_lm_prompt_beam_search(lm,
                                                        tokenizer,
                                                        [prompt.replace(IMAGE_TOKENS[VLM_NAME], "") for prompt in prompts_q],
                                                        max_new_tokens=20,
                                                        batch_size=1,
                                                        generation_kwargs=generation_kwargs)

        new_questions = [[cleanup_generated_question(question) for question in beam_search_questions] for beam_search_questions in new_questions]                                
        new_questions_sources = [["vlm"] * len(beam_search_questions) for beam_search_questions in new_questions]

        # Remove duplicate candidates
        keep_idxs = [[question_idx for question_idx, question in enumerate(beam_search_outputs) if question not in beam_search_outputs[:question_idx]] for beam_search_outputs in new_questions]

        # Try to remove any candidates that we've seen before (if we've seen all the candidates before, don't remove any)
        keep_idxs_filtered = [[question_idx for question_idx, question in enumerate(beam_search_outputs) if question_idx in keep_idxs[batch_sub_idx] and question not in questions[batch_sub_idx]] for batch_sub_idx, beam_search_outputs in enumerate(new_questions)]
        keep_idxs = [keep_idxs_filtered[batch_sub_idx] if len(keep_idxs_filtered[batch_sub_idx]) > 0 else keep_idxs[batch_sub_idx] for batch_sub_idx in range(this_batch_size)]

        # Apply kept indices to new questions and their sources
        new_questions = [[new_questions[batch_sub_idx][question_idx] for question_idx in this_keep_idxs] for batch_sub_idx, this_keep_idxs in enumerate(keep_idxs)]
        new_questions_sources = [[new_questions_sources[batch_sub_idx][question_idx] for question_idx in this_keep_idxs] for batch_sub_idx, this_keep_idxs in enumerate(keep_idxs)]

        # Save all candidates from beam search
        for batch_sub_idx in range(len(candidate_questions)):
            candidate_questions[batch_sub_idx].append(new_questions[batch_sub_idx])
            candidate_questions_sources[batch_sub_idx].append(new_questions_sources[batch_sub_idx])

        # Select best candidate question from pool
        generation_scores = compute_completion_log_likelihoods(lm, tokenizer, [prompt.replace(IMAGE_TOKENS[VLM_NAME], "") for prompt in prompts_q], new_questions, batch_size=1)

        # Select most likely question (first one in list)
        selected_questions = []
        new_scores = []
        for batch_sub_idx, (beam_search_questions, beam_search_scores) in enumerate(zip(new_questions, generation_scores)):                    
            assert len(beam_search_questions) == len(beam_search_scores), "Expected candidate questions and their scores to have the same shape!"

            # Save all candidate scores
            candidate_questions_scores[batch_sub_idx].append(beam_search_scores)

            candidate_idxs = list(range(len(beam_search_questions)))

            # Then pick candidate with highest score
            best_candidate = max(candidate_idxs, key=lambda x: beam_search_scores[x] == max(beam_search_scores))
            selected_questions.append(beam_search_questions[best_candidate])
            new_scores.append(beam_search_scores[best_candidate])

        new_questions = selected_questions

        # Save scores for best questions
        for batch_sub_idx in range(this_batch_size):
            scores[batch_sub_idx].append(new_scores[batch_sub_idx])

        # Save generated questions
        for batch_sub_idx in range(this_batch_size):
            questions[batch_sub_idx].append(new_questions[batch_sub_idx])

        # Run VQA with generated questions (and optional spatial filter)
        prompts_a = [prompt + f' {question}{USER_END_TOKENS[VLM_NAME]}{ASSISTANT_START_TOKENS[VLM_NAME]}A:' for prompt, question in zip(prompts_q, new_questions)]

        # Effective prompt for VQA depends on whether we want to exclude dialog history from prompt
        use_prompts_a = [f'{USER_START_TOKENS[VLM_NAME]}{IMAGE_TOKENS[VLM_NAME]}Q: {question}{USER_END_TOKENS[VLM_NAME]}{ASSISTANT_START_TOKENS[VLM_NAME]}A:' for prompt, question in zip(prompts_q, new_questions)]

        new_answers_logits = run_vqa_with_visual_filter(vlm_processor=vlm_processor, 
                                                        vlm=vlm, 
                                                        batch_examples=batch_examples, 
                                                        batch_frames=batch_frames, 
                                                        prompts_a=use_prompts_a, 
                                                        new_questions=new_questions, 
                                                        question_idx=question_idx,
                                                        batch_size=1,
                                                        visual_filter=None,
                                                        nlp=NotImplemented,
                                                        visual_filter_mode=None,
                                                        frame_cache_dir=None,
                                                        is_encoder_decoder="-t5-" in VLM_NAME.lower())

        # Gather up VQA outputs (which automatically calculates answer probabilities from logits)
        new_answers = [
            VQAOutputs(
                task_name=MistakeDetectionTasks("ego4d_single"),
                example_id=example.example_id,
                procedure_id=example.procedure_id,
                frame=example.frames[0],
                prompt=prompt,
                expected_answer=None,
                response_token_ids=response_token_ids,
                logits=logits,
                question=question,
            ) for logits, example, prompt, question in zip(new_answers_logits, batch_examples, prompts_a, new_questions)
        ]
        new_answers_str = [output.predicted_answer.name if np.abs(output.answer_probs[VQAResponse.Yes] - 0.5) >= UNSURE_RANGE else "Unsure" for output in new_answers]

        # Save answers and their probabilities
        for batch_sub_idx in range(this_batch_size):
            answer_probs[batch_sub_idx].append([round(float(new_answers[batch_sub_idx].answer_probs[VQAResponse(answer_idx)]), 6) for answer_idx in range(2)])
            answers[batch_sub_idx].append(new_answers_str[batch_sub_idx])
        
        
        prompts = [prompt + " " + output for prompt, output in zip(prompts_a, new_answers_str)]

        # Ask VLM probability of success
        questions_success = [
            IVQA_SUCCESS_QUESTION.format(procedure=procedure)
            for procedure in batch_procedures
        ]
        prompts_success = [
            prompt + f'{ASSISTANT_END_TOKENS[VLM_NAME]}{USER_START_TOKENS[VLM_NAME]}Q: {question}{USER_END_TOKENS[VLM_NAME]}{ASSISTANT_START_TOKENS[VLM_NAME]}A: '
            for prompt, question in zip(prompts, questions_success)
        ]

        success_vqa_outputs = run_vqa_with_visual_filter(vlm_processor=vlm_processor, 
                                                            vlm=vlm, 
                                                            batch_examples=batch_examples, 
                                                            batch_frames=batch_frames, 
                                                            prompts_a=prompts_success, 
                                                            new_questions=questions_success, 
                                                            question_idx=f"{question_idx}_success",
                                                            batch_size=1,
                                                            visual_filter=None,
                                                            nlp=None,
                                                            visual_filter_mode=None,
                                                            frame_cache_dir=None,
                                                            is_encoder_decoder="-t5-" in VLM_NAME.lower(),
                                                            ignore_frames=False)
        success_vqa_outputs = [
            VQAOutputs(
                task_name=MistakeDetectionTasks("ego4d_single"),
                example_id=example.example_id,
                procedure_id=example.procedure_id,
                frame=example.frames[0],
                prompt=prompt,
                expected_answer=None,
                response_token_ids=response_token_ids,
                logits=logits,
                question=question,
            ) for logits, example, prompt, question in zip(success_vqa_outputs, batch_examples, prompts_a, new_questions)
        ]               

        # Save success probability for this turn
        for batch_sub_idx in range(this_batch_size):
            success_probs[batch_sub_idx].append(
                round(float(success_vqa_outputs[batch_sub_idx].answer_probs[VQAResponse.Yes]), 6)
            )

        # Clear out VQA outputs now because they occupy a lot of memory
        del new_answers
        del success_vqa_outputs

        # Check if we can stop based on early stopping criteria
        # if success score doesn't change enough over 3 turns, stop incorporating questions
        # (we still run inference across all questions for efficiency and simplicity, but later can make a proper demo script)
        if question_idx >= 2:
            if np.abs(success_probs[0][question_idx-1] - success_probs[0][question_idx-2]) < EARLY_STOP_DELTA and np.abs(success_probs[0][question_idx] - success_probs[0][question_idx-1]) < EARLY_STOP_DELTA:
                n_iterations_taken.append(question_idx+1)
                print("Early stop!")
                break
        # OR if success score is within confident delta, stop
        if success_probs[0][-1] < CONFIDENT_RANGE or 1.0 - success_probs[0][-1] < CONFIDENT_RANGE:
            n_iterations_taken.append(question_idx+1)
            print("Early stop!")
            break
        # If it's the last iteration, record
        if question_idx == MAX_ITERATIONS-1:
            n_iterations_taken.append(MAX_ITERATIONS)

    end = time.time()
    time_taken.append(end-start)

print("Avg. # iterations:", np.mean(n_iterations_taken))
print("Std. # iterations:", np.std(n_iterations_taken))
print("Avg. time (sec.):", np.mean(time_taken))
print("Std. time (sec.):", np.std(time_taken))

# + Coherence-Based Ranking

In [None]:
import time

# Modify below values to tuned values for each experiment
EARLY_STOP_DELTA = 0.1
CONFIDENT_RANGE = 0.1

n_iterations_taken = []
time_taken = []
for batch_idx, batch_example in tqdm(enumerate(dataset.get_batches(1, 
                                                                    n_workers=1, 
                                                                    worker_index=0,
                                                                    load_frames=False)), 
                                                desc="running iterative VQA inference"):

    batch_examples = [batch_example]
    start = time.time()

    # Take first frame (expect there to only be one frame)
    batch_procedures = [example.procedure_description for example in batch_examples]
    batch_frames = [Image.open(example.frames[0]) for example in batch_examples]

    this_batch_size = len(batch_examples)

    prompts = [
        f'{DIALOG_START_TOKENS[VLM_NAME]}{USER_START_TOKENS[VLM_NAME]}{IMAGE_TOKENS[VLM_NAME]}{IVQA_PREAMBLE.format(procedure=procedure)}' 
        for procedure in batch_procedures
    ]
    questions = [[] for _ in range(this_batch_size)]
    frames = [[] for _ in range(this_batch_size)]
    candidate_questions = [[] for _ in range(this_batch_size)]
    candidate_questions_scores = [[] for _ in range(this_batch_size)]
    candidate_questions_sources = [[] for _ in range(this_batch_size)]
    scores = [[] for _ in range(this_batch_size)]
    answer_probs = [[] for _ in range(this_batch_size)] 
    answers = [[] for _ in range(this_batch_size)]
    success_probs = [[] for _ in range(this_batch_size)]
    success_probs_negated = [[] for _ in range(this_batch_size)]

    # Iteratively generate questions
    for question_idx in tqdm(range(MAX_ITERATIONS), desc="running iterative QA"):
        
        # Generate a question (with beam search so we have several candidates)
        prompts_q = [prompt + f"{ASSISTANT_END_TOKENS[VLM_NAME] if question_idx != 0 else USER_END_TOKENS[VLM_NAME]}{USER_START_TOKENS[VLM_NAME]}Q:" for prompt in prompts]
        new_questions, _ = simple_lm_prompt_beam_search(lm,
                                                        tokenizer,
                                                        [prompt.replace(IMAGE_TOKENS[VLM_NAME], "") for prompt in prompts_q],
                                                        max_new_tokens=20,
                                                        batch_size=1,
                                                        generation_kwargs=generation_kwargs)

        new_questions = [[cleanup_generated_question(question) for question in beam_search_questions] for beam_search_questions in new_questions]                                
        new_questions_sources = [["vlm"] * len(beam_search_questions) for beam_search_questions in new_questions]

        # Remove duplicate candidates
        keep_idxs = [[question_idx for question_idx, question in enumerate(beam_search_outputs) if question not in beam_search_outputs[:question_idx]] for beam_search_outputs in new_questions]

        # Try to remove any candidates that we've seen before (if we've seen all the candidates before, don't remove any)
        keep_idxs_filtered = [[question_idx for question_idx, question in enumerate(beam_search_outputs) if question_idx in keep_idxs[batch_sub_idx] and question not in questions[batch_sub_idx]] for batch_sub_idx, beam_search_outputs in enumerate(new_questions)]
        keep_idxs = [keep_idxs_filtered[batch_sub_idx] if len(keep_idxs_filtered[batch_sub_idx]) > 0 else keep_idxs[batch_sub_idx] for batch_sub_idx in range(this_batch_size)]

        # Apply kept indices to new questions and their sources
        new_questions = [[new_questions[batch_sub_idx][question_idx] for question_idx in this_keep_idxs] for batch_sub_idx, this_keep_idxs in enumerate(keep_idxs)]
        new_questions_sources = [[new_questions_sources[batch_sub_idx][question_idx] for question_idx in this_keep_idxs] for batch_sub_idx, this_keep_idxs in enumerate(keep_idxs)]

        # Save all candidates from beam search
        for batch_sub_idx in range(len(candidate_questions)):
            candidate_questions[batch_sub_idx].append(new_questions[batch_sub_idx])
            candidate_questions_sources[batch_sub_idx].append(new_questions_sources[batch_sub_idx])

        # Select best candidate question from pool
        # Calculate coherence metrics for each candidate question
        nli_outputs = question_coherence_metrics_nli(
            nli_tokenizer, 
            nli_model,
            tokenizer,
            lm,
            [procedure for procedure, beam_search_questions in zip(batch_procedures, new_questions) for _ in beam_search_questions],
            [question for beam_search_questions in new_questions for question in beam_search_questions],
            previous_questions=[[q for qi, q in enumerate(batch_idx_questions) if batch_idx_answers[qi] != "Unsure"] for batch_idx_questions, batch_idx_answers, beam_search_questions in zip(questions, answers, new_questions) for _ in beam_search_questions],
            previous_answers=[[a for a in batch_idx_answers if a != "Unsure"] for batch_idx_answers, beam_search_questions in zip(answers, new_questions) for _ in beam_search_questions],
            rephrase_batch_size=10
        )

        # Select best candidate based on coherence metrics
        selected_questions = []
        new_scores = []
        parallel_idx = 0
        ranking_key_mapping = {
            "relevance": "relevance_marginal",
            "informativeness": "informativeness_marginal",
            "coherence": "informativeness_marginal_x_relevance_marginal",
        }
        for batch_sub_idx, beam_search_questions in enumerate(new_questions):
            this_nli_outputs = [{k: round(float(nli_outputs[k][i]), 3) if type(nli_outputs[k][i]) != str else nli_outputs[k][i] for k in nli_outputs} for i in range(parallel_idx, parallel_idx + len(beam_search_questions))]
            candidate_questions_scores[batch_sub_idx].append(this_nli_outputs)
            parallel_idx += len(beam_search_questions)

            # Use marginal relevance (consistency) and expected informativeness (verifiability) to rank candidates
            candidate_scores = np.array(
                [candidate_metrics[ranking_key_mapping["coherence"]] for candidate_metrics in this_nli_outputs]
            )

            best_candidate = np.argmax(candidate_scores)
            selected_questions.append(beam_search_questions[best_candidate])
            new_scores.append(round(float(candidate_scores[best_candidate]), 6))
        
        new_questions = selected_questions
                
        # Save scores for best questions
        for batch_sub_idx in range(this_batch_size):
            scores[batch_sub_idx].append(new_scores[batch_sub_idx])

        # Save generated questions
        for batch_sub_idx in range(this_batch_size):
            questions[batch_sub_idx].append(new_questions[batch_sub_idx])

        # Run VQA with generated questions (and optional spatial filter)
        prompts_a = [prompt + f' {question}{USER_END_TOKENS[VLM_NAME]}{ASSISTANT_START_TOKENS[VLM_NAME]}A:' for prompt, question in zip(prompts_q, new_questions)]

        # Effective prompt for VQA depends on whether we want to exclude dialog history from prompt
        use_prompts_a = [f'{USER_START_TOKENS[VLM_NAME]}{IMAGE_TOKENS[VLM_NAME]}Q: {question}{USER_END_TOKENS[VLM_NAME]}{ASSISTANT_START_TOKENS[VLM_NAME]}A:' for prompt, question in zip(prompts_q, new_questions)]

        new_answers_logits = run_vqa_with_visual_filter(vlm_processor=vlm_processor, 
                                                        vlm=vlm, 
                                                        batch_examples=batch_examples, 
                                                        batch_frames=batch_frames, 
                                                        prompts_a=use_prompts_a, 
                                                        new_questions=new_questions, 
                                                        question_idx=question_idx,
                                                        batch_size=1,
                                                        visual_filter=None,
                                                        nlp=NotImplemented,
                                                        visual_filter_mode=None,
                                                        frame_cache_dir=None,
                                                        is_encoder_decoder="-t5-" in VLM_NAME.lower())

        # Gather up VQA outputs (which automatically calculates answer probabilities from logits)
        new_answers = [
            VQAOutputs(
                task_name=MistakeDetectionTasks("ego4d_single"),
                example_id=example.example_id,
                procedure_id=example.procedure_id,
                frame=example.frames[0],
                prompt=prompt,
                expected_answer=None,
                response_token_ids=response_token_ids,
                logits=logits,
                question=question,
            ) for logits, example, prompt, question in zip(new_answers_logits, batch_examples, prompts_a, new_questions)
        ]
        new_answers_str = [output.predicted_answer.name if np.abs(output.answer_probs[VQAResponse.Yes] - 0.5) >= UNSURE_RANGE else "Unsure" for output in new_answers]

        # Save answers and their probabilities
        for batch_sub_idx in range(this_batch_size):
            answer_probs[batch_sub_idx].append([round(float(new_answers[batch_sub_idx].answer_probs[VQAResponse(answer_idx)]), 6) for answer_idx in range(2)])
            answers[batch_sub_idx].append(new_answers_str[batch_sub_idx])
        
        
        prompts = [prompt + " " + output for prompt, output in zip(prompts_a, new_answers_str)]

        # Ask VLM probability of success
        questions_success = [
            IVQA_SUCCESS_QUESTION.format(procedure=procedure)
            for procedure in batch_procedures
        ]
        prompts_success = [
            prompt + f'{ASSISTANT_END_TOKENS[VLM_NAME]}{USER_START_TOKENS[VLM_NAME]}Q: {question}{USER_END_TOKENS[VLM_NAME]}{ASSISTANT_START_TOKENS[VLM_NAME]}A: '
            for prompt, question in zip(prompts, questions_success)
        ]

        success_vqa_outputs = run_vqa_with_visual_filter(vlm_processor=vlm_processor, 
                                                            vlm=vlm, 
                                                            batch_examples=batch_examples, 
                                                            batch_frames=batch_frames, 
                                                            prompts_a=prompts_success, 
                                                            new_questions=questions_success, 
                                                            question_idx=f"{question_idx}_success",
                                                            batch_size=1,
                                                            visual_filter=None,
                                                            nlp=None,
                                                            visual_filter_mode=None,
                                                            frame_cache_dir=None,
                                                            is_encoder_decoder="-t5-" in VLM_NAME.lower(),
                                                            ignore_frames=False)
        success_vqa_outputs = [
            VQAOutputs(
                task_name=MistakeDetectionTasks("ego4d_single"),
                example_id=example.example_id,
                procedure_id=example.procedure_id,
                frame=example.frames[0],
                prompt=prompt,
                expected_answer=None,
                response_token_ids=response_token_ids,
                logits=logits,
                question=question,
            ) for logits, example, prompt, question in zip(success_vqa_outputs, batch_examples, prompts_a, new_questions)
        ]               

        # Save success probability for this turn
        for batch_sub_idx in range(this_batch_size):
            success_probs[batch_sub_idx].append(
                round(float(success_vqa_outputs[batch_sub_idx].answer_probs[VQAResponse.Yes]), 6)
            )

        # Clear out VQA outputs now because they occupy a lot of memory
        del new_answers
        del success_vqa_outputs

        # Check if we can stop based on early stopping criteria
        # if success score doesn't change enough over 3 turns, stop incorporating questions
        # (we still run inference across all questions for efficiency and simplicity, but later can make a proper demo script)
        if question_idx >= 2:
            if np.abs(success_probs[0][question_idx-1] - success_probs[0][question_idx-2]) < EARLY_STOP_DELTA and np.abs(success_probs[0][question_idx] - success_probs[0][question_idx-1]) < EARLY_STOP_DELTA:
                n_iterations_taken.append(question_idx+1)
                print("Early stop!")
                break
        # OR if success score is within confident delta, stop
        if success_probs[0][-1] < CONFIDENT_RANGE or 1.0 - success_probs[0][-1] < CONFIDENT_RANGE:
            n_iterations_taken.append(question_idx+1)
            print("Early stop!")
            break
        # If it's the last iteration, record
        if question_idx == MAX_ITERATIONS-1:
            n_iterations_taken.append(MAX_ITERATIONS)

    end = time.time()
    time_taken.append(end-start)

print("Avg. # iterations:", np.mean(n_iterations_taken))
print("Std. # iterations:", np.std(n_iterations_taken))
print("Avg. time (sec.):", np.mean(time_taken))
print("Std. time (sec.):", np.std(time_taken))
print("Avg. runtime per iteration (sec.):", np.mean([t / i for i, t in zip(n_iterations_taken, time_taken)]))

# + In-Context Learning

In [None]:
import time

# Modify below values to tuned values for each experiment
EARLY_STOP_DELTA = 0.1
CONFIDENT_RANGE = 0.1

n_iterations_taken = []
time_taken = []
for batch_idx, batch_example in tqdm(enumerate(dataset.get_batches(1, 
                                                                    n_workers=1, 
                                                                    worker_index=0,
                                                                    load_frames=False)), 
                                                desc="running iterative VQA inference"):

    batch_examples = [batch_example]
    start = time.time()

    # Take first frame (expect there to only be one frame)
    batch_procedures = [example.procedure_description for example in batch_examples]
    batch_frames = [Image.open(example.frames[0]) for example in batch_examples]

    this_batch_size = len(batch_examples)

    prompts = [
        f'{DIALOG_START_TOKENS[VLM_NAME]}{USER_START_TOKENS[VLM_NAME]}{IMAGE_TOKENS[VLM_NAME]}{IVQA_PREAMBLE.format(procedure=procedure)}' 
        for procedure in batch_procedures
    ]
    questions = [[] for _ in range(this_batch_size)]
    frames = [[] for _ in range(this_batch_size)]
    candidate_questions = [[] for _ in range(this_batch_size)]
    candidate_questions_scores = [[] for _ in range(this_batch_size)]
    candidate_questions_sources = [[] for _ in range(this_batch_size)]
    scores = [[] for _ in range(this_batch_size)]
    answer_probs = [[] for _ in range(this_batch_size)] 
    answers = [[] for _ in range(this_batch_size)]
    success_probs = [[] for _ in range(this_batch_size)]
    success_probs_negated = [[] for _ in range(this_batch_size)]

    # Iteratively generate questions
    for question_idx in tqdm(range(MAX_ITERATIONS), desc="running iterative QA"):
        
        # Generate a question (with beam search so we have several candidates)
        prompts_q = [prompt + f"{ASSISTANT_END_TOKENS[VLM_NAME] if question_idx != 0 else USER_END_TOKENS[VLM_NAME]}{USER_START_TOKENS[VLM_NAME]}Q:" for prompt in prompts]
        new_questions, _ = simple_lm_prompt_beam_search(lm,
                                                        tokenizer,
                                                        [prompt.replace(IMAGE_TOKENS[VLM_NAME], "") for prompt in prompts_q],
                                                        max_new_tokens=20,
                                                        batch_size=1,
                                                        generation_kwargs=generation_kwargs)

        new_questions = [[cleanup_generated_question(question) for question in beam_search_questions] for beam_search_questions in new_questions]                                
        new_questions_sources = [["vlm"] * len(beam_search_questions) for beam_search_questions in new_questions]

        # Optionally inject more candidates from original VQG ICL code
        icl_prompts = [generate_vqg_prompt_icl(procedure, N_ICL_DEMONSTRATIONS, include_answers=False) for procedure in batch_procedures] # Create ICL prompt
        icl_prompts = [
            prompt + '\n'.join([str(pqi+1) + ' ' + pq for pqi, pq in enumerate(previous_questions[-2:])]) + ("\n" if len(previous_questions) > 0 else "") + f"{len(previous_questions) + 1}. " 
            for prompt, previous_questions in zip(icl_prompts, questions)
        ] # Add some previous questions if possible (take last 2 that were asked)
        icl_new_questions, _ = simple_lm_prompt_beam_search(lm,
                                                            tokenizer,
                                                            icl_prompts,
                                                            max_new_tokens=20,
                                                            batch_size=1,
                                                            generation_kwargs=generation_kwargs)
        
        icl_new_questions = [[cleanup_generated_question(question) for question in beam_search_questions] for beam_search_questions in icl_new_questions]
        
        for batch_sub_idx in range(this_batch_size):
            new_questions[batch_sub_idx] += icl_new_questions[batch_sub_idx]
            new_questions_sources[batch_sub_idx] += ["icl"] * len(icl_new_questions[batch_sub_idx])

        # Remove duplicate candidates
        keep_idxs = [[question_idx for question_idx, question in enumerate(beam_search_outputs) if question not in beam_search_outputs[:question_idx]] for beam_search_outputs in new_questions]

        # Try to remove any candidates that we've seen before (if we've seen all the candidates before, don't remove any)
        keep_idxs_filtered = [[question_idx for question_idx, question in enumerate(beam_search_outputs) if question_idx in keep_idxs[batch_sub_idx] and question not in questions[batch_sub_idx]] for batch_sub_idx, beam_search_outputs in enumerate(new_questions)]
        keep_idxs = [keep_idxs_filtered[batch_sub_idx] if len(keep_idxs_filtered[batch_sub_idx]) > 0 else keep_idxs[batch_sub_idx] for batch_sub_idx in range(this_batch_size)]

        # Apply kept indices to new questions and their sources
        new_questions = [[new_questions[batch_sub_idx][question_idx] for question_idx in this_keep_idxs] for batch_sub_idx, this_keep_idxs in enumerate(keep_idxs)]
        new_questions_sources = [[new_questions_sources[batch_sub_idx][question_idx] for question_idx in this_keep_idxs] for batch_sub_idx, this_keep_idxs in enumerate(keep_idxs)]

        # Save all candidates from beam search
        for batch_sub_idx in range(len(candidate_questions)):
            candidate_questions[batch_sub_idx].append(new_questions[batch_sub_idx])
            candidate_questions_sources[batch_sub_idx].append(new_questions_sources[batch_sub_idx])

        # Select best candidate question from pool
        # Calculate coherence metrics for each candidate question
        nli_outputs = question_coherence_metrics_nli(
            nli_tokenizer, 
            nli_model,
            tokenizer,
            lm,
            [procedure for procedure, beam_search_questions in zip(batch_procedures, new_questions) for _ in beam_search_questions],
            [question for beam_search_questions in new_questions for question in beam_search_questions],
            previous_questions=[[q for qi, q in enumerate(batch_idx_questions) if batch_idx_answers[qi] != "Unsure"] for batch_idx_questions, batch_idx_answers, beam_search_questions in zip(questions, answers, new_questions) for _ in beam_search_questions],
            previous_answers=[[a for a in batch_idx_answers if a != "Unsure"] for batch_idx_answers, beam_search_questions in zip(answers, new_questions) for _ in beam_search_questions],
            rephrase_batch_size=10
        )

        # Select best candidate based on coherence metrics
        selected_questions = []
        new_scores = []
        parallel_idx = 0
        ranking_key_mapping = {
            "relevance": "relevance_marginal",
            "informativeness": "informativeness_marginal",
            "coherence": "informativeness_marginal_x_relevance_marginal",
        }
        for batch_sub_idx, beam_search_questions in enumerate(new_questions):
            this_nli_outputs = [{k: round(float(nli_outputs[k][i]), 3) if type(nli_outputs[k][i]) != str else nli_outputs[k][i] for k in nli_outputs} for i in range(parallel_idx, parallel_idx + len(beam_search_questions))]
            candidate_questions_scores[batch_sub_idx].append(this_nli_outputs)
            parallel_idx += len(beam_search_questions)

            # Use marginal relevance (consistency) and expected informativeness (verifiability) to rank candidates
            candidate_scores = np.array(
                [candidate_metrics[ranking_key_mapping["coherence"]] for candidate_metrics in this_nli_outputs]
            )

            best_candidate = np.argmax(candidate_scores)
            selected_questions.append(beam_search_questions[best_candidate])
            new_scores.append(round(float(candidate_scores[best_candidate]), 6))
        
        new_questions = selected_questions
                
        # Save scores for best questions
        for batch_sub_idx in range(this_batch_size):
            scores[batch_sub_idx].append(new_scores[batch_sub_idx])

        # Save generated questions
        for batch_sub_idx in range(this_batch_size):
            questions[batch_sub_idx].append(new_questions[batch_sub_idx])

        # Run VQA with generated questions (and optional spatial filter)
        prompts_a = [prompt + f' {question}{USER_END_TOKENS[VLM_NAME]}{ASSISTANT_START_TOKENS[VLM_NAME]}A:' for prompt, question in zip(prompts_q, new_questions)]

        # Effective prompt for VQA depends on whether we want to exclude dialog history from prompt
        use_prompts_a = [f'{USER_START_TOKENS[VLM_NAME]}{IMAGE_TOKENS[VLM_NAME]}Q: {question}{USER_END_TOKENS[VLM_NAME]}{ASSISTANT_START_TOKENS[VLM_NAME]}A:' for prompt, question in zip(prompts_q, new_questions)]

        new_answers_logits = run_vqa_with_visual_filter(vlm_processor=vlm_processor, 
                                                        vlm=vlm, 
                                                        batch_examples=batch_examples, 
                                                        batch_frames=batch_frames, 
                                                        prompts_a=use_prompts_a, 
                                                        new_questions=new_questions, 
                                                        question_idx=question_idx,
                                                        batch_size=1,
                                                        visual_filter=None,
                                                        nlp=NotImplemented,
                                                        visual_filter_mode=None,
                                                        frame_cache_dir=None,
                                                        is_encoder_decoder="-t5-" in VLM_NAME.lower())

        # Gather up VQA outputs (which automatically calculates answer probabilities from logits)
        new_answers = [
            VQAOutputs(
                task_name=MistakeDetectionTasks("ego4d_single"),
                example_id=example.example_id,
                procedure_id=example.procedure_id,
                frame=example.frames[0],
                prompt=prompt,
                expected_answer=None,
                response_token_ids=response_token_ids,
                logits=logits,
                question=question,
            ) for logits, example, prompt, question in zip(new_answers_logits, batch_examples, prompts_a, new_questions)
        ]
        new_answers_str = [output.predicted_answer.name if np.abs(output.answer_probs[VQAResponse.Yes] - 0.5) >= UNSURE_RANGE else "Unsure" for output in new_answers]

        # Save answers and their probabilities
        for batch_sub_idx in range(this_batch_size):
            answer_probs[batch_sub_idx].append([round(float(new_answers[batch_sub_idx].answer_probs[VQAResponse(answer_idx)]), 6) for answer_idx in range(2)])
            answers[batch_sub_idx].append(new_answers_str[batch_sub_idx])
        
        
        prompts = [prompt + " " + output for prompt, output in zip(prompts_a, new_answers_str)]

        # Ask VLM probability of success
        questions_success = [
            IVQA_SUCCESS_QUESTION.format(procedure=procedure)
            for procedure in batch_procedures
        ]
        prompts_success = [
            prompt + f'{ASSISTANT_END_TOKENS[VLM_NAME]}{USER_START_TOKENS[VLM_NAME]}Q: {question}{USER_END_TOKENS[VLM_NAME]}{ASSISTANT_START_TOKENS[VLM_NAME]}A: '
            for prompt, question in zip(prompts, questions_success)
        ]

        success_vqa_outputs = run_vqa_with_visual_filter(vlm_processor=vlm_processor, 
                                                            vlm=vlm, 
                                                            batch_examples=batch_examples, 
                                                            batch_frames=batch_frames, 
                                                            prompts_a=prompts_success, 
                                                            new_questions=questions_success, 
                                                            question_idx=f"{question_idx}_success",
                                                            batch_size=1,
                                                            visual_filter=None,
                                                            nlp=None,
                                                            visual_filter_mode=None,
                                                            frame_cache_dir=None,
                                                            is_encoder_decoder="-t5-" in VLM_NAME.lower(),
                                                            ignore_frames=False)
        success_vqa_outputs = [
            VQAOutputs(
                task_name=MistakeDetectionTasks("ego4d_single"),
                example_id=example.example_id,
                procedure_id=example.procedure_id,
                frame=example.frames[0],
                prompt=prompt,
                expected_answer=None,
                response_token_ids=response_token_ids,
                logits=logits,
                question=question,
            ) for logits, example, prompt, question in zip(success_vqa_outputs, batch_examples, prompts_a, new_questions)
        ]               

        # Save success probability for this turn
        for batch_sub_idx in range(this_batch_size):
            success_probs[batch_sub_idx].append(
                round(float(success_vqa_outputs[batch_sub_idx].answer_probs[VQAResponse.Yes]), 6)
            )

        # Clear out VQA outputs now because they occupy a lot of memory
        del new_answers
        del success_vqa_outputs

        # Check if we can stop based on early stopping criteria
        # if success score doesn't change enough over 3 turns, stop incorporating questions
        # (we still run inference across all questions for efficiency and simplicity, but later can make a proper demo script)
        if question_idx >= 2:
            if np.abs(success_probs[0][question_idx-1] - success_probs[0][question_idx-2]) < EARLY_STOP_DELTA and np.abs(success_probs[0][question_idx] - success_probs[0][question_idx-1]) < EARLY_STOP_DELTA:
                n_iterations_taken.append(question_idx+1)
                print("Early stop!")
                break
        # OR if success score is within confident delta, stop
        if success_probs[0][-1] < CONFIDENT_RANGE or 1.0 - success_probs[0][-1] < CONFIDENT_RANGE:
            n_iterations_taken.append(question_idx+1)
            print("Early stop!")
            break
        # If it's the last iteration, record
        if question_idx == MAX_ITERATIONS-1:
            n_iterations_taken.append(MAX_ITERATIONS)

    end = time.time()
    time_taken.append(end-start)

print("Avg. # iterations:", np.mean(n_iterations_taken))
print("Std. # iterations:", np.std(n_iterations_taken))
print("Avg. time (sec.):", np.mean(time_taken))
print("Std. time (sec.):", np.std(time_taken))
print("Avg. runtime per iteration (sec.):", np.mean([t / i for i, t in zip(n_iterations_taken, time_taken)]))

# + DPO Adapter

In [None]:
VQG_ADAPTER_PATH = "path/to/trained/adapter/directory"
lm.load_adapter(VQG_ADAPTER_PATH, adapter_name="vqg")
print("Loaded VQG adapter at", VQG_ADAPTER_PATH)
print(lm.active_adapters())

In [None]:
import time

# Modify below values to tuned values for each experiment
EARLY_STOP_DELTA = 0.1
CONFIDENT_RANGE = 0.1

n_iterations_taken = []
time_taken = []
for batch_idx, batch_example in tqdm(enumerate(dataset.get_batches(1, 
                                                                    n_workers=1, 
                                                                    worker_index=0,
                                                                    load_frames=False)), 
                                                desc="running iterative VQA inference"):

    batch_examples = [batch_example]
    start = time.time()

    # Take first frame (expect there to only be one frame)
    batch_procedures = [example.procedure_description for example in batch_examples]
    batch_frames = [Image.open(example.frames[0]) for example in batch_examples]

    this_batch_size = len(batch_examples)

    prompts = [
        f'{DIALOG_START_TOKENS[VLM_NAME]}{USER_START_TOKENS[VLM_NAME]}{IMAGE_TOKENS[VLM_NAME]}{IVQA_PREAMBLE.format(procedure=procedure)}' 
        for procedure in batch_procedures
    ]
    questions = [[] for _ in range(this_batch_size)]
    frames = [[] for _ in range(this_batch_size)]
    candidate_questions = [[] for _ in range(this_batch_size)]
    candidate_questions_scores = [[] for _ in range(this_batch_size)]
    candidate_questions_sources = [[] for _ in range(this_batch_size)]
    scores = [[] for _ in range(this_batch_size)]
    answer_probs = [[] for _ in range(this_batch_size)] 
    answers = [[] for _ in range(this_batch_size)]
    success_probs = [[] for _ in range(this_batch_size)]
    success_probs_negated = [[] for _ in range(this_batch_size)]

    # Iteratively generate questions
    for question_idx in tqdm(range(MAX_ITERATIONS), desc="running iterative QA"):

        # If we have an adapter available for VQG, enable it (this should only be used for the dialog-based VQG, not in-context learning)
        lm.enable_adapters()

        # Generate a question (with beam search so we have several candidates)
        prompts_q = [prompt + f"{ASSISTANT_END_TOKENS[VLM_NAME] if question_idx != 0 else USER_END_TOKENS[VLM_NAME]}{USER_START_TOKENS[VLM_NAME]}Q:" for prompt in prompts]
        new_questions, _ = simple_lm_prompt_beam_search(lm,
                                                        tokenizer,
                                                        [prompt.replace(IMAGE_TOKENS[VLM_NAME], "") for prompt in prompts_q],
                                                        max_new_tokens=20,
                                                        batch_size=1,
                                                        generation_kwargs=generation_kwargs)

        new_questions = [[cleanup_generated_question(question) for question in beam_search_questions] for beam_search_questions in new_questions]                                
        new_questions_sources = [["vlm"] * len(beam_search_questions) for beam_search_questions in new_questions]

        lm.disable_adapters()

        # Optionally inject more candidates from original VQG ICL code
        icl_prompts = [generate_vqg_prompt_icl(procedure, N_ICL_DEMONSTRATIONS, include_answers=False) for procedure in batch_procedures] # Create ICL prompt
        icl_prompts = [
            prompt + '\n'.join([str(pqi+1) + ' ' + pq for pqi, pq in enumerate(previous_questions[-2:])]) + ("\n" if len(previous_questions) > 0 else "") + f"{len(previous_questions) + 1}. " 
            for prompt, previous_questions in zip(icl_prompts, questions)
        ] # Add some previous questions if possible (take last 2 that were asked)
        icl_new_questions, _ = simple_lm_prompt_beam_search(lm,
                                                            tokenizer,
                                                            icl_prompts,
                                                            max_new_tokens=20,
                                                            batch_size=1,
                                                            generation_kwargs=generation_kwargs)
        
        icl_new_questions = [[cleanup_generated_question(question) for question in beam_search_questions] for beam_search_questions in icl_new_questions]
        
        for batch_sub_idx in range(this_batch_size):
            new_questions[batch_sub_idx] += icl_new_questions[batch_sub_idx]
            new_questions_sources[batch_sub_idx] += ["icl"] * len(icl_new_questions[batch_sub_idx])

        # Remove duplicate candidates
        keep_idxs = [[question_idx for question_idx, question in enumerate(beam_search_outputs) if question not in beam_search_outputs[:question_idx]] for beam_search_outputs in new_questions]

        # Try to remove any candidates that we've seen before (if we've seen all the candidates before, don't remove any)
        keep_idxs_filtered = [[question_idx for question_idx, question in enumerate(beam_search_outputs) if question_idx in keep_idxs[batch_sub_idx] and question not in questions[batch_sub_idx]] for batch_sub_idx, beam_search_outputs in enumerate(new_questions)]
        keep_idxs = [keep_idxs_filtered[batch_sub_idx] if len(keep_idxs_filtered[batch_sub_idx]) > 0 else keep_idxs[batch_sub_idx] for batch_sub_idx in range(this_batch_size)]

        # Apply kept indices to new questions and their sources
        new_questions = [[new_questions[batch_sub_idx][question_idx] for question_idx in this_keep_idxs] for batch_sub_idx, this_keep_idxs in enumerate(keep_idxs)]
        new_questions_sources = [[new_questions_sources[batch_sub_idx][question_idx] for question_idx in this_keep_idxs] for batch_sub_idx, this_keep_idxs in enumerate(keep_idxs)]

        # Save all candidates from beam search
        for batch_sub_idx in range(len(candidate_questions)):
            candidate_questions[batch_sub_idx].append(new_questions[batch_sub_idx])
            candidate_questions_sources[batch_sub_idx].append(new_questions_sources[batch_sub_idx])

        # Select best candidate question from pool
        # Calculate coherence metrics for each candidate question
        nli_outputs = question_coherence_metrics_nli(
            nli_tokenizer, 
            nli_model,
            tokenizer,
            lm,
            [procedure for procedure, beam_search_questions in zip(batch_procedures, new_questions) for _ in beam_search_questions],
            [question for beam_search_questions in new_questions for question in beam_search_questions],
            previous_questions=[[q for qi, q in enumerate(batch_idx_questions) if batch_idx_answers[qi] != "Unsure"] for batch_idx_questions, batch_idx_answers, beam_search_questions in zip(questions, answers, new_questions) for _ in beam_search_questions],
            previous_answers=[[a for a in batch_idx_answers if a != "Unsure"] for batch_idx_answers, beam_search_questions in zip(answers, new_questions) for _ in beam_search_questions],
            rephrase_batch_size=10
        )

        # Select best candidate based on coherence metrics
        selected_questions = []
        new_scores = []
        parallel_idx = 0
        ranking_key_mapping = {
            "relevance": "relevance_marginal",
            "informativeness": "informativeness_marginal",
            "coherence": "informativeness_marginal_x_relevance_marginal",
        }
        for batch_sub_idx, beam_search_questions in enumerate(new_questions):
            this_nli_outputs = [{k: round(float(nli_outputs[k][i]), 3) if type(nli_outputs[k][i]) != str else nli_outputs[k][i] for k in nli_outputs} for i in range(parallel_idx, parallel_idx + len(beam_search_questions))]
            candidate_questions_scores[batch_sub_idx].append(this_nli_outputs)
            parallel_idx += len(beam_search_questions)

            # Use marginal relevance (consistency) and expected informativeness (verifiability) to rank candidates
            candidate_scores = np.array(
                [candidate_metrics[ranking_key_mapping["coherence"]] for candidate_metrics in this_nli_outputs]
            )

            best_candidate = np.argmax(candidate_scores)
            selected_questions.append(beam_search_questions[best_candidate])
            new_scores.append(round(float(candidate_scores[best_candidate]), 6))
        
        new_questions = selected_questions
                
        # Save scores for best questions
        for batch_sub_idx in range(this_batch_size):
            scores[batch_sub_idx].append(new_scores[batch_sub_idx])

        # Save generated questions
        for batch_sub_idx in range(this_batch_size):
            questions[batch_sub_idx].append(new_questions[batch_sub_idx])

        # Run VQA with generated questions (and optional spatial filter)
        prompts_a = [prompt + f' {question}{USER_END_TOKENS[VLM_NAME]}{ASSISTANT_START_TOKENS[VLM_NAME]}A:' for prompt, question in zip(prompts_q, new_questions)]

        # Effective prompt for VQA depends on whether we want to exclude dialog history from prompt
        use_prompts_a = [f'{USER_START_TOKENS[VLM_NAME]}{IMAGE_TOKENS[VLM_NAME]}Q: {question}{USER_END_TOKENS[VLM_NAME]}{ASSISTANT_START_TOKENS[VLM_NAME]}A:' for prompt, question in zip(prompts_q, new_questions)]

        new_answers_logits = run_vqa_with_visual_filter(vlm_processor=vlm_processor, 
                                                        vlm=vlm, 
                                                        batch_examples=batch_examples, 
                                                        batch_frames=batch_frames, 
                                                        prompts_a=use_prompts_a, 
                                                        new_questions=new_questions, 
                                                        question_idx=question_idx,
                                                        batch_size=1,
                                                        visual_filter=None,
                                                        nlp=NotImplemented,
                                                        visual_filter_mode=None,
                                                        frame_cache_dir=None,
                                                        is_encoder_decoder="-t5-" in VLM_NAME.lower())

        # Gather up VQA outputs (which automatically calculates answer probabilities from logits)
        new_answers = [
            VQAOutputs(
                task_name=MistakeDetectionTasks("ego4d_single"),
                example_id=example.example_id,
                procedure_id=example.procedure_id,
                frame=example.frames[0],
                prompt=prompt,
                expected_answer=None,
                response_token_ids=response_token_ids,
                logits=logits,
                question=question,
            ) for logits, example, prompt, question in zip(new_answers_logits, batch_examples, prompts_a, new_questions)
        ]
        new_answers_str = [output.predicted_answer.name if np.abs(output.answer_probs[VQAResponse.Yes] - 0.5) >= UNSURE_RANGE else "Unsure" for output in new_answers]

        # Save answers and their probabilities
        for batch_sub_idx in range(this_batch_size):
            answer_probs[batch_sub_idx].append([round(float(new_answers[batch_sub_idx].answer_probs[VQAResponse(answer_idx)]), 6) for answer_idx in range(2)])
            answers[batch_sub_idx].append(new_answers_str[batch_sub_idx])
        
        
        prompts = [prompt + " " + output for prompt, output in zip(prompts_a, new_answers_str)]

        # Ask VLM probability of success
        questions_success = [
            IVQA_SUCCESS_QUESTION.format(procedure=procedure)
            for procedure in batch_procedures
        ]
        prompts_success = [
            prompt + f'{ASSISTANT_END_TOKENS[VLM_NAME]}{USER_START_TOKENS[VLM_NAME]}Q: {question}{USER_END_TOKENS[VLM_NAME]}{ASSISTANT_START_TOKENS[VLM_NAME]}A: '
            for prompt, question in zip(prompts, questions_success)
        ]

        success_vqa_outputs = run_vqa_with_visual_filter(vlm_processor=vlm_processor, 
                                                            vlm=vlm, 
                                                            batch_examples=batch_examples, 
                                                            batch_frames=batch_frames, 
                                                            prompts_a=prompts_success, 
                                                            new_questions=questions_success, 
                                                            question_idx=f"{question_idx}_success",
                                                            batch_size=1,
                                                            visual_filter=None,
                                                            nlp=None,
                                                            visual_filter_mode=None,
                                                            frame_cache_dir=None,
                                                            is_encoder_decoder="-t5-" in VLM_NAME.lower(),
                                                            ignore_frames=False)
        success_vqa_outputs = [
            VQAOutputs(
                task_name=MistakeDetectionTasks("ego4d_single"),
                example_id=example.example_id,
                procedure_id=example.procedure_id,
                frame=example.frames[0],
                prompt=prompt,
                expected_answer=None,
                response_token_ids=response_token_ids,
                logits=logits,
                question=question,
            ) for logits, example, prompt, question in zip(success_vqa_outputs, batch_examples, prompts_a, new_questions)
        ]               

        # Save success probability for this turn
        for batch_sub_idx in range(this_batch_size):
            success_probs[batch_sub_idx].append(
                round(float(success_vqa_outputs[batch_sub_idx].answer_probs[VQAResponse.Yes]), 6)
            )

        # Clear out VQA outputs now because they occupy a lot of memory
        del new_answers
        del success_vqa_outputs

        # Check if we can stop based on early stopping criteria
        # if success score doesn't change enough over 3 turns, stop incorporating questions
        # (we still run inference across all questions for efficiency and simplicity, but later can make a proper demo script)
        if question_idx >= 2:
            if np.abs(success_probs[0][question_idx-1] - success_probs[0][question_idx-2]) < EARLY_STOP_DELTA and np.abs(success_probs[0][question_idx] - success_probs[0][question_idx-1]) < EARLY_STOP_DELTA:
                n_iterations_taken.append(question_idx+1)
                print("Early stop!")
                break
        # OR if success score is within confident delta, stop
        if success_probs[0][-1] < CONFIDENT_RANGE or 1.0 - success_probs[0][-1] < CONFIDENT_RANGE:
            n_iterations_taken.append(question_idx+1)
            print("Early stop!")
            break
        # If it's the last iteration, record
        if question_idx == MAX_ITERATIONS-1:
            n_iterations_taken.append(MAX_ITERATIONS)

    end = time.time()
    time_taken.append(end-start)

print("Avg. # iterations:", np.mean(n_iterations_taken))
print("Std. # iterations:", np.std(n_iterations_taken))
print("Avg. time (sec.):", np.mean(time_taken))
print("Std. time (sec.):", np.std(time_taken))
print("Avg. runtime per iteration (sec.):", np.mean([t / i for i, t in zip(n_iterations_taken, time_taken)]))

# DPO Adapter Without ICL and Coherence Ranking

In [None]:
VQG_ADAPTER_PATH = "/path/to/trained/adapter/directory"
lm.load_adapter(VQG_ADAPTER_PATH, adapter_name="vqg")
print("Loaded VQG adapter at", VQG_ADAPTER_PATH)
print(lm.active_adapters())

In [None]:
import time

# Modify below values to tuned values for each experiment
EARLY_STOP_DELTA = 0.1
CONFIDENT_RANGE = 0.1

n_iterations_taken = []
time_taken = []
for batch_idx, batch_example in tqdm(enumerate(dataset.get_batches(1, 
                                                                    n_workers=1, 
                                                                    worker_index=0,
                                                                    load_frames=False)), 
                                                desc="running iterative VQA inference"):

    batch_examples = [batch_example]
    start = time.time()

    # Take first frame (expect there to only be one frame)
    batch_procedures = [example.procedure_description for example in batch_examples]
    batch_frames = [Image.open(example.frames[0]) for example in batch_examples]

    this_batch_size = len(batch_examples)

    prompts = [
        f'{DIALOG_START_TOKENS[VLM_NAME]}{USER_START_TOKENS[VLM_NAME]}{IMAGE_TOKENS[VLM_NAME]}{IVQA_PREAMBLE.format(procedure=procedure)}' 
        for procedure in batch_procedures
    ]
    questions = [[] for _ in range(this_batch_size)]
    frames = [[] for _ in range(this_batch_size)]
    candidate_questions = [[] for _ in range(this_batch_size)]
    candidate_questions_scores = [[] for _ in range(this_batch_size)]
    candidate_questions_sources = [[] for _ in range(this_batch_size)]
    scores = [[] for _ in range(this_batch_size)]
    answer_probs = [[] for _ in range(this_batch_size)] 
    answers = [[] for _ in range(this_batch_size)]
    success_probs = [[] for _ in range(this_batch_size)]
    success_probs_negated = [[] for _ in range(this_batch_size)]

    # Iteratively generate questions
    for question_idx in tqdm(range(MAX_ITERATIONS), desc="running iterative QA"):
        
        lm.enable_adapters()
        
        # Generate a question (with beam search so we have several candidates)
        prompts_q = [prompt + f"{ASSISTANT_END_TOKENS[VLM_NAME] if question_idx != 0 else USER_END_TOKENS[VLM_NAME]}{USER_START_TOKENS[VLM_NAME]}Q:" for prompt in prompts]
        new_questions, _ = simple_lm_prompt_beam_search(lm,
                                                        tokenizer,
                                                        [prompt.replace(IMAGE_TOKENS[VLM_NAME], "") for prompt in prompts_q],
                                                        max_new_tokens=20,
                                                        batch_size=1,
                                                        generation_kwargs=generation_kwargs)

        new_questions = [[cleanup_generated_question(question) for question in beam_search_questions] for beam_search_questions in new_questions]                                
        new_questions_sources = [["vlm"] * len(beam_search_questions) for beam_search_questions in new_questions]

        lm.disable_adapters()
        
        # Remove duplicate candidates
        keep_idxs = [[question_idx for question_idx, question in enumerate(beam_search_outputs) if question not in beam_search_outputs[:question_idx]] for beam_search_outputs in new_questions]

        # Try to remove any candidates that we've seen before (if we've seen all the candidates before, don't remove any)
        keep_idxs_filtered = [[question_idx for question_idx, question in enumerate(beam_search_outputs) if question_idx in keep_idxs[batch_sub_idx] and question not in questions[batch_sub_idx]] for batch_sub_idx, beam_search_outputs in enumerate(new_questions)]
        keep_idxs = [keep_idxs_filtered[batch_sub_idx] if len(keep_idxs_filtered[batch_sub_idx]) > 0 else keep_idxs[batch_sub_idx] for batch_sub_idx in range(this_batch_size)]

        # Apply kept indices to new questions and their sources
        new_questions = [[new_questions[batch_sub_idx][question_idx] for question_idx in this_keep_idxs] for batch_sub_idx, this_keep_idxs in enumerate(keep_idxs)]
        new_questions_sources = [[new_questions_sources[batch_sub_idx][question_idx] for question_idx in this_keep_idxs] for batch_sub_idx, this_keep_idxs in enumerate(keep_idxs)]

        # Save all candidates from beam search
        for batch_sub_idx in range(len(candidate_questions)):
            candidate_questions[batch_sub_idx].append(new_questions[batch_sub_idx])
            candidate_questions_sources[batch_sub_idx].append(new_questions_sources[batch_sub_idx])

        # Select best candidate question from pool
        generation_scores = compute_completion_log_likelihoods(lm, tokenizer, [prompt.replace(IMAGE_TOKENS[VLM_NAME], "") for prompt in prompts_q], new_questions, batch_size=1)

        # Select most likely question (first one in list)
        selected_questions = []
        new_scores = []
        for batch_sub_idx, (beam_search_questions, beam_search_scores) in enumerate(zip(new_questions, generation_scores)):                    
            assert len(beam_search_questions) == len(beam_search_scores), "Expected candidate questions and their scores to have the same shape!"

            # Save all candidate scores
            candidate_questions_scores[batch_sub_idx].append(beam_search_scores)

            candidate_idxs = list(range(len(beam_search_questions)))

            # Then pick candidate with highest score
            best_candidate = max(candidate_idxs, key=lambda x: beam_search_scores[x] == max(beam_search_scores))
            selected_questions.append(beam_search_questions[best_candidate])
            new_scores.append(beam_search_scores[best_candidate])

        new_questions = selected_questions

        # Save scores for best questions
        for batch_sub_idx in range(this_batch_size):
            scores[batch_sub_idx].append(new_scores[batch_sub_idx])

        # Save generated questions
        for batch_sub_idx in range(this_batch_size):
            questions[batch_sub_idx].append(new_questions[batch_sub_idx])

        # Run VQA with generated questions (and optional spatial filter)
        prompts_a = [prompt + f' {question}{USER_END_TOKENS[VLM_NAME]}{ASSISTANT_START_TOKENS[VLM_NAME]}A:' for prompt, question in zip(prompts_q, new_questions)]

        # Effective prompt for VQA depends on whether we want to exclude dialog history from prompt
        use_prompts_a = [f'{USER_START_TOKENS[VLM_NAME]}{IMAGE_TOKENS[VLM_NAME]}Q: {question}{USER_END_TOKENS[VLM_NAME]}{ASSISTANT_START_TOKENS[VLM_NAME]}A:' for prompt, question in zip(prompts_q, new_questions)]

        new_answers_logits = run_vqa_with_visual_filter(vlm_processor=vlm_processor, 
                                                        vlm=vlm, 
                                                        batch_examples=batch_examples, 
                                                        batch_frames=batch_frames, 
                                                        prompts_a=use_prompts_a, 
                                                        new_questions=new_questions, 
                                                        question_idx=question_idx,
                                                        batch_size=1,
                                                        visual_filter=None,
                                                        nlp=NotImplemented,
                                                        visual_filter_mode=None,
                                                        frame_cache_dir=None,
                                                        is_encoder_decoder="-t5-" in VLM_NAME.lower())

        # Gather up VQA outputs (which automatically calculates answer probabilities from logits)
        new_answers = [
            VQAOutputs(
                task_name=MistakeDetectionTasks("ego4d_single"),
                example_id=example.example_id,
                procedure_id=example.procedure_id,
                frame=example.frames[0],
                prompt=prompt,
                expected_answer=None,
                response_token_ids=response_token_ids,
                logits=logits,
                question=question,
            ) for logits, example, prompt, question in zip(new_answers_logits, batch_examples, prompts_a, new_questions)
        ]
        new_answers_str = [output.predicted_answer.name if np.abs(output.answer_probs[VQAResponse.Yes] - 0.5) >= UNSURE_RANGE else "Unsure" for output in new_answers]

        # Save answers and their probabilities
        for batch_sub_idx in range(this_batch_size):
            answer_probs[batch_sub_idx].append([round(float(new_answers[batch_sub_idx].answer_probs[VQAResponse(answer_idx)]), 6) for answer_idx in range(2)])
            answers[batch_sub_idx].append(new_answers_str[batch_sub_idx])
        
        
        prompts = [prompt + " " + output for prompt, output in zip(prompts_a, new_answers_str)]

        # Ask VLM probability of success
        questions_success = [
            IVQA_SUCCESS_QUESTION.format(procedure=procedure)
            for procedure in batch_procedures
        ]
        prompts_success = [
            prompt + f'{ASSISTANT_END_TOKENS[VLM_NAME]}{USER_START_TOKENS[VLM_NAME]}Q: {question}{USER_END_TOKENS[VLM_NAME]}{ASSISTANT_START_TOKENS[VLM_NAME]}A: '
            for prompt, question in zip(prompts, questions_success)
        ]

        success_vqa_outputs = run_vqa_with_visual_filter(vlm_processor=vlm_processor, 
                                                            vlm=vlm, 
                                                            batch_examples=batch_examples, 
                                                            batch_frames=batch_frames, 
                                                            prompts_a=prompts_success, 
                                                            new_questions=questions_success, 
                                                            question_idx=f"{question_idx}_success",
                                                            batch_size=1,
                                                            visual_filter=None,
                                                            nlp=None,
                                                            visual_filter_mode=None,
                                                            frame_cache_dir=None,
                                                            is_encoder_decoder="-t5-" in VLM_NAME.lower(),
                                                            ignore_frames=False)
        success_vqa_outputs = [
            VQAOutputs(
                task_name=MistakeDetectionTasks("ego4d_single"),
                example_id=example.example_id,
                procedure_id=example.procedure_id,
                frame=example.frames[0],
                prompt=prompt,
                expected_answer=None,
                response_token_ids=response_token_ids,
                logits=logits,
                question=question,
            ) for logits, example, prompt, question in zip(success_vqa_outputs, batch_examples, prompts_a, new_questions)
        ]               

        # Save success probability for this turn
        for batch_sub_idx in range(this_batch_size):
            success_probs[batch_sub_idx].append(
                round(float(success_vqa_outputs[batch_sub_idx].answer_probs[VQAResponse.Yes]), 6)
            )

        # Clear out VQA outputs now because they occupy a lot of memory
        del new_answers
        del success_vqa_outputs

        # Check if we can stop based on early stopping criteria
        # if success score doesn't change enough over 3 turns, stop incorporating questions
        # (we still run inference across all questions for efficiency and simplicity, but later can make a proper demo script)
        if question_idx >= 2:
            if np.abs(success_probs[0][question_idx-1] - success_probs[0][question_idx-2]) < EARLY_STOP_DELTA and np.abs(success_probs[0][question_idx] - success_probs[0][question_idx-1]) < EARLY_STOP_DELTA:
                n_iterations_taken.append(question_idx+1)
                print("Early stop!")
                break
        # OR if success score is within confident delta, stop
        if success_probs[0][-1] < CONFIDENT_RANGE or 1.0 - success_probs[0][-1] < CONFIDENT_RANGE:
            n_iterations_taken.append(question_idx+1)
            print("Early stop!")
            break
        # If it's the last iteration, record
        if question_idx == MAX_ITERATIONS-1:
            n_iterations_taken.append(MAX_ITERATIONS)

    end = time.time()
    time_taken.append(end-start)

print("Avg. # iterations:", np.mean(n_iterations_taken))
print("Std. # iterations:", np.std(n_iterations_taken))
print("Avg. time (sec.):", np.mean(time_taken))
print("Std. time (sec.):", np.std(time_taken))