In [38]:
import numpy as np
import string
import re
from collections import Counter
import re
from utils import load_file, postprocess_answer_option_conditioned, parse_path
from emoji import emojize
import pandas as pd
import json
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Any, Tuple, Literal
import os
from functools import partial

## Metrics

In [39]:
def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def handle_punc(text):
        exclude = set(string.punctuation + "".join([u"‘", u"’", u"´", u"`"]))
        return ''.join(ch if ch not in exclude else ' ' for ch in text)

    def lower(text):
        return text.lower()

    def replace_underscore(text):
        return text.replace('_', ' ')

    return white_space_fix(remove_articles(handle_punc(lower(replace_underscore(s))))).strip()


def f1_score(prediction, ground_truth):
    '''
    pred: str
    gt: str
    '''
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth):
    '''
    pred: str
    gt: str
    '''
    return normalize_answer(prediction) == normalize_answer(ground_truth)


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    '''
    metric_fn: fns that take two strs only
    pred: str
    gts: list(str)
    '''
    if type(ground_truths) == str:
        ground_truths = [ground_truths]
    ground_truths = [i for i in ground_truths if i.strip() != ""]
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)


# def accuracy(preds, labels):
#     match_count = 0
#     for pred, label in zip(preds, labels):
#         target = label[0]
#         if pred == target:
#             match_count += 1

#     return 100 * (match_count / len(preds))

# def accuracy(pred, label):
#     '''
#     akari claimed acc
#     ===
#     pred: str
#     label: str or list(str)
#     '''
#     if type(label) == list:
#         assert len(label) == 1
#         label = label[0]

#     return pred.lower()==label.lower()

def loose_acc(pred, label):
    '''
    pred: str
    label: str
    '''
    # cnt = 0
    def white_space_fix(text):
        return ' '.join(text.split())

    def handle_punc(text):
        exclude = set(string.punctuation + "".join([u"‘", u"’", u"´", u"`"]))
        return ''.join(ch if ch not in exclude else ' ' for ch in text)

    def lower(text):
        return text.lower()

    pred = white_space_fix(handle_punc(lower(pred))).strip()
    label = white_space_fix(handle_punc(lower(label))).strip()
    if pred == label:
        return 1
    if len(pred) < len(label):
        return 0
    # if pred.startswith(label):
    #     cnt += len(label)

    # return cnt / len(pred)
    if len(pred.split())>0 and pred.split()[0] == label:
        return 1

    return 0

def loose_match(prediction, ground_truth):
    '''
    pred: str
    gt: str
    '''
    if ground_truth in prediction:
        return 1
    return 0

## Scores

In [40]:
def compute_confidence(log_probs:List):
    '''
    log_probs: List[float]
    '''
    return np.mean(np.exp(log_probs))

def compute_inv_perplexity(log_probs:List):
    '''
    log_probs: List[float]
    '''
    return np.exp(np.mean(log_probs))

In [41]:
def false_invperplexity(ret_ind):
    return np.sum(ret_ind['id_log_probs']) / max(1, len(ret_ind['token_ids']))

def true_invperplexity(ret_ind):
    return np.sum(ret_ind['id_log_probs']) / max(1, len(ret_ind['id_log_probs']))

def recompute_score_inv_perplexity(ret_res, is_selfrag):
    """pass res['retrieval_res']"""
    
    res = ret_res.copy()
    ret_prefix = ['retrieval_0', 'retrieval_1', 'retrieval_2', 'retrieval_3', 'retrieval_4']
    for prefix in ret_prefix:
        if is_selfrag:
            score = ret_res[prefix]['score'] - false_invperplexity(ret_res[prefix]) + true_invperplexity(ret_res[prefix])
        else:
            assert false_invperplexity(ret_res[prefix]) - ret_res[prefix]['score'] < 1e-5
            score = true_invperplexity(ret_res[prefix])
        res[prefix]['score'] = score
    return res
        

In [42]:
def recompute_score(ret_res, recompute_fn, is_closed, is_selfrag):
    ret_res = ret_res.copy()
    for i in ret_res:
        i['retrieval_res'] = recompute_fn(i['retrieval_res'], is_selfrag)
        if is_closed:
            answer2score = {}
            for key, result in i['retrieval_res'].items():
                answer = postprocess_answer_option_conditioned(result["pred"])
                if len(answer.split()) > 0:
                    answer = answer.split()[0]
                score = result["score"]
                answer2score.setdefault(answer, 0)
                answer2score[answer] += score
            sorted_answers = sorted(
                answer2score.items(), key=lambda x: x[1], reverse=True)
            best_option = sorted_answers[0][0]
            hit_results = {key: item for key, item in i['retrieval_res'].items() if postprocess_answer_option_conditioned(item["pred"]).startswith(best_option)}
            
            path2score = {key: item["score"] for key,
                            item in hit_results.items()}
            best_path = sorted(path2score.items(),
                                key=lambda x: x[1], reverse=True)[0][0]
            i["best_one"] = i['retrieval_res'][best_path]
            best_option = i['retrieval_res'][best_path]["pred"]
            token_ids = i['retrieval_res'][best_path]["token_ids"]
            id_log_probs = i['retrieval_res'][best_path]["id_log_probs"]
        else:
            path2score = {key: item["score"] for key,
                            item in i['retrieval_res'].items()}
            best_path = sorted(path2score.items(),
                                key=lambda x: x[1], reverse=True)[0][0]
            best_option = i['retrieval_res'][best_path]["pred"]
            token_ids = i['retrieval_res'][best_path]["token_ids"]
            id_log_probs =i['retrieval_res'][best_path]["id_log_probs"]
            i["best_one"] = i['retrieval_res'][best_path]
        
        i["retrieval"] = best_option
        i["retrieval_token_ids"] = token_ids
        i["retrieval_log_probs"] = id_log_probs
    return ret_res
        

## Res Analysis

In [43]:
def emo_print(s):
    print(emojize(s))
    
    
emo_print(":bear:")

🐻


In [44]:
def score_by_magnitude(numbers):
    # Create a list of tuples (number, index) to keep track of the original indices
    indexed_numbers = list(enumerate(numbers))
    
    # Sort the list by the numbers (magnitude)
    sorted_numbers = sorted(indexed_numbers, key=lambda x: x[1])
    
    # Create a list to store the scores
    scores = [0] * len(numbers)
    
    # Initialize variables for scoring
    current_score = 1
    previous_value = None
    
    # Assign scores based on the sorted order, handling duplicates
    for i, (index, value) in enumerate(sorted_numbers):
        if value != previous_value:
            current_score = i + 1
        scores[index] = current_score
        previous_value = value
    
    return scores

In [56]:
def select_by_confidence(res, is_closed=False):
    preds = []
    has_hit = 0
    right_choice = 0 
    for i in res:
        # compute confidence for ret_i
        
        for key, item in i['retrieval_res'].items():
            if is_closed:
                acc = metric_max_over_ground_truths(loose_acc, item['pred'], i['gold'])
            else:
                
                acc = metric_max_over_ground_truths(loose_match, item['pred'], i['gold'])
            if acc > 0:
                has_hit += 1
                break
            
        for key, item in i['retrieval_res'].items():
            i['retrieval_res'][key]['score'] = compute_confidence(item['id_log_probs'])
        if is_closed:
            answer2score = {}
            for key, result in i['retrieval_res'].items():
                answer = postprocess_answer_option_conditioned(result["pred"])
                if len(answer.split()) > 0:
                    answer = answer.split()[0]
                score = result['score']
                answer2score.setdefault(answer, 0)
                answer2score[answer] += score
            sorted_answers = sorted(
                answer2score.items(), key=lambda x: x[1], reverse=True)
            best_option = sorted_answers[0][0]
            hit_results = {key: item for key, item in i['retrieval_res'].items() if postprocess_answer_option_conditioned(item["pred"]).startswith(best_option)}
            
            path2score = {key: item['score'] for key,
                            item in hit_results.items()}
            best_path = sorted(path2score.items(),
                                key=lambda x: x[1], reverse=True)[0][0]
            best_option = i['retrieval_res'][best_path]["pred"]
        else:
            path2score = {key: item['score'] for key,
                            item in i['retrieval_res'].items()}
            best_path = sorted(path2score.items(),
                                key=lambda x: x[1], reverse=True)[0][0]
            best_option = i['retrieval_res'][best_path]["pred"]
        if is_closed:
            acc_ = metric_max_over_ground_truths(loose_acc, best_option, i['gold'])
        else:
            acc_ = metric_max_over_ground_truths(loose_match, best_option, i['gold'])
        if acc_ > 0:
            right_choice += 1
        if acc_ > acc:
            print("something wired...")
        preds.append(best_option)
    print(f"#Right choice rate: {right_choice} out of {has_hit}")
    print(f"Percentage of right choice: {right_choice/has_hit * 100}%")
    return preds, None

def select_by_ctx(res, is_closed=False):
    preds = []
    has_hit = 0
    right_choice = 0 
    for i in res:
        # compute confidence for ret_i
        
        for key, item in i['retrieval_res'].items():
            if is_closed:
                acc = metric_max_over_ground_truths(loose_acc, item['pred'], i['gold'])
            else:
                
                acc = metric_max_over_ground_truths(loose_match, item['pred'], i['gold'])
            if acc > 0:
                has_hit += 1
                break
            
        for key, item in i['retrieval_res'].items():
            i['retrieval_res'][key]['score'] = float(item['ctx_score'])
        if is_closed:
            answer2score = {}
            for key, result in i['retrieval_res'].items():
                answer = postprocess_answer_option_conditioned(result["pred"])
                if len(answer.split()) > 0:
                    answer = answer.split()[0]
                score = result['score']
                answer2score.setdefault(answer, 0)
                answer2score[answer] += score
            sorted_answers = sorted(
                answer2score.items(), key=lambda x: x[1], reverse=True)
            best_option = sorted_answers[0][0]
            hit_results = {key: item for key, item in i['retrieval_res'].items() if postprocess_answer_option_conditioned(item["pred"]).startswith(best_option)}
            
            path2score = {key: item['score'] for key,
                            item in hit_results.items()}
            best_path = sorted(path2score.items(),
                                key=lambda x: x[1], reverse=True)[0][0]
            best_option = i['retrieval_res'][best_path]["pred"]
        else:
            path2score = {key: item['score'] for key,
                            item in i['retrieval_res'].items()}
            best_path = sorted(path2score.items(),
                                key=lambda x: x[1], reverse=True)[0][0]
            best_option = i['retrieval_res'][best_path]["pred"]
        if is_closed:
            acc_ = metric_max_over_ground_truths(loose_acc, best_option, i['gold'])
        else:
            acc_ = metric_max_over_ground_truths(loose_match, best_option, i['gold'])
        if acc_ > 0:
            right_choice += 1
        if acc_ > acc:
            print("something wired...")
        preds.append(best_option)
    print(f"#Right choice rate: {right_choice} out of {has_hit}")
    print(f"Percentage of right choice: {right_choice/has_hit * 100}%")
    return preds, None

def select_by_selfeval(res, eval_dc, is_closed=False):
    preds = []
    has_hit = 0
    right_choice = 0 
    for ind, i in enumerate(res):
        # compute confidence for ret_i
        
        for key, item in i['retrieval_res'].items():
            if is_closed:
                acc = metric_max_over_ground_truths(loose_acc, item['pred'], i['gold'])
            else:
                
                acc = metric_max_over_ground_truths(loose_match, item['pred'], i['gold'])
            if acc > 0:
                has_hit += 1
                break
            
        answer2score = eval_dc[ind]
        sorted_answers = sorted(
            answer2score.items(), key=lambda x: x[1], reverse=True
        )
        best_option = int(sorted_answers[0][0])-1
        best_option = i['retrieval_res'][f"retrieval_{best_option}"]["pred"]
        if is_closed:
            acc_ = metric_max_over_ground_truths(loose_acc, best_option, i['gold'])
        else:
            acc_ = metric_max_over_ground_truths(loose_match, best_option, i['gold'])
        if acc_ > 0:
            right_choice += 1
        if acc_ > acc:
            print("something wired...")
        preds.append(best_option)
    print(f"#Right choice rate: {right_choice} out of {has_hit}")
    print(f"Percentage of right choice: {right_choice/has_hit * 100}%")
    return preds, None

def select_selfrag(res, is_closed=False, ret_p = None, thres = 0.2):
    print(f"#Adp threshold: {thres}")
    preds = []
    non_adp_preds = []
    has_hit = 0
    right_choice = 0 
    no_ret_cnt = 0
    ret_cnt = 0
    for ind, i in enumerate(res):
        for key, item in i['retrieval_res'].items():
            if is_closed:
                acc = metric_max_over_ground_truths(loose_acc, item['pred'], i['gold'])
            else:
                
                acc = metric_max_over_ground_truths(loose_match, item['pred'], i['gold'])
            if acc > 0:
                has_hit += 1
                # comp_str = f"{compute_confidence(item['id_log_probs'])}, {np.log(float(item['ctx_score']))}"
                break
            
        for key, item in i['retrieval_res'].items():
            i['retrieval_res'][key]['score'] = i['retrieval_res'][key]['score'] - false_invperplexity(i['retrieval_res'][key])# + true_invperplexity(i['retrieval_res'][key])
        if is_closed:
            answer2score = {}
            for key, result in i['retrieval_res'].items():
                answer = postprocess_answer_option_conditioned(result["pred"])
                if len(answer.split()) > 0:
                    answer = answer.split()[0]
                score = result['score']
                answer2score.setdefault(answer, 0)
                answer2score[answer] += score
            sorted_answers = sorted(
                answer2score.items(), key=lambda x: x[1], reverse=True)
            best_option = sorted_answers[0][0]
            hit_results = {key: item for key, item in i['retrieval_res'].items() if postprocess_answer_option_conditioned(item["pred"]).startswith(best_option)}
            
            path2score = {key: item['score'] for key,
                            item in hit_results.items()}
            best_path = sorted(path2score.items(),
                                key=lambda x: x[1], reverse=True)[0][0]
            best_option = i['retrieval_res'][best_path]["pred"]
        else:
            path2score = {key: item['score'] for key,
                            item in i['retrieval_res'].items()}
            best_path = sorted(path2score.items(),
                                key=lambda x: x[1], reverse=True)[0][0]
            best_option = i['retrieval_res'][best_path]["pred"]
        if is_closed:
            acc_ = metric_max_over_ground_truths(loose_acc, best_option, i['gold'])
        else:
            acc_ = metric_max_over_ground_truths(loose_match, best_option, i['gold'])
        if acc_ > 0:
            right_choice += 1
        if acc_ > acc:
            print("something wired...")
        # if acc_ < acc:
        #     print('best', comp_str)
        #     print('chosen',  compute_confidence(i['retrieval_res'][best_path]['id_log_probs']), np.log(float(i['retrieval_res'][best_path]['ctx_score'])))
        if ret_p is not None and ret_p[ind] < thres:
            preds.append(i['no_retrieval'])
            no_ret_cnt += 1
        else:
            preds.append(best_option)
            ret_cnt += 1
        non_adp_preds.append(best_option)
    print(f"#Right choice rate: {right_choice} out of {has_hit} ({round(right_choice/has_hit * 100, 1)}%)")
    print(f"Ret: {ret_cnt}, NoRet: {no_ret_cnt}, Ret rate: {round(ret_cnt/(ret_cnt + no_ret_cnt) * 100, 1)}%")
    return non_adp_preds, preds

def select_by_cfdnctx(res, is_closed=False):
    preds = []
    has_hit = 0
    right_choice = 0 
    for i in res:
        # compute confidence for ret_i
        
        for key, item in i['retrieval_res'].items():
            if is_closed:
                acc = metric_max_over_ground_truths(loose_acc, item['pred'], i['gold'])
            else:
                
                acc = metric_max_over_ground_truths(loose_match, item['pred'], i['gold'])
            if acc > 0:
                has_hit += 1
                break
            
        for key, item in i['retrieval_res'].items():
            i['retrieval_res'][key]['score'] = compute_confidence(item['id_log_probs']) + float(item['ctx_score'])
        if is_closed:
            answer2score = {}
            for key, result in i['retrieval_res'].items():
                answer = postprocess_answer_option_conditioned(result["pred"])
                if len(answer.split()) > 0:
                    answer = answer.split()[0]
                score = result['score']
                answer2score.setdefault(answer, 0)
                answer2score[answer] += score
            sorted_answers = sorted(
                answer2score.items(), key=lambda x: x[1], reverse=True)
            best_option = sorted_answers[0][0]
            hit_results = {key: item for key, item in i['retrieval_res'].items() if postprocess_answer_option_conditioned(item["pred"]).startswith(best_option)}
            
            path2score = {key: item['score'] for key,
                            item in hit_results.items()}
            best_path = sorted(path2score.items(),
                                key=lambda x: x[1], reverse=True)[0][0]
            best_option = i['retrieval_res'][best_path]["pred"]
        else:
            path2score = {key: item['score'] for key,
                            item in i['retrieval_res'].items()}
            best_path = sorted(path2score.items(),
                                key=lambda x: x[1], reverse=True)[0][0]
            best_option = i['retrieval_res'][best_path]["pred"]
        if is_closed:
            acc_ = metric_max_over_ground_truths(loose_acc, best_option, i['gold'])
        else:
            acc_ = metric_max_over_ground_truths(loose_match, best_option, i['gold'])
        if acc_ > 0:
            right_choice += 1
        if acc_ > acc:
            print("something wired...")
        preds.append(best_option)
    print(f"#Right choice rate: {right_choice} out of {has_hit}")
    return preds

def select_by_cfdnlogctx(res, is_closed=False):
    preds = []
    has_hit = 0
    right_choice = 0 
    for i in res:
        # compute confidence for ret_i
        
        for key, item in i['retrieval_res'].items():
            if is_closed:
                acc = metric_max_over_ground_truths(loose_acc, item['pred'], i['gold'])
            else:
                
                acc = metric_max_over_ground_truths(loose_match, item['pred'], i['gold'])
            if acc > 0:
                has_hit += 1
                # comp_str = f"{compute_confidence(item['id_log_probs'])}, {np.log(float(item['ctx_score']))}"
                break
            
        for key, item in i['retrieval_res'].items():
            i['retrieval_res'][key]['score'] = compute_confidence(item['id_log_probs']) + np.log(float(item['ctx_score']))
        if is_closed:
            answer2score = {}
            for key, result in i['retrieval_res'].items():
                answer = postprocess_answer_option_conditioned(result["pred"])
                if len(answer.split()) > 0:
                    answer = answer.split()[0]
                score = result['score']
                answer2score.setdefault(answer, 0)
                answer2score[answer] += score
            sorted_answers = sorted(
                answer2score.items(), key=lambda x: x[1], reverse=True)
            best_option = sorted_answers[0][0]
            hit_results = {key: item for key, item in i['retrieval_res'].items() if postprocess_answer_option_conditioned(item["pred"]).startswith(best_option)}
            
            path2score = {key: item['score'] for key,
                            item in hit_results.items()}
            best_path = sorted(path2score.items(),
                                key=lambda x: x[1], reverse=True)[0][0]
            best_option = i['retrieval_res'][best_path]["pred"]
        else:
            path2score = {key: item['score'] for key,
                            item in i['retrieval_res'].items()}
            best_path = sorted(path2score.items(),
                                key=lambda x: x[1], reverse=True)[0][0]
            best_option = i['retrieval_res'][best_path]["pred"]
        if is_closed:
            acc_ = metric_max_over_ground_truths(loose_acc, best_option, i['gold'])
        else:
            acc_ = metric_max_over_ground_truths(loose_match, best_option, i['gold'])
        if acc_ > 0:
            right_choice += 1
        if acc_ > acc:
            print("something wired...")
        # if acc_ < acc:
        #     print('best', comp_str)
        #     print('chosen',  compute_confidence(i['retrieval_res'][best_path]['id_log_probs']), np.log(float(i['retrieval_res'][best_path]['ctx_score'])))
        preds.append(best_option)
    print(f"#Right choice rate: {right_choice} out of {has_hit}")
    return preds



def select_by_cfdnlogctx_adp(res, is_closed=False, ret_p = None, thres = 0.2):
    print(f"#Adp threshold: {thres}")
    preds = []
    has_hit = 0
    right_choice = 0 
    ret_cnt = 0
    noret_cnt = 0
    for ind, i in enumerate(res):
        if ret_p is not None and ret_p[ind] < thres:
            preds.append(i['no_retrieval'])
            noret_cnt += 1
            continue
        # compute confidence for ret_i
        ret_cnt += 1
        for key, item in i['retrieval_res'].items():
            if is_closed:
                acc = metric_max_over_ground_truths(loose_acc, item['pred'], i['gold'])
            else:
                
                acc = metric_max_over_ground_truths(loose_match, item['pred'], i['gold'])
            if acc > 0:
                has_hit += 1
                # comp_str = f"{compute_confidence(item['id_log_probs'])}, {np.log(float(item['ctx_score']))}"
                break
            
        for key, item in i['retrieval_res'].items():
            i['retrieval_res'][key]['score'] = compute_confidence(item['id_log_probs']) + np.log(float(item['ctx_score']))
        if is_closed:
            answer2score = {}
            for key, result in i['retrieval_res'].items():
                answer = postprocess_answer_option_conditioned(result["pred"])
                if len(answer.split()) > 0:
                    answer = answer.split()[0]
                score = result['score']
                answer2score.setdefault(answer, 0)
                answer2score[answer] += score
            sorted_answers = sorted(
                answer2score.items(), key=lambda x: x[1], reverse=True)
            best_option = sorted_answers[0][0]
            hit_results = {key: item for key, item in i['retrieval_res'].items() if postprocess_answer_option_conditioned(item["pred"]).startswith(best_option)}
            
            path2score = {key: item['score'] for key,
                            item in hit_results.items()}
            best_path = sorted(path2score.items(),
                                key=lambda x: x[1], reverse=True)[0][0]
            best_option = i['retrieval_res'][best_path]["pred"]
        else:
            path2score = {key: item['score'] for key,
                            item in i['retrieval_res'].items()}
            best_path = sorted(path2score.items(),
                                key=lambda x: x[1], reverse=True)[0][0]
            best_option = i['retrieval_res'][best_path]["pred"]
        if is_closed:
            acc_ = metric_max_over_ground_truths(loose_acc, best_option, i['gold'])
        else:
            acc_ = metric_max_over_ground_truths(loose_match, best_option, i['gold'])
        if acc_ > 0:
            right_choice += 1
        if acc_ > acc:
            print("something wired...")
        # if acc_ < acc:
        #     print('best', comp_str)
        #     print('chosen',  compute_confidence(i['retrieval_res'][best_path]['id_log_probs']), np.log(float(i['retrieval_res'][best_path]['ctx_score'])))
        preds.append(best_option)
    print(f"Ret: {ret_cnt}, NoRet: {noret_cnt}")
    print(f"#Right choice rate: {right_choice} out of {has_hit}")
    return preds

def identity(x):
    return x

def select_by_cfd_ctx_selfeval_adp(res, eval_dc, paras=[1,1,1], fns=[None, None, None], is_closed=False, ret_p = None, thres = 0.2):
    fns = [identity if i is None else i for i in fns]
    # print(f"#Adp threshold: {thres}")
    desc = f"calculated formula: {paras[0]} * {fns[0].__name__}(cfd_score) + {paras[1]} * {fns[1].__name__}(ctx_score) + {paras[2]} * {fns[2].__name__}(eval_score)"
    para_dc = {'paras': paras, 'fns': [i.__name__ for i in fns]}
    preds = []
    non_adp_preds = []
    has_hit = 0
    right_choice = 0 
    for ind, i in enumerate(res):
        for key, item in i['retrieval_res'].items():
            if is_closed:
                acc = metric_max_over_ground_truths(loose_acc, item['pred'], i['gold'])
            else:
                
                acc = metric_max_over_ground_truths(loose_match, item['pred'], i['gold'])
            if acc > 0:
                has_hit += 1
                # comp_str = f"{compute_confidence(item['id_log_probs'])}, {np.log(float(item['ctx_score']))}"
                break
            
        for key, item in i['retrieval_res'].items():
            cfd_score = compute_confidence(item['id_log_probs'])
            ctx_score = float(item['ctx_score'])
            eval_score = eval_dc[ind][str(int(key[-1]) + 1)]
            i['retrieval_res'][key]['score'] = paras[0] * fns[0](cfd_score) + paras[1] * fns[1](ctx_score) + paras[2] * fns[2](eval_score)
        if is_closed:
            answer2score = {}
            for key, result in i['retrieval_res'].items():
                answer = postprocess_answer_option_conditioned(result["pred"])
                if len(answer.split()) > 0:
                    answer = answer.split()[0]
                score = result['score']
                answer2score.setdefault(answer, 0)
                answer2score[answer] += score
            sorted_answers = sorted(
                answer2score.items(), key=lambda x: x[1], reverse=True)
            best_option = sorted_answers[0][0]
            hit_results = {key: item for key, item in i['retrieval_res'].items() if postprocess_answer_option_conditioned(item["pred"]).startswith(best_option)}
            
            path2score = {key: item['score'] for key,
                            item in hit_results.items()}
            best_path = sorted(path2score.items(),
                                key=lambda x: x[1], reverse=True)[0][0]
            best_option = i['retrieval_res'][best_path]["pred"]
        else:
            path2score = {key: item['score'] for key,
                            item in i['retrieval_res'].items()}
            best_path = sorted(path2score.items(),
                                key=lambda x: x[1], reverse=True)[0][0]
            best_option = i['retrieval_res'][best_path]["pred"]
        if is_closed:
            acc_ = metric_max_over_ground_truths(loose_acc, best_option, i['gold'])
        else:
            acc_ = metric_max_over_ground_truths(loose_match, best_option, i['gold'])
        if acc_ > 0:
            right_choice += 1
        if acc_ > acc:
            print("something wired...")
        # if acc_ < acc:
        #     print('best', comp_str)
        #     print('chosen',  compute_confidence(i['retrieval_res'][best_path]['id_log_probs']), np.log(float(i['retrieval_res'][best_path]['ctx_score'])))
        if ret_p is not None and ret_p[ind] < thres:
            preds.append(i['no_retrieval'])
        else:
            preds.append(best_option)
        non_adp_preds.append(best_option)
    # print(f"#Right choice rate: {right_choice} out of {has_hit}")
    right_choice_rate = right_choice/has_hit
    return preds, non_adp_preds, desc, para_dc, right_choice_rate

def select(res, eval_dc, paras=[1,1,1], fns=[None, None, None], is_closed=False, ret_p = None, thres = 0.2):
    fns = [identity if i is None else i for i in fns]
    print(f"#Adp threshold: {thres}")
    desc = f"calculated formula: {paras[0]} * {fns[0].__name__}(cfd_score) + {paras[1]} * {fns[1].__name__}(ctx_score) + {paras[2]} * {fns[2].__name__}(eval_score)"
    print(desc)
    preds = []
    non_adp_preds = []
    has_hit = 0
    right_choice = 0 
    no_ret_cnt = 0
    ret_cnt = 0
    for ind, i in enumerate(res):
        for key, item in i['retrieval_res'].items():
            if is_closed:
                acc = metric_max_over_ground_truths(loose_acc, item['pred'], i['gold'])
            else:
                
                acc = metric_max_over_ground_truths(loose_match, item['pred'], i['gold'])
            if acc > 0:
                has_hit += 1
                # comp_str = f"{compute_confidence(item['id_log_probs'])}, {np.log(float(item['ctx_score']))}"
                break
            
        for key, item in i['retrieval_res'].items():
            cfd_score = compute_confidence(item['id_log_probs'])
            ctx_score = float(item['ctx_score'])
            eval_score = eval_dc[ind][str(int(key[-1]) + 1)]
            i['retrieval_res'][key]['score'] = paras[0] * fns[0](cfd_score) + paras[1] * fns[1](ctx_score) + paras[2] * fns[2](eval_score)
        if is_closed:
            answer2score = {}
            for key, result in i['retrieval_res'].items():
                answer = postprocess_answer_option_conditioned(result["pred"])
                if len(answer.split()) > 0:
                    answer = answer.split()[0]
                score = result['score']
                answer2score.setdefault(answer, 0)
                answer2score[answer] += score
            sorted_answers = sorted(
                answer2score.items(), key=lambda x: x[1], reverse=True)
            best_option = sorted_answers[0][0]
            hit_results = {key: item for key, item in i['retrieval_res'].items() if postprocess_answer_option_conditioned(item["pred"]).startswith(best_option)}
            
            path2score = {key: item['score'] for key,
                            item in hit_results.items()}
            best_path = sorted(path2score.items(),
                                key=lambda x: x[1], reverse=True)[0][0]
            best_option = i['retrieval_res'][best_path]["pred"]
        else:
            path2score = {key: item['score'] for key,
                            item in i['retrieval_res'].items()}
            best_path = sorted(path2score.items(),
                                key=lambda x: x[1], reverse=True)[0][0]
            best_option = i['retrieval_res'][best_path]["pred"]
        if is_closed:
            acc_ = metric_max_over_ground_truths(loose_acc, best_option, i['gold'])
        else:
            acc_ = metric_max_over_ground_truths(loose_match, best_option, i['gold'])
        if acc_ > 0:
            right_choice += 1
        if acc_ > acc:
            print("something wired...")
        # if acc_ < acc:
        #     print('best', comp_str)
        #     print('chosen',  compute_confidence(i['retrieval_res'][best_path]['id_log_probs']), np.log(float(i['retrieval_res'][best_path]['ctx_score'])))
        if ret_p is not None and ret_p[ind] < thres:
            preds.append(i['no_retrieval'])
            no_ret_cnt += 1
        else:
            preds.append(best_option)
            ret_cnt += 1
        non_adp_preds.append(best_option)
    print(f"#Right choice rate: {right_choice} out of {has_hit} ({round(right_choice/has_hit * 100, 1)}%)")
    print(f"Ret: {ret_cnt}, NoRet: {no_ret_cnt}, Ret rate: {round(ret_cnt/(ret_cnt + no_ret_cnt) * 100, 1)}%")
    return non_adp_preds, preds

def select_categorical_score(res, eval_dc, paras=[1,1,1], fns=[None, None, None], is_closed=False, ret_p = None, thres = 0.2):
    fns = [identity if i is None else i for i in fns]
    print(f"#Adp threshold: {thres}")
    desc = f"calculated formula: {paras[0]} * {fns[0].__name__}(cfd_score) + {paras[1]} * {fns[1].__name__}(ctx_score) + {paras[2]} * {fns[2].__name__}(eval_score)"
    print(desc)
    preds = []
    non_adp_preds = []
    has_hit = 0
    right_choice = 0 
    no_ret_cnt = 0
    ret_cnt = 0
    for ind, i in enumerate(res):
        for key, item in i['retrieval_res'].items():
            if is_closed:
                acc = metric_max_over_ground_truths(loose_acc, item['pred'], i['gold'])
            else:
                
                acc = metric_max_over_ground_truths(loose_match, item['pred'], i['gold'])
            if acc > 0:
                has_hit += 1
                # comp_str = f"{compute_confidence(item['id_log_probs'])}, {np.log(float(item['ctx_score']))}"
                break
        cfd_scores = []
        ctx_scores = []
        eval_scores = []   
        for key, item in i['retrieval_res'].items():
            cfd_score = compute_confidence(item['id_log_probs'])
            ctx_score = float(item['ctx_score'])
            eval_score = eval_dc[ind][str(int(key[-1]) + 1)]
            cfd_scores.append(cfd_score)
            ctx_scores.append(ctx_score)
            eval_scores.append(eval_score)
        cfd_scores = score_by_magnitude(cfd_scores)
        ctx_scores = score_by_magnitude(ctx_scores)
        eval_scores = score_by_magnitude(eval_scores)
        for key, item in i['retrieval_res'].items():
            cfd_score = cfd_scores[int(key[-1])]
            ctx_score = ctx_scores[int(key[-1])]
            eval_score = eval_scores[int(key[-1])]
            i['retrieval_res'][key]['score'] = paras[0] * fns[0](cfd_score) + paras[1] * fns[1](ctx_score) + paras[2] * fns[2](eval_score)
        if is_closed:
            answer2score = {}
            for key, result in i['retrieval_res'].items():
                answer = postprocess_answer_option_conditioned(result["pred"])
                if len(answer.split()) > 0:
                    answer = answer.split()[0]
                score = result['score']
                answer2score.setdefault(answer, 0)
                answer2score[answer] += score
            sorted_answers = sorted(
                answer2score.items(), key=lambda x: x[1], reverse=True)
            best_option = sorted_answers[0][0]
            hit_results = {key: item for key, item in i['retrieval_res'].items() if postprocess_answer_option_conditioned(item["pred"]).startswith(best_option)}
            
            path2score = {key: item['score'] for key,
                            item in hit_results.items()}
            best_path = sorted(path2score.items(),
                                key=lambda x: x[1], reverse=True)[0][0]
            best_option = i['retrieval_res'][best_path]["pred"]
        else:
            path2score = {key: item['score'] for key,
                            item in i['retrieval_res'].items()}
            best_path = sorted(path2score.items(),
                                key=lambda x: x[1], reverse=True)[0][0]
            best_option = i['retrieval_res'][best_path]["pred"]
        if is_closed:
            acc_ = metric_max_over_ground_truths(loose_acc, best_option, i['gold'])
        else:
            acc_ = metric_max_over_ground_truths(loose_match, best_option, i['gold'])
        if acc_ > 0:
            right_choice += 1
        if acc_ > acc:
            print("something wired...")
        # if acc_ < acc:
        #     print('best', comp_str)
        #     print('chosen',  compute_confidence(i['retrieval_res'][best_path]['id_log_probs']), np.log(float(i['retrieval_res'][best_path]['ctx_score'])))
        if ret_p is not None and ret_p[ind] < thres:
            preds.append(i['no_retrieval'])
            no_ret_cnt += 1
        else:
            preds.append(best_option)
            ret_cnt += 1
        non_adp_preds.append(best_option)
    print(f"#Right choice rate: {right_choice} out of {has_hit} ({round(right_choice/has_hit * 100, 1)}%)")
    print(f"Ret: {ret_cnt}, NoRet: {no_ret_cnt}, Ret rate: {round(ret_cnt/(ret_cnt + no_ret_cnt) * 100, 1)}%")
    return non_adp_preds, preds

def select_by_confidence_no_election(res, is_closed=False):
    preds = []
    has_hit = 0
    right_choice = 0 
    for i in res:
        # compute confidence for ret_i
        
        for key, item in i['retrieval_res'].items():
            if is_closed:
                acc = metric_max_over_ground_truths(loose_acc, item['pred'], i['gold'])
            else:
                
                acc = metric_max_over_ground_truths(loose_match, item['pred'], i['gold'])
            if acc > 0:
                has_hit += 1
                break
            
        for key, item in i['retrieval_res'].items():
            i['retrieval_res'][key]['score'] = compute_confidence(item['id_log_probs'])
        
        path2score = {key: item['score'] for key,
                        item in i['retrieval_res'].items()}
        best_path = sorted(path2score.items(),
                            key=lambda x: x[1], reverse=True)[0][0]
        best_option = i['retrieval_res'][best_path]["pred"]
        if is_closed:
            acc_ = metric_max_over_ground_truths(loose_acc, best_option, i['gold'])
        else:
            acc_ = metric_max_over_ground_truths(loose_match, best_option, i['gold'])
        if acc_ > 0:
            right_choice += 1
        if acc_ > acc:
            print("something wired...")
        preds.append(best_option)
    print(f"#Right choice rate: {right_choice} out of {has_hit}")
    return preds

def select_by_ctx_no_election(res, is_closed=False):
    preds = []
    has_hit = 0
    right_choice = 0 
    for i in res:
        # compute confidence for ret_i
        
        for key, item in i['retrieval_res'].items():
            if is_closed:
                acc = metric_max_over_ground_truths(loose_acc, item['pred'], i['gold'])
            else:
                
                acc = metric_max_over_ground_truths(loose_match, item['pred'], i['gold'])
            if acc > 0:
                has_hit += 1
                break
            
        for key, item in i['retrieval_res'].items():
            i['retrieval_res'][key]['score'] = float(item['ctx_score'])
        
        path2score = {key: item['score'] for key,
                        item in i['retrieval_res'].items()}
        best_path = sorted(path2score.items(),
                            key=lambda x: x[1], reverse=True)[0][0]
        best_option = i['retrieval_res'][best_path]["pred"]
        if is_closed:
            acc_ = metric_max_over_ground_truths(loose_acc, best_option, i['gold'])
        else:
            acc_ = metric_max_over_ground_truths(loose_match, best_option, i['gold'])
        if acc_ > 0:
            right_choice += 1
        if acc_ > acc:
            print("something wired...")
        preds.append(best_option)
    print(f"#Right choice rate: {right_choice} out of {has_hit}")
    return preds


def select_by_cfdnctx_no_election(res, is_closed=False):
    preds = []
    has_hit = 0
    right_choice = 0 
    for i in res:
        # compute confidence for ret_i
        
        for key, item in i['retrieval_res'].items():
            if is_closed:
                acc = metric_max_over_ground_truths(loose_acc, item['pred'], i['gold'])
            else:
                
                acc = metric_max_over_ground_truths(loose_match, item['pred'], i['gold'])
            if acc > 0:
                has_hit += 1
                break
            
        for key, item in i['retrieval_res'].items():
            i['retrieval_res'][key]['score'] = compute_confidence(item['id_log_probs']) + float(item['ctx_score'])
        
        path2score = {key: item['score'] for key,
                        item in i['retrieval_res'].items()}
        best_path = sorted(path2score.items(),
                            key=lambda x: x[1], reverse=True)[0][0]
        best_option = i['retrieval_res'][best_path]["pred"]
        if is_closed:
            acc_ = metric_max_over_ground_truths(loose_acc, best_option, i['gold'])
        else:
            acc_ = metric_max_over_ground_truths(loose_match, best_option, i['gold'])
        if acc_ > 0:
            right_choice += 1
        if acc_ > acc:
            print("something wired...")
        preds.append(best_option)
    print(f"#Right choice rate: {right_choice} out of {has_hit}")
    return preds




def select_optimal_top(res, is_closed=False):
    preds = []
    cnt = 0
    for i in res:
        # compute confidence for ret_i
        hit = False
        for key, item in i['retrieval_res'].items():
            if is_closed:
                i['retrieval_res'][key]['score'] = metric_max_over_ground_truths(loose_acc, item['pred'], i['gold'])
            else:
                
                i['retrieval_res'][key]['score'] = metric_max_over_ground_truths(loose_match, item['pred'], i['gold'])
            if i['retrieval_res'][key]['score'] > 0:
                preds.append(item['pred'])
                cnt += 1
                hit = True
                break
        if not hit:
            preds.append(i['no_retrieval'])
    print(f"#retrieval rate: {cnt/len(res)}")
    return preds, None

def select_optimal_all(res, is_closed=False):
    preds = []
    cnt = 0
    for i in res:
        if is_closed:
            use_all_score = metric_max_over_ground_truths(loose_acc, i['all_doc_retrieval'], i['gold'])
        else:
            use_all_score = metric_max_over_ground_truths(loose_match, i['all_doc_retrieval'], i['gold'])
        if use_all_score > 0:
            preds.append(i['all_doc_retrieval'])
            cnt += 1
            continue
        preds.append(i['no_retrieval'])
    print(f"#retrieval rate: {cnt/len(res)}")
        
    return None, preds


def select_right_ans(res, is_closed=False):
    preds = []
    cnt = 0
    for i in res:
        # compute confidence for ret_i
        pred = i['retrieval_res']['retrieval_0']['pred']
        for key, item in i['retrieval_res'].items():
            if is_closed:
                i['retrieval_res'][key]['score'] = metric_max_over_ground_truths(loose_acc, item['pred'], i['gold'])
            else:
                
                i['retrieval_res'][key]['score'] = metric_max_over_ground_truths(loose_match, item['pred'], i['gold'])
            if i['retrieval_res'][key]['score'] > 0:
                pred = item['pred']
                cnt += 1
                break
        preds.append(pred)
    print(f"#Optimal right rate: {cnt} out of {len(res)}")
    return None, preds

In [46]:
class Analyzer:
    def __init__(self, json_path):
        model_name, dataset_name, task, args = parse_path(json_path)
        print('#'*20)
        print(f"#Model: {model_name}")
        print(f"#Dataset: {dataset_name}")
        print(f"#Task: {task}")
        print(f"#Args: {args}")
        # self.metric_evaluation(evaluation_metrics)
        self.model_name = model_name
        self.dataset_name = dataset_name
        self.task = task
        self.args = args
        
        self.dc = self.process_res(json_path)

        self.no_ret_preds = []
        self.use_all_preds = []
        self.select_top_preds = []

    def process_res(self, json_path):
        with open(json_path) as f:
            json_file = json.load(f)
            
            if type(json_file) == list:
                # print(json_file[0].keys())
                # print(json_file[0]['ret_0_scores'].keys())
                res = json_file
            else:
                # print(json_file["results"][0].keys())
                # print(json_file["results"][0]["retrieval_res"].keys())
                res = json_file['results']
                
        if 'ret_0' in res[0]:
            for data in res:
                for i, ret in enumerate(['ret_0', 'ret_1', 'ret_2', 'ret_3', 'ret_4']):
                    assert data["retrieval_res"][f'retrieval_{i}']['pred'] == data[ret]
                    assert data["retrieval_res"][f'retrieval_{i}']['id_log_probs'] == data[ret+'_log_probs']
                    data["retrieval_res"][f'retrieval_{i}']['ctx'] = data[ret+'_ctx']
                    data["retrieval_res"][f'retrieval_{i}']['ctx_score'] = data[ret+'_ctx_score']
                    data["retrieval_res"][f'retrieval_{i}']['scores'] = data[ret+'_scores']
                    
        return res
    

    
    
    def metric_evaluation(self, select_fns = [select_by_confidence], retrieval_methods = ['no_ret', 'use_all', 'select_top']):
        '''
        metric: str
        retrieval_method: str
        threshold: float
        '''
        if 'arc' in self.dataset_name or 'health' in self.dataset_name:
            metric_fn = loose_acc
        else:
            metric_fn = loose_match
        
        golds = [i['gold'] for i in self.dc]
        print('#'*20)
        print(f"#Total number of samples: {len(self.dc)}")
        for method in retrieval_methods:
            print(f"#====== {method} ======#")
            
            if method == 'no_ret':
                preds = [i['no_retrieval'] for i in self.dc]
                metric_res = [metric_max_over_ground_truths(metric_fn, pred, gold) for pred, gold in zip(preds, golds)]
                print(f"#No retrieval {metric_fn.__name__}: {np.mean(metric_res)}")
                self.no_ret_preds = preds
            elif method == 'use_all':
                preds = [i['all_doc_retrieval'] for i in self.dc]
                metric_res = [metric_max_over_ground_truths(metric_fn, pred, gold) for pred, gold in zip(preds, golds)]
                print(f"#Use all {metric_fn.__name__}: {np.mean(metric_res)}")
                self.use_all_preds = preds
            elif method == 'select_top':
                for select_fn in select_fns:
                    try:
                        print(f"#Method: {select_fn.__name__}")
                    except:
                        print(f"#Method: {select_fn.func.__name__}")
                    preds,adp_preds = select_fn(self.dc, is_closed='arc' in self.dataset_name or 'health' in self.dataset_name)
                    if preds:
                        metric_res = [metric_max_over_ground_truths(metric_fn, pred, gold) for pred, gold in zip(preds, golds)]
                        print(f"#All ret {metric_fn.__name__}: {np.mean(metric_res)}")
                    if adp_preds:
                        metric_res = [metric_max_over_ground_truths(metric_fn, pred, gold) for pred, gold in zip(adp_preds, golds)]
                        print(f"#Adp {metric_fn.__name__}: {np.mean(metric_res)}")
                       
                    
                    print('# ')
                    
        
                
    def hyper_tune(self, select_fns = []):
        '''
        metric: str
        retrieval_method: str
        threshold: float
        '''
        if 'arc' in self.dataset_name or 'health' in self.dataset_name:
            metric_fn = loose_acc
        else:
            metric_fn = loose_match
        
        golds = [i['gold'] for i in self.dc]
        print('#'*20)
        print(f"#Total number of samples: {len(self.dc)}")
        select_top_dc = {}
        out_put_dc = {'model': [], 'dataset': [], 'paras': [], 'fns': [], 'non_adp_acc': [], 'adp_acc': [], 'right_choice_rate': []}
        for select_fn in select_fns:
            # try:
            #     print(f"#Select top method: {select_fn.__name__}")
            # except:
            #     print(f"#Select top method: {select_fn.func.__name__}")
            preds, non_adp_preds, desc, para_dc, right_choice_rate = select_fn(self.dc, is_closed='arc' in self.dataset_name or 'health' in self.dataset_name)
            # print(f"#Description: {desc}")
            metric_res_adp = [metric_max_over_ground_truths(metric_fn, pred, gold) for pred, gold in zip(preds, golds)]
            metric_res_non_adp = [metric_max_over_ground_truths(metric_fn, pred, gold) for pred, gold in zip(non_adp_preds, golds)]
            # print(f"#Adp {metric_fn.__name__}: {np.mean(metric_res_adp)}")
            # print(f"#Non-adp {metric_fn.__name__}: {np.mean(metric_res_non_adp)}")
            # print('# ')
            select_top_dc[desc] = np.mean(metric_res_non_adp)
            out_put_dc['model'].append(self.model_name)
            out_put_dc['dataset'].append(self.dataset_name)
            out_put_dc['paras'].append(para_dc['paras'])
            out_put_dc['fns'].append(para_dc['fns'])
            out_put_dc['non_adp_acc'].append(np.mean(metric_res_non_adp))
            out_put_dc['adp_acc'].append(np.mean(metric_res_adp))
            out_put_dc['right_choice_rate'].append(right_choice_rate)
            
            
            
        select_top_dc = sorted(select_top_dc.items(), key=lambda x: x[1], reverse=True)
        print(f"##1 method: {select_top_dc[0][0]}: {select_top_dc[0][1]}")
        print(f"##2 method: {select_top_dc[1][0]}: {select_top_dc[1][1]}")
        print(f"##3 method: {select_top_dc[2][0]}: {select_top_dc[2][1]}")
        
        return out_put_dc
                


        
                    


In [47]:
# files_dc = {
#     r'eval_res\selfrag-pqa-fullspan.json': r'eval_res\selfragadp-pqa.json',
#     r'eval_res\selfrag-tqa-fullspan.json': r'eval_res\selfragadp-tqa.json',
#     r'eval_res\selfrag-health-fullspan.json': r'eval_res\selfragadp-health.json',
#     r'eval_res\selfrag-arc-fullspan.json': r'eval_res\selfragadp-arc.json',
#     r'eval_res\llama2chat-pqa-fullspan.json': r'eval_res\llama2chat-pqa-selfadp.json',
#     r'eval_res\llama2chat-tqa-fullspan.json': r'eval_res\llama2chat-tqa-selfadp.json',
#     r'eval_res\llama2chat-health-fullspan.json': r'eval_res\llama2chat-health-selfadp.json',
#     r'eval_res\llama2chat-health-fullspan-w_exp.json': r'eval_res\llama2chat-health-selfadp.json',
#     r'eval_res\llama2chat-arc-fullspan.json': r'eval_res\llama2chat-arc-selfadp.json',
#     r'eval_res\llama2chat-arc-fullspan-w_exp.json': r'eval_res\llama2chat-arc-selfadp.json',
#     r'eval_res\llama3Ins-pqa-fullspan.json': r'eval_res\llama3Ins-pqa-selfadp.json',
#     r'eval_res\llama3Ins-tqa-fullspan.json': r'eval_res\llama3Ins-tqa-selfadp.json',
#     r'eval_res\llama3Ins-health-fullspan.json': r'eval_res\llama3Ins-health-selfadp.json',
#     r'eval_res\llama3Ins-health-fullspan-w_exp.json': r'eval_res\llama3Ins-health-selfadp.json',
#     r'eval_res\llama3Ins-arc-fullspan.json': r'eval_res\llama3Ins-arc-selfadp.json',
#     r'eval_res\llama3Ins-arc-fullspan-w_exp.json': r'eval_res\llama3Ins-arc-selfadp.json',
    
# }
files_dc = {
    r'post_processed\selfrag-pqa-fullspan.json': r'eval_res\selfrag-pqa-selfadp.json',
    r'post_processed\selfrag-tqa-fullspan.json': r'eval_res\selfrag-tqa-selfadp.json',
    r'post_processed\selfrag-health-fullspan.json': r'eval_res\selfrag-health-selfadp.json',
    r'post_processed\selfrag-arc-fullspan.json': r'eval_res\selfrag-arc-selfadp.json',
    
    r'post_processed\llama2chat-pqa_processed_2mil7.json': r'eval_res\llama2chat-pqa-selfadp.json',
    r'post_processed\llama2chat-tqa_processed_2mil7.json': r'eval_res\llama2chat-tqa-selfadp.json',
    r'post_processed\llama2chat-health_processed_2mil7.json': r'eval_res\llama2chat-health-selfadp.json',
    r'post_processed\llama2chat-arc_processed_2mil7.json': r'eval_res\llama2chat-arc-selfadp.json',
    r'post_processed\llama3Ins-pqa_processed_2mil7.json': r'eval_res\llama3Ins-pqa-selfadp.json',
    r'post_processed\llama3Ins-tqa_processed_2mil7.json': r'eval_res\llama3Ins-tqa-selfadp.json',
    r'post_processed\llama3Ins-health_processed_2mil7.json': r'eval_res\llama3Ins-health-selfadp.json',
    r'post_processed\llama3Ins-arc_processed_2mil7.json': r'eval_res\llama3Ins-arc-selfadp.json',
}

# note:
# 1. health不elect好，arc elect好
# 2. log一下ctx—score，除了llama2的pqa都有提升
# 3. 1-log(ctx_score)的效果整体不如log(ctx_score)

TODO
   1. selfeval 1-5打分制方法
   2. cfd no_ret, ret, 分桶分布(round(2))放一张图
   3. health:强reasoning: selfreflect; arc: overfit了，confidence

In [57]:
for k, v in files_dc.items():
    ret_p_path = v
    with open(ret_p_path, 'r') as f:
        ret_p_file = json.load(f)
        ret_p = ret_p_file["retrieval_p"]
        try:
            ret_p_hard = ret_p_file["retrieval_p_hard"]
        except:
            print(f"no ret_p_hard for {k}")
    ret_p_path_v2 = v.replace('selfadp', 'selfadp_v2')
    try:
        with open(ret_p_path_v2, 'r') as f:
            ret_p_file = json.load(f)
            ret_p_v2 = ret_p_file["retrieval_p"]
            ret_p_hard_v2 = ret_p_file["retrieval_p_hard"]
        assert len(analyzer.dc) == len(eval_dc) == len(ret_p) == len(ret_p_hard) == len(ret_p_v2) == len(ret_p_hard_v2)
    except:
        print(f"no adp_v2 file for {k}")
    path = k
    eval_path = v.replace('selfadp', 'selfeval')
    with open(eval_path, 'r') as f:
        eval_file = json.load(f)
        eval_dc = eval_file["scores"]
    if 'pqa' in path or 'tqa' in path:
        metric="loose_match"
    else:
        metric='loose_acc'
        
    analyzer = Analyzer(path)
    assert len(analyzer.dc) == len(eval_dc) == len(ret_p)
    fn_ls = [select_by_confidence, 
             select_by_ctx, 
             partial(select_by_selfeval, eval_dc=eval_dc), 
             partial(select, eval_dc=eval_dc, paras=[1,1,0.1], ret_p=ret_p, thres = 0.5),
             partial(select_categorical_score, eval_dc=eval_dc, ret_p=ret_p, thres = 0.5),
             select_right_ans, select_optimal_top, select_optimal_all]
    if 'selfrag' in path:
        fn_ls = [partial(select_selfrag, ret_p=ret_p, thres = 0.5)] + fn_ls
    analyzer.metric_evaluation(select_fns=fn_ls)

no ret_p_hard for post_processed\selfrag-pqa-fullspan.json
no adp_v2 file for post_processed\selfrag-pqa-fullspan.json
####################
#Model: selfrag
#Dataset: pqa
#Task: pqa
#Args: fullspan
####################
#Total number of samples: 1399
#No retrieval loose_match: 0.28377412437455324
#Use all loose_match: 0.4481772694781987
#Method: select_selfrag
#Adp threshold: 0.5
#Right choice rate: 730 out of 868 (84.1%)
Ret: 332, NoRet: 1067, Ret rate: 23.7%
#All ret loose_match: 0.5218012866333095
#Adp loose_match: 0.3173695496783417
# 
#Method: select_by_confidence
#Right choice rate: 703 out of 868
Percentage of right choice: 80.99078341013825%
#All ret loose_match: 0.5025017869907077
# 
#Method: select_by_ctx
#Right choice rate: 711 out of 868
Percentage of right choice: 81.91244239631337%
#All ret loose_match: 0.5082201572551823
# 
#Method: select_by_selfeval
#Right choice rate: 653 out of 868
Percentage of right choice: 75.23041474654379%
#All ret loose_match: 0.46676197283774123

In [88]:
from itertools import product
# Number of apples and people
total_apples = 9
num_people = 3

# Generate all possible distributions
possible_splits_ = [split for split in product(range(total_apples + 1), repeat=num_people) if sum(split) == total_apples]

possible_splits_ = [list(split) for split in possible_splits_]
possible_splits = possible_splits_.copy()
for i in possible_splits_:
    i = i.copy()
    i[-1] = round(i[-1]*0.1,1)
    possible_splits.append(i)
    
possible_fns = []
for i in [None, np.log, np.exp]:
    for j in [None, np.log, np.exp]:
        possible_fns.append([None, i, j])


In [89]:


for k, v in files_dc.items():
    ret_p_path = v
    with open(ret_p_path, 'r') as f:
        ret_p_file = json.load(f)
        ret_p = ret_p_file["retrieval_p"]
        ret_p_hard = ret_p_file["retrieval_p_hard"]
    ret_p_path_v2 = v.replace('selfadp', 'selfadp_v2')
    with open(ret_p_path_v2, 'r') as f:
        ret_p_file = json.load(f)
        ret_p_v2 = ret_p_file["retrieval_p"]
        ret_p_hard_v2 = ret_p_file["retrieval_p_hard"]
    path = k
    eval_path = v.replace('selfadp', 'selfeval')
    with open(eval_path, 'r') as f:
        eval_file = json.load(f)
        eval_dc = eval_file["scores"]
    if 'pqa' in path or 'tqa' in path:
        metric="loose_match"
    else:
        metric='loose_acc'
        
    analyzer = Analyzer(path)
    assert len(analyzer.dc) == len(eval_dc) == len(ret_p) == len(ret_p_hard) == len(ret_p_v2) == len(ret_p_hard_v2)
    select_fns = []
    
    for paras in possible_splits:
        for fns in possible_fns:
            select_fns.append(partial(select_by_cfd_ctx_selfeval_adp, eval_dc=eval_dc, paras=paras, fns=fns, ret_p=ret_p, thres = 0.5))
    
    
    
    output_dc = analyzer.hyper_tune(select_fns=select_fns)

        
    final_df = pd.DataFrame(output_dc)
    if not os.path.isfile('output.csv'):
        final_df.to_csv('output.csv', index=False)
    else: # else it exists so append without writing the header
        final_df.to_csv('output.csv', mode='a', header=False, index=False)

####################
#Model: llama2chat
#Dataset: pqa
#Task: pqa
#Args: processed 2mil7
####################
#Total number of samples: 1399
##1 method: calculated formula: 5 * identity(cfd_score) + 3 * log(ctx_score) + 0.1 * identity(eval_score): 0.540385989992852
##2 method: calculated formula: 6 * identity(cfd_score) + 2 * identity(ctx_score) + 0.1 * identity(eval_score): 0.538241601143674
##3 method: calculated formula: 4 * identity(cfd_score) + 4 * identity(ctx_score) + 0.1 * exp(eval_score): 0.5375268048606148
####################
#Model: llama2chat
#Dataset: tqa
#Task: tqa
#Args: processed 2mil7
####################
#Total number of samples: 11313
##1 method: calculated formula: 4 * identity(cfd_score) + 2 * identity(ctx_score) + 0.3 * identity(eval_score): 0.6397065323079643
##2 method: calculated formula: 4 * identity(cfd_score) + 2 * log(ctx_score) + 0.3 * identity(eval_score): 0.6391761690091046
##3 method: calculated formula: 7 * identity(cfd_score) + 1 * exp(ctx_score) + 1 

In [175]:
res_df = pd.read_csv('output.csv')
ls = []
for model in ['llama2chat', 'llama3Ins']:
    for dataset in ['pqa', 'tqa', 'health', 'arc']:
        sub_df = res_df[(res_df['model'] == model) & (res_df['dataset'] == dataset)]
        sorted_df = sub_df.sort_values(by='non_adp_acc', ascending=False)
        top_score = sorted_df.head(1)['non_adp_acc'].values[0]
        for i in sorted_df.iterrows():
            if top_score - i[1]['non_adp_acc'] < 0.01:
                ls.append(i[1].tolist())
                
print(len(res_df), len(ls))

7920 2710


In [177]:

para_n_fns = []
for i in ls:
    para_n_fns.append((i[2], i[3]))

In [149]:
len(para_n_fns)

4539

In [178]:
from collections import Counter
x = Counter(para_n_fns)
sorted_x = sorted(x.items(), key=lambda x: x[1], reverse=True)

In [179]:
sorted_x  #(('[3, 3, 0.3]', "['identity', 'identity', 'identity']"), 5),

[(('[5, 2, 0.2]', "['identity', 'identity', 'identity']"), 6),
 (('[3, 4, 0.2]', "['identity', 'log', 'identity']"), 6),
 (('[5, 1, 0.3]', "['identity', 'exp', 'exp']"), 6),
 (('[4, 3, 0.2]', "['identity', 'log', 'identity']"), 6),
 (('[4, 3, 0.2]', "['identity', 'identity', 'exp']"), 6),
 (('[5, 3, 0.1]', "['identity', 'log', 'identity']"), 5),
 (('[4, 4, 0.1]', "['identity', 'identity', 'exp']"), 5),
 (('[6, 2, 0.1]', "['identity', 'identity', 'exp']"), 5),
 (('[5, 1, 0.3]', "['identity', 'exp', 'identity']"), 5),
 (('[3, 5, 0.1]', "['identity', 'log', 'exp']"), 5),
 (('[3, 5, 0.1]', "['identity', 'log', 'identity']"), 5),
 (('[6, 1, 0.2]', "['identity', 'exp', 'identity']"), 5),
 (('[5, 3, 0.1]', "['identity', 'log', 'exp']"), 5),
 (('[5, 3, 0.1]', "['identity', 'identity', 'exp']"), 5),
 (('[3, 4, 0.2]', "['identity', 'identity', 'identity']"), 5),
 (('[3, 4, 0.2]', "['identity', 'identity', 'exp']"), 5),
 (('[4, 4, 0.1]', "['identity', 'log', 'exp']"), 5),
 (('[6, 1, 0.2]', "['ide

In [132]:
for i in sorted_x:
    if i[1] >= 7:
        print(i)
        