In [20]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6,7"

import pickle
import random
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import joblib
import logging
from datetime import datetime
from argparse import ArgumentParser
import src.utils as utils
from src.utils import normalize_answer, find_subsequence, exact_match_score
import json
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, roc_curve

In [21]:
data = utils.load_parquet_data('results/llama-2-7b-hf/HotpotQA-2000-2per-seed0-cons3.parquet')

# evaluate Greedy
print('------------------Greedy------------------')
greedy_em, _, _ = utils.evaluate_em(data, 'gold_ans', 'greedy_ans')
greedy_f1, _, _ = utils.evaluate_f1(data, 'gold_ans', 'greedy_ans')

print(f"EM: {greedy_em * 100}")
print(f"F1: {greedy_f1 * 100}\n")

# evaluate RAG Hallu
print('------------------Consistency------------------')
ours_em, _, _ = utils.evaluate_em(data, 'gold_ans', 'pred_ans')
ours_f1, _, _ = utils.evaluate_f1(data, 'gold_ans', 'pred_ans')


print(f"EM: {ours_em * 100}")
print(f"F1: {ours_f1 * 100}\n")

------------------Greedy------------------
EM: 43.22
F1: 53.61

------------------Consistency------------------
EM: 42.370000000000005
F1: 53.59



In [22]:
correct = []
incorrect = []

correct_consistency = []
incorrect_consistency = []

for i, item in enumerate(data):

    if exact_match_score(item['gold_ans'][0], item['pred_ans']):
        correct.append(item)
        correct_consistency.append(item['greedy_ans'] == item['pred_ans'])
    else:
        incorrect.append(item)
        incorrect_consistency.append(item['greedy_ans'] == item['pred_ans'])

print('-'*50)        
print(f"Total: {len(data)}")
print('-'*50)
print(f"(O, TN) Consistent when correct: {sum(correct_consistency)}")
print(f"(X, FP) Inconsistent when correct: {len(correct_consistency) - sum(correct_consistency)}")
print(f"(X, FN) Consistent when incorrect: {sum(incorrect_consistency)}")
print(f"(O, TP) Inconsistent when incorrect: {len(incorrect_consistency) - sum(incorrect_consistency)}")

TP = len(incorrect_consistency) - sum(incorrect_consistency)
TN = sum(correct_consistency)
FP = len(correct_consistency) - sum(correct_consistency)
FN = sum(incorrect_consistency)

detection_accuracy = (TP + TN) / len(data)
print('-'*50)
print(f"Detection Accuracy: {detection_accuracy:.4f}")

precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0
recall = TP / (TP + FN) if (TP + FN) > 0 else 0.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0

#print(f"Detection Precision: {precision:.4f}")
#print(f"Detection Recall: {recall:.4f}")
print(f"Detection F1-score: {f1:.4f}")

--------------------------------------------------
Total: 118
--------------------------------------------------
(O, TN) Consistent when correct: 43
(X, FP) Inconsistent when correct: 7
(X, FN) Consistent when incorrect: 28
(O, TP) Inconsistent when incorrect: 40
--------------------------------------------------
Detection Accuracy: 0.7034
Detection F1-score: 0.6957
