In [4]:
%load_ext autoreload
%autoreload 2

import os
os.chdir("/nfs/turbo/coe-chaijy/sstorks/simulation_informed_pcr4nlu/TRAVEl")
from travel import init_travel
init_travel()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
import torch
from transformers import BitsAndBytesConfig, AutoModelForSequenceClassification, AutoTokenizer
from transformers import AutoModelForVision2Seq, AutoProcessor
import spacy

from travel.model.nli import NLI_MODEL_PATH

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=False,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

nli_model = AutoModelForSequenceClassification.from_pretrained(NLI_MODEL_PATH, quantization_config=bnb_config)
nli_tokenizer = AutoTokenizer.from_pretrained(NLI_MODEL_PATH)
nlp = spacy.load("en_core_web_lg")

VLM_NAME = "llava-hf/llava-1.5-7b-hf"
vlm = AutoModelForVision2Seq.from_pretrained(VLM_NAME, 
                                            quantization_config=bnb_config)
vlm_processor = AutoProcessor.from_pretrained(VLM_NAME)
vlm_processor.tokenizer.padding_side = "left"
vlm_processor.tokenizer.pad_token_id = vlm_processor.tokenizer.eos_token_id

`low_cpu_mem_usage` was None, now set to True since model is quantized.
`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
lm = vlm.language_model
tokenizer = vlm_processor.tokenizer

# Tuning bottom-up number of iterations

Problem: not having the early stopping significantly hurts performance because it lowers accuracy, thus lowering consistency and verifiability.

In [4]:
import json
import os

this_results_dir = "/home/sstorks/coe-chaijy/sstorks/simulation_informed_pcr4nlu/TRAVEl/saved_results_222/vqa_mistake_detection/ego4d_single_debug250/llava-1.5-7b-hf/IterativeVQA_q10_ego4d_single_debug250_llava-1.5-7b-hf_beam8-4_likelihood_nohistory_20240815204213"
# this_results_dir = "/home/sstorks/coe-chaijy/sstorks/simulation_informed_pcr4nlu/TRAVEl/saved_results_222/vqa_mistake_detection/ego4d_single_debug250/llava-1.5-7b-hf/IterativeVQA_topdown_q10_ego4d_single_debug250_llava-1.5-7b-hf_beam8-4_likelihood_nohistory_20240817105952"

with open(os.path.join(this_results_dir, "outputs_val.json"), "r") as f:
    all_results_dicts = json.load(f)
with open(os.path.join(this_results_dir, "metrics_coherence_nli_val.json"), "r") as f:
    all_coherence_metrics_base = json.load(f)

In [None]:
import numpy as np
from tqdm import tqdm

from travel.model.metrics import question_coherence_metrics_nli, compile_accuracy_and_coherence_metrics
from travel.model.mistake_detection import MISTAKE_DETECTION_THRESHOLDS

for turn_idx in tqdm(range(10)):
    
    allturns_results_dir = os.path.join(this_results_dir, f"allturns{turn_idx}")
    if not os.path.exists(allturns_results_dir):
        os.makedirs(allturns_results_dir)
    else:
        continue    
    
    all_probs = [results_dict['success_probs'][turn_idx] for results_dict in all_results_dicts.values()]
    all_labels = [results_dict['mistake_type'] for results_dict in all_results_dicts.values()]
    
    all_chosen_questions = [question for results_dict in all_results_dicts.values() for question in range(turn_idx+1)]
    all_previous_questions = [[q for qi, q in enumerate(results_dict['questions'][:question_idx]) if results_dict['answers'][qi] != "Unsure"] for results_dict in all_results_dicts.values() for question_idx in range(turn_idx+1)]

    label_answer_mapping = {0: "No", 1: "Yes"}
    all_predicted_answers = [label_answer_mapping[np.argmax(answer_probs)] for results_dict in all_results_dicts.values() for answer_probs in range(turn_idx+1)]
    all_previous_answers = [[a for a in results_dict['answers'][:question_idx] if a != "Unsure"] for results_dict in all_results_dicts.values() for question_idx in range(turn_idx+1)]

    
    all_coherence_metrics = question_coherence_metrics_nli(nli_tokenizer,
                                                                    nli_model,
                                                                    tokenizer,
                                                                    lm,                                         
                                                                    [results_dict['procedure'] for results_dict in all_results_dicts.values() for _ in range(turn_idx+1)],
                                                                    all_chosen_questions,
                                                                    answers=all_predicted_answers,
                                                                    previous_questions=all_previous_questions,
                                                                    previous_answers=all_previous_answers,
                                                                    mistake_labels=[results_dict['mistake'] for results_dict in all_results_dicts.values() for _ in range(turn_idx+1)],
                                                                    rephrase_batch_size=120)
    
    this_results_dicts = {k: v | {"final_turn": turn_idx} for k, v in all_results_dicts.items()}
    accuracy_metrics_by_threshold, coherence_metrics = compile_accuracy_and_coherence_metrics(all_labels, all_probs, all_coherence_metrics, this_results_dicts, MISTAKE_DETECTION_THRESHOLDS, 0.1)
    
    json.dump(accuracy_metrics_by_threshold, 
            open(os.path.join(allturns_results_dir, f"metrics_accuracy_val.json"), "w"),
            indent=4)

    json.dump(coherence_metrics, 
            open(os.path.join(allturns_results_dir, f"metrics_coherence_nli_val.json"), "w"),
            indent=4)

    json.dump(this_results_dicts, 
            open(os.path.join(allturns_results_dir, f"outputs_val.json"), "w"),
            indent=4)


In [12]:
max_accuracy = None
max_accuracy_turn = None

max_verifiability = None
max_verifiability_turn = None

for turn_idx in tqdm(range(10)):

    allturns_results_dir = os.path.join(this_results_dir, f"allturns{turn_idx}")
    accuracy_metrics_by_threshold = json.load(open(os.path.join(allturns_results_dir, f"metrics_accuracy_val.json"), "r"))
    coherence_metrics = json.load(open(os.path.join(allturns_results_dir, f"metrics_coherence_nli_val.json"), "r"))
    
    this_accuracy = accuracy_metrics_by_threshold['best_metrics']['accuracy']
    this_verifiability = max(list(coherence_metrics['metrics_by_threshold'].values()), key=lambda x: x['verifiability'])['verifiability']
    
    if max_accuracy is None or this_accuracy > max_accuracy:
        max_accuracy = this_accuracy
        max_accuracy_turn = turn_idx
        
    if max_verifiability is None or this_verifiability > max_verifiability:
        max_verifiability = this_verifiability
        max_verifiability_turn = turn_idx

100%|██████████| 10/10 [00:00<00:00, 40.95it/s]


In [13]:
from pprint import pprint

lines = []
lines.append(f"Max accuracy: {max_accuracy} at turn {max_accuracy_turn}")
lines.append(f"Max verifiability: {max_verifiability} at turn {max_verifiability_turn}")

pprint(lines)
with open(os.path.join(this_results_dir, "allturns_results.txt"), "w") as f:
    f.write("\n".join(lines))

['Max accuracy: 0.634 at turn 0',
 'Max verifiability: 0.043835601961904 at turn 9']


# Top-down selection of max iterations

In [6]:
import json
import os

# this_results_dir = "/home/sstorks/coe-chaijy/sstorks/simulation_informed_pcr4nlu/TRAVEl/saved_results_222/vqa_mistake_detection/ego4d_single_debug250/llava-1.5-7b-hf/IterativeVQA_topdown_q10_ego4d_single_debug250_llava-1.5-7b-hf_beam8-4_likelihood_nohistory_20240817105952"
this_results_dir = "/home/sstorks/coe-chaijy/sstorks/simulation_informed_pcr4nlu/TRAVEl/saved_results_222/vqa_mistake_detection/ego4d_single_debug250/Phi-3-vision-128k-instruct/IterativeVQA_topdown_q10_ego4d_single_debug250_Phi-3-vision-128k-instruct_beam8-4_likelihood_nohistory_20240823121212"

with open(os.path.join(this_results_dir, "outputs_val.json"), "r") as f:
    all_results_dicts = json.load(f)
with open(os.path.join(this_results_dir, "metrics_coherence_nli_val.json"), "r") as f:
    all_coherence_metrics_base = json.load(f)
coherence_metrics_by_turn = all_coherence_metrics_base['metrics_by_turn']

In [7]:
from collections import defaultdict
from copy import deepcopy
import numpy as np
from pprint import pprint
from tqdm import tqdm

from travel.model.metrics import question_coherence_metrics_nli, compile_accuracy_and_coherence_metrics, mistake_detection_metrics
from travel.model.mistake_detection import MISTAKE_DETECTION_THRESHOLDS

for turn_idx in tqdm(range(10)):
    
    allturns_results_dir = os.path.join(this_results_dir, f"allturns{turn_idx}")
    if not os.path.exists(allturns_results_dir):
        os.makedirs(allturns_results_dir)
    else:
        continue    
    
    coherence_metrics_by_example = defaultdict(list)
    for k in coherence_metrics_by_turn:
        for example_idx, results_dict in enumerate(list(all_results_dicts.values())):
            this_metrics = []
            for this_turn_idx in range(turn_idx+1):
                this_metrics.append(coherence_metrics_by_turn[k][example_idx][this_turn_idx])
            
             # We'll usually just take the mean across all turns, but we take max marginal informativeness across dialog
            if k != "informativeness_marginal" and k != "informativeness_marginal_ref":
                example_metric = round(float(np.mean(this_metrics)), 6)
            else:
                example_metric = round(float(np.max(this_metrics)), 6)    

            coherence_metrics_by_example[k.replace("_by_turn", "_by_example")].append(example_metric)
        
    all_probs = [results_dict['success_probs'][turn_idx] for results_dict in all_results_dicts.values()]
    all_labels_binary = [results_dict['mistake'] for results_dict in all_results_dicts.values()]
    
    best_metrics = None
    best_threshold = None
    accuracy_metrics_by_threshold = {}
    coherence_metrics_by_threshold = {}
    for threshold in MISTAKE_DETECTION_THRESHOLDS:
        preds = [1.0 - p >= threshold for p in all_probs] # Have to do 1.0 - probability since we got "success" probability from VLM
        assert len(preds) == len(all_probs) == len(all_labels_binary), "Expected same number of preds, probs, and labels."
        this_metrics = mistake_detection_metrics(all_labels_binary, preds)
        accuracy_metrics_by_threshold[threshold] = this_metrics

        # Calculate consistency and verifiability for this example, which are conditional on correctness
        verifiability = np.mean([
            coherence_metrics_by_example['informativeness_marginal_ref_by_example'][i] * coherence_metrics_by_example['relevance_marginal_by_example'][i] if preds[i] == all_labels_binary[i] else 0.0 
            for i in range(len(preds))
        ])
        consistency = np.mean([coherence_metrics_by_example['relevance_marginal_by_example'][i] if preds[i] == all_labels_binary[i] else 0.0 for i in range(len(preds))])
        coherence_metrics_by_threshold[threshold] = {"verifiability": verifiability, "consistency": consistency,}
        
        if best_metrics is None or (this_metrics['false_positive_rate'] + this_metrics['false_negative_rate']) < (best_metrics['false_positive_rate'] + best_metrics['false_negative_rate']):
            best_metrics = this_metrics
            best_threshold = threshold

    accuracy_metrics_by_threshold['best_metrics'] = best_metrics
    accuracy_metrics_by_threshold['best_threshold'] = best_threshold

    all_coherence_metrics_base['metrics_by_turn']['relevance_by_turn']

    coherence_metrics = deepcopy(all_coherence_metrics_base) | {"metrics_by_example": coherence_metrics_by_example, "metrics_by_threshold": coherence_metrics_by_threshold}
    json.dump(accuracy_metrics_by_threshold, 
            open(os.path.join(allturns_results_dir, f"metrics_accuracy_val.json"), "w"),
            indent=4)

    coherence_metrics = deepcopy(all_coherence_metrics_base) | {"metrics_by_example": coherence_metrics_by_example, "metrics_by_threshold": coherence_metrics_by_threshold}
    json.dump(coherence_metrics, 
            open(os.path.join(allturns_results_dir, f"metrics_coherence_nli_val.json"), "w"),
            indent=4)

    this_results_dicts = {k: v | {"final_turn": turn_idx} for k, v in all_results_dicts.items()}
    json.dump(this_results_dicts, 
            open(os.path.join(allturns_results_dir, f"outputs_val.json"), "w"),
            indent=4)


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize

In [8]:
max_accuracy = None
max_accuracy_turn = None

max_verifiability = None
max_verifiability_turn = None

for turn_idx in tqdm(range(10)):

    allturns_results_dir = os.path.join(this_results_dir, f"allturns{turn_idx}")
    accuracy_metrics_by_threshold = json.load(open(os.path.join(allturns_results_dir, f"metrics_accuracy_val.json"), "r"))
    coherence_metrics = json.load(open(os.path.join(allturns_results_dir, f"metrics_coherence_nli_val.json"), "r"))
    
    this_accuracy = accuracy_metrics_by_threshold['best_metrics']['accuracy']
    this_verifiability = max(list(coherence_metrics['metrics_by_threshold'].values()), key=lambda x: x['verifiability'])['verifiability']
    
    if max_accuracy is None or this_accuracy > max_accuracy:
        max_accuracy = this_accuracy
        max_accuracy_turn = turn_idx
        
    if max_verifiability is None or this_verifiability > max_verifiability:
        max_verifiability = this_verifiability
        max_verifiability_turn = turn_idx

100%|██████████| 10/10 [00:00<00:00, 59.61it/s]


In [None]:
from pprint import pprint

lines = []
lines.append(f"Max accuracy: {max_accuracy} at turn {max_accuracy_turn}")
lines.append(f"Max verifiability: {max_verifiability} at turn {max_verifiability_turn}")

pprint(lines)
with open(os.path.join(this_results_dir, "allturns_results.txt"), "w") as f:
    f.write("\n".join(lines))