In [1]:
import os
import json
import argparse
import numpy as np
import pytrec_eval

In [2]:
def question_breakdown(result_path, qrel_path, subdomain_path):
    all_result_ori = json.load(open(result_path, "r"))
    all_result = {}
    for did in all_result_ori:
        new_did = "_".join(did.split("_")[-2:])
        all_result[new_did] = all_result_ori[did]
    
    qrels = json.load(open(qrel_path, "r"))
    subdomain = json.load(open(subdomain_path, "r"))
    subdomain["all"] = list(all_result.keys())
    
    print(result_path + "\n")
    metric_results = None
    for source, pids in subdomain.items():
        sqrels = dict(filter(lambda x: x[0] in pids, qrels.items()))
        sqrels = dict(filter(lambda x: x[1] != {"": 1}, sqrels.items())) # QReCC: filtering missings
        sresults = dict(filter(lambda x: x[0] in pids, all_result.items()))

        evaluator = pytrec_eval.RelevanceEvaluator(
                sqrels, {"recip_rank", "recall", "map", "ndcg"})
        metrics = evaluator.evaluate(sresults)
        mrr_list = [v["recip_rank"] for v in metrics.values()]
        recall_5_list = [v["recall_5"] for v in metrics.values()]
        recall_10_list = [v["recall_10"] for v in metrics.values()]
        recall_20_list = [v["recall_20"] for v in metrics.values()]
        recall_30_list = [v["recall_30"] for v in metrics.values()]
        recall_100_list = [v["recall_100"] for v in metrics.values()]
        map_list = [v["map"] for v in metrics.values()]
        ndcg_list = [v["ndcg"] for v in metrics.values()]
        
        np.set_printoptions(precision=4)
        
        eval_metrics = {
            "Recall@5": round(100*np.average(recall_5_list), 2),
            "Recall@10": round(100*np.average(recall_10_list), 2),
            "Recall@20": round(100*np.average(recall_20_list), 2),
            "Recall@30": round(100*np.average(recall_30_list), 2),
            "Recall@100": round(100*np.average(recall_100_list), 2),
            "MRR": round(100*np.average(mrr_list), 2),
            "map": round(100*np.average(map_list), 2),
            "ndcg": round(100*np.average(ndcg_list), 2),
        }
        print(source, len(sqrels))
        print(eval_metrics)
        print("")
        if source == 'all':
            metric_results = metrics
    return metric_results

In [3]:
qrel_path = 'datasets/qrecc/qrels_test.txt'
qtype_path = 'datasets/qrecc/test_question_types.json'

In [4]:
temp_path = 'outputs/BM25/test_rewrite_bm25_scores.json'
res = question_breakdown(temp_path, qrel_path, qtype_path)

outputs/BM25/test_rewrite_bm25_scores.json

trec 371
{'Recall@5': 38.54, 'Recall@10': 53.77, 'Recall@20': 69.95, 'Recall@30': 76.82, 'Recall@100': 98.92, 'MRR': 27.34, 'map': 27.04, 'ndcg': 41.57}

quac 6396
{'Recall@5': 51.21, 'Recall@10': 62.9, 'Recall@20': 74.78, 'Recall@30': 81.1, 'Recall@100': 98.35, 'MRR': 40.32, 'map': 38.98, 'ndcg': 51.68}

nq 1442
{'Recall@5': 51.65, 'Recall@10': 63.8, 'Recall@20': 75.61, 'Recall@30': 81.69, 'Recall@100': 98.96, 'MRR': 40.78, 'map': 39.05, 'ndcg': 52.27}

no-switch 279
{'Recall@5': 68.57, 'Recall@10': 76.89, 'Recall@20': 87.03, 'Recall@30': 90.01, 'Recall@100': 100.0, 'MRR': 55.14, 'map': 53.58, 'ndcg': 64.2}

switch 573
{'Recall@5': 43.71, 'Recall@10': 58.98, 'Recall@20': 73.38, 'Recall@30': 80.47, 'Recall@100': 98.85, 'MRR': 34.23, 'map': 32.76, 'ndcg': 47.02}

first 267
{'Recall@5': 40.45, 'Recall@10': 56.65, 'Recall@20': 71.0, 'Recall@30': 79.21, 'Recall@100': 100.0, 'MRR': 32.7, 'map': 31.09, 'ndcg': 45.72}

all 8209
{'Recall@5': 50.71, '

In [5]:
temp_path = 'outputs/BM25/test_original_fused_ICL_bm25.json'
res = question_breakdown(temp_path, qrel_path, qtype_path)

outputs/BM25/test_original_fused_ICL_bm25.json

trec 371
{'Recall@5': 14.82, 'Recall@10': 22.1, 'Recall@20': 29.51, 'Recall@30': 32.35, 'Recall@100': 42.99, 'MRR': 10.3, 'map': 10.27, 'ndcg': 16.83}

quac 6396
{'Recall@5': 12.1, 'Recall@10': 15.2, 'Recall@20': 18.99, 'Recall@30': 21.31, 'Recall@100': 27.7, 'MRR': 9.29, 'map': 8.84, 'ndcg': 12.8}

nq 1442
{'Recall@5': 11.14, 'Recall@10': 15.14, 'Recall@20': 18.36, 'Recall@30': 21.1, 'Recall@100': 29.06, 'MRR': 9.06, 'map': 8.64, 'ndcg': 12.88}

no-switch 279
{'Recall@5': 4.48, 'Recall@10': 5.38, 'Recall@20': 7.17, 'Recall@30': 7.53, 'Recall@100': 13.44, 'MRR': 3.25, 'map': 3.07, 'ndcg': 5.06}

switch 573
{'Recall@5': 7.41, 'Recall@10': 10.73, 'Recall@20': 14.92, 'Recall@30': 16.7, 'Recall@100': 23.72, 'MRR': 5.86, 'map': 5.75, 'ndcg': 9.37}

first 267
{'Recall@5': 40.45, 'Recall@10': 56.65, 'Recall@20': 71.0, 'Recall@30': 79.21, 'Recall@100': 100.0, 'MRR': 32.7, 'map': 31.09, 'ndcg': 45.72}

all 8209
{'Recall@5': 12.06, 'Recall@10': 15.

In [6]:
temp_path = 'outputs/BM25/test_GPT_rewrite_ICL_post_bm25_scores.json'
res = question_breakdown(temp_path, qrel_path, qtype_path)

outputs/BM25/test_GPT_rewrite_ICL_post_bm25_scores.json

trec 371
{'Recall@5': 27.36, 'Recall@10': 39.89, 'Recall@20': 49.82, 'Recall@30': 53.59, 'Recall@100': 69.05, 'MRR': 19.02, 'map': 18.86, 'ndcg': 29.08}

quac 6396
{'Recall@5': 59.59, 'Recall@10': 68.28, 'Recall@20': 76.21, 'Recall@30': 80.32, 'Recall@100': 89.86, 'MRR': 49.81, 'map': 48.38, 'ndcg': 57.67}

nq 1442
{'Recall@5': 50.85, 'Recall@10': 60.13, 'Recall@20': 69.58, 'Recall@30': 74.69, 'Recall@100': 86.56, 'MRR': 41.51, 'map': 39.71, 'ndcg': 50.27}

no-switch 279
{'Recall@5': 80.74, 'Recall@10': 84.75, 'Recall@20': 88.83, 'Recall@30': 91.36, 'Recall@100': 94.89, 'MRR': 68.89, 'map': 67.42, 'ndcg': 74.11}

switch 573
{'Recall@5': 34.63, 'Recall@10': 44.48, 'Recall@20': 56.63, 'Recall@30': 63.18, 'Recall@100': 76.56, 'MRR': 25.29, 'map': 24.06, 'ndcg': 35.43}

first 267
{'Recall@5': 40.45, 'Recall@10': 56.65, 'Recall@20': 71.0, 'Recall@30': 79.21, 'Recall@100': 100.0, 'MRR': 32.7, 'map': 31.09, 'ndcg': 45.72}

all 8209
{'Re

In [7]:
temp_path = 'outputs/BM25/test_GPT_rewrite_fused_ZSL_post_bm25.json'
res = question_breakdown(temp_path, qrel_path, qtype_path)

outputs/BM25/test_GPT_rewrite_fused_ZSL_post_bm25.json

trec 371
{'Recall@5': 26.82, 'Recall@10': 35.58, 'Recall@20': 47.12, 'Recall@30': 54.54, 'Recall@100': 72.15, 'MRR': 18.5, 'map': 18.26, 'ndcg': 29.1}

quac 6396
{'Recall@5': 55.02, 'Recall@10': 63.2, 'Recall@20': 71.39, 'Recall@30': 75.7, 'Recall@100': 85.55, 'MRR': 45.43, 'map': 44.11, 'ndcg': 53.31}

nq 1442
{'Recall@5': 45.3, 'Recall@10': 54.69, 'Recall@20': 64.06, 'Recall@30': 68.73, 'Recall@100': 81.92, 'MRR': 36.43, 'map': 34.81, 'ndcg': 45.31}

no-switch 279
{'Recall@5': 67.51, 'Recall@10': 73.14, 'Recall@20': 80.47, 'Recall@30': 84.71, 'Recall@100': 90.24, 'MRR': 57.47, 'map': 56.37, 'ndcg': 64.23}

switch 573
{'Recall@5': 31.16, 'Recall@10': 40.79, 'Recall@20': 51.42, 'Recall@30': 56.9, 'Recall@100': 72.45, 'MRR': 23.08, 'map': 22.02, 'ndcg': 32.8}

first 267
{'Recall@5': 40.45, 'Recall@10': 56.65, 'Recall@20': 71.0, 'Recall@30': 79.21, 'Recall@100': 100.0, 'MRR': 32.7, 'map': 31.09, 'ndcg': 45.72}

all 8209
{'Recall@5':

In [8]:
temp_path = 'outputs/BM25/test_Editor_rewrite_fused_ICL_editor_post_bm25.json'
res = question_breakdown(temp_path, qrel_path, qtype_path)

outputs/BM25/test_Editor_rewrite_fused_ICL_editor_post_bm25.json

trec 371
{'Recall@5': 26.28, 'Recall@10': 36.25, 'Recall@20': 47.04, 'Recall@30': 51.62, 'Recall@100': 66.49, 'MRR': 17.43, 'map': 17.08, 'ndcg': 27.15}

quac 6396
{'Recall@5': 62.55, 'Recall@10': 70.46, 'Recall@20': 77.57, 'Recall@30': 81.31, 'Recall@100': 89.95, 'MRR': 53.01, 'map': 51.52, 'ndcg': 60.22}

nq 1442
{'Recall@5': 50.89, 'Recall@10': 59.63, 'Recall@20': 69.28, 'Recall@30': 74.25, 'Recall@100': 85.72, 'MRR': 41.57, 'map': 39.69, 'ndcg': 50.1}

no-switch 279
{'Recall@5': 83.83, 'Recall@10': 89.04, 'Recall@20': 91.8, 'Recall@30': 93.6, 'Recall@100': 96.7, 'MRR': 74.82, 'map': 73.4, 'ndcg': 79.23}

switch 573
{'Recall@5': 32.16, 'Recall@10': 41.17, 'Recall@20': 53.49, 'Recall@30': 59.67, 'Recall@100': 72.16, 'MRR': 22.35, 'map': 20.97, 'ndcg': 32.12}

first 267
{'Recall@5': 40.45, 'Recall@10': 56.65, 'Recall@20': 71.0, 'Recall@30': 79.21, 'Recall@100': 100.0, 'MRR': 32.7, 'map': 31.09, 'ndcg': 45.72}

all 8209


In [9]:
temp_path = 'outputs/BM25/test_Editor_rewrite_post_bm25_scores.json'
res = question_breakdown(temp_path, qrel_path, qtype_path)

outputs/BM25/test_Editor_rewrite_post_bm25_scores.json

trec 371
{'Recall@5': 32.08, 'Recall@10': 43.26, 'Recall@20': 55.12, 'Recall@30': 60.78, 'Recall@100': 77.18, 'MRR': 21.04, 'map': 20.79, 'ndcg': 32.34}

quac 6396
{'Recall@5': 60.49, 'Recall@10': 68.84, 'Recall@20': 77.01, 'Recall@30': 81.01, 'Recall@100': 91.09, 'MRR': 50.67, 'map': 49.18, 'ndcg': 58.55}

nq 1442
{'Recall@5': 51.88, 'Recall@10': 60.67, 'Recall@20': 70.27, 'Recall@30': 74.93, 'Recall@100': 87.97, 'MRR': 42.69, 'map': 40.64, 'ndcg': 51.34}

no-switch 279
{'Recall@5': 79.05, 'Recall@10': 82.35, 'Recall@20': 87.09, 'Recall@30': 89.67, 'Recall@100': 95.27, 'MRR': 69.76, 'map': 68.4, 'ndcg': 74.86}

switch 573
{'Recall@5': 38.08, 'Recall@10': 47.3, 'Recall@20': 59.63, 'Recall@30': 64.78, 'Recall@100': 78.88, 'MRR': 26.92, 'map': 25.21, 'ndcg': 37.0}

first 267
{'Recall@5': 40.45, 'Recall@10': 56.65, 'Recall@20': 71.0, 'Recall@30': 79.21, 'Recall@100': 100.0, 'MRR': 32.7, 'map': 31.09, 'ndcg': 45.72}

all 8209
{'Recall