In [1]:
import json
from tqdm import tqdm
import numpy as np

In [17]:
# preds = '/mnt/scratch/chenqu/stateful_search/29_eval/test_preds.txt'
preds = '/mnt/scratch/chenqu/stateful_search/33/test_preds.txt'
dev_file = '/mnt/scratch/chenqu/msmarco/preprocessed/session_dev_small.txt'
test_file = '/mnt/scratch/chenqu/msmarco/preprocessed/session_test.txt'
train_file = '/mnt/scratch/chenqu/msmarco/preprocessed/session_train_original.txt'

In [18]:
with open(preds) as fin:
    res = json.load(fin)

In [19]:
res.keys()

dict_keys(['qrels', 'run', 'ranker_test_all_label_ids', 'guids', 'preds'])

In [21]:
qrels_file = '/mnt/scratch/chenqu/stateful_search/33/qrels.txt'
run_file = '/mnt/scratch/chenqu/stateful_search/33/run.txt'
with open(qrels_file, 'w') as qrels, open(run_file, 'w') as run:
    
    for pred, label, guid in zip(res['preds'], res['ranker_test_all_label_ids'], res['guids']):
        guid_splits = guid.split('_')
        query_id = ''.join(guid_splits[: 2])
        doc_id = ''.join(guid_splits)
        
        qrels.write('{} 0 {} {}\n'.format(query_id, doc_id, label))
        run.write('{} Q0 {} 0 {} SYSTEM\n'.format(query_id, doc_id, pred))
    

In [20]:
len(res['ranker_test_all_label_ids'])

3807950

In [21]:
def mrr(preds, labels, doc_num_list):
    mrr_list = []
    for num in tqdm(doc_num_list):
        cur_preds = preds[: num]
        cur_labels = labels[: num]
        mrr_list.append(single_mrr(cur_preds, cur_labels))
        
        preds = preds[num :]
        labels = labels[num :]
        
    return {'mrr': np.average(mrr_list)}, mrr_list
        
def single_mrr(preds, labels):
    score = 0.0
    index_rank = np.argsort(preds)[::-1]
    for rank, i in enumerate(index_rank):
        if labels[i] == 1:
            score = 1.0 / (rank + 1.0)
            break
            
    return score

In [19]:
res['test_doc_num_list'][0]

50

In [25]:
mrr_res, _ = mrr(res['preds'][50:100], res['ranker_test_all_label_ids'][50:100], res['test_doc_num_list'][1:2])


  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00, 1567.38it/s][A

In [26]:
print(mrr_res)

{'mrr': 1.0}


In [12]:
def trec_eval(preds, labels, guids):
    qrels = {}
    run = {}
    for pred, label, guid in zip(preds, labels, guids):
        guid_splits = guid.split('_')
        query_id = '_'.join(guid_splits[: 2])
        doc_id = guid_splits[-1]
        
        if query_id in qrels:
            qrels[query_id][doc_id] = int(label)
        else:
            qrels[query_id] = {doc_id: int(label)}
        
        if query_id in run:
            run[query_id][doc_id] = float(pred)
        else:
            run[query_id] = {doc_id: float(pred)}
            
    evaluator = pytrec_eval.RelevanceEvaluator(qrels, {'recip_rank', 'ndcg'})
    res = evaluator.evaluate(run)
    mrr_list = [v['recip_rank'] for v in res.values()]
    ndcg_list = [v['ndcg'] for v in res.values()]
    return {'mrr': np.average(mrr_list), 'ndcg': np.average(ndcg_list)}, qrels, run

In [30]:
guids = []
qid, did = 0, 0
for num in tqdm(res['test_doc_num_list'][1:2]):
    qid += 1
    did = 0
    for i in range(num):
        did += 1
        guids.append('fake_{}_{}'.format(qid, did))


  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00, 1542.59it/s][A

In [32]:
import pytrec_eval
trec_eval(res['preds'][50:100], res['ranker_test_all_label_ids'][50:100], guids)

({'mrr': 0.16666666666666666, 'ndcg': 0.3562071871080222},
 {'fake_1': {'1': 1,
   '2': 0,
   '3': 0,
   '4': 0,
   '5': 0,
   '6': 0,
   '7': 0,
   '8': 0,
   '9': 0,
   '10': 0,
   '11': 0,
   '12': 0,
   '13': 0,
   '14': 0,
   '15': 0,
   '16': 0,
   '17': 0,
   '18': 0,
   '19': 0,
   '20': 0,
   '21': 0,
   '22': 0,
   '23': 0,
   '24': 0,
   '25': 0,
   '26': 0,
   '27': 0,
   '28': 0,
   '29': 0,
   '30': 0,
   '31': 0,
   '32': 0,
   '33': 0,
   '34': 0,
   '35': 0,
   '36': 0,
   '37': 0,
   '38': 0,
   '39': 0,
   '40': 0,
   '41': 0,
   '42': 0,
   '43': 0,
   '44': 0,
   '45': 0,
   '46': 0,
   '47': 0,
   '48': 0,
   '49': 0,
   '50': 0}},
 {'fake_1': {'1': 0.5541460514068604,
   '2': 4.136287316214293e-06,
   '3': 7.670581908314489e-06,
   '4': 5.175919795874506e-06,
   '5': 5.107867309561698e-06,
   '6': 0.33400648832321167,
   '7': 5.2628261073550675e-06,
   '8': 5.175919795874506e-06,
   '9': 0.00019116289331577718,
   '10': 5.107867309561698e-06,
   '11': 5.107867309