In [1]:
import jsonlines
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.special import softmax
import copy
import math
import os

In [2]:
def sigmoid(x):
    return 1 / (1 + math.exp(-x))
    
def prob_processed_data(prob_result):
    for item in prob_result:
        evidence_data = []
        for key, value in item['evidence'].items():
            if 'probability' in value:
                prob = value['probability']
                prob = softmax(prob)
                max_prob = 1 - prob[1]
                evidence_data.append((key, max_prob))
        
        evidence_data_sorted = sorted(evidence_data, key=lambda x: x[1], reverse=True)
        sorted_evidence = {key: item['evidence'][key] for key, _ in evidence_data_sorted}
        item['evidence'] = sorted_evidence
    
    result_data_list = [{'id': item['id'], 'doc_ids': item['evidence']} for item in prob_result]
    for item in result_data_list:
        item['doc_ids'] = [int(key) for key in item['doc_ids'].keys()]
    return result_data_list


def mix_rerank_processed_data(reranker_result, prob_result, alpha):
    k1 = alpha
    k2 = 1 - alpha
    result = []
    for claim_index in range(len(reranker_result)):
        doc_ids = reranker_result[claim_index]['doc_ids']
        relevance_scores = reranker_result[claim_index]['scores']
        prob_list = prob_result[claim_index]['evidence']
        evidence_data = []
        for doc_for_one_claim in range(len(doc_ids)):
            doc_id = doc_ids[doc_for_one_claim]
            relevance_score = relevance_scores[doc_for_one_claim]
            relevance_score = sigmoid(relevance_score)*2 #because the original output is (-infinite，0)，we should map it to (0,1) as relevance
            prob = prob_list[str(doc_id)]['probability']
            prob = softmax(prob)
            # max_prob = k1 * (max(prob[0], prob[2]) -  min(prob[0], prob[2])) - k2 * prob[1]
            max_prob = 1 - prob[1]
            fixed_score = k1*relevance_score + k2*max_prob
            # fixed_score = relevance_score

            evidence_data.append((doc_id, fixed_score, relevance_score, max_prob))

        evidence_data_sorted = sorted(evidence_data, key=lambda x: x[1], reverse=True)
        sorted_evidence = [key for key, _ , _ , _ in evidence_data_sorted]
        sorted_scores = [score for _ , score, _ , _ in evidence_data_sorted]
        sorted_relevance_scores = [relevance for _ , _ , relevance , _ in evidence_data_sorted]
        sorted_prob = [prob for _ , _ , _ , prob  in evidence_data_sorted]
        
        one_result = {}
        one_result["id"] = reranker_result[claim_index]['id']
        one_result["claim"] = reranker_result[claim_index]['claim']
        one_result["doc_ids"] = sorted_evidence
        one_result["scores"] = sorted_scores
        one_result["relevance_scores"] = sorted_relevance_scores
        one_result["prob_scores"] = sorted_prob
        result.append(one_result)
    return result

In [3]:
def filter_scifact(data_list):
    data_list_copy = copy.deepcopy(data_list)
    
    for item in data_list_copy:
        filtered_evidence = {k: v for k, v in item['evidence'].items() if v['provenance'] == 'citation'}
        item['evidence'] = filtered_evidence
        
    return data_list_copy
    
def filter_open(data_list):
    data_list_copy = copy.deepcopy(data_list)
    
    for item in data_list_copy:
        filtered_evidence = {k: v for k, v in item['evidence'].items() if v['provenance'] == 'pooling'}
        item['evidence'] = filtered_evidence
        
    return data_list_copy

In [4]:
def calculate_hit(claim_set, predicted_set):
    for top_k in [50,20,10,5,3,1]:
        eval_dataset = {data['id']: data for data in claim_set}
        hit_one = 0
        hit_all = 0
        total = 0
        hit_one_evi = 0
        hit_all_evi = 0
        total_has_evi = 0
        for retrieval in predicted_set:
            total += 1
            try:
                data = eval_dataset[retrieval['id']]
            except KeyError:
                print('error')
            pred_doc_ids = set(retrieval['doc_ids'][:top_k])
            true_doc_ids = set(map(int, data['evidence'].keys()))
        
            if pred_doc_ids.intersection(true_doc_ids) or not true_doc_ids:
                hit_one += 1

            if pred_doc_ids.issuperset(true_doc_ids):
                hit_all += 1
            if true_doc_ids:
                total_has_evi += 1
                if pred_doc_ids.intersection(true_doc_ids):
                    hit_one_evi += 1
                if pred_doc_ids.issuperset(true_doc_ids):
                    hit_all_evi += 1

        hit_one_evidence = round(hit_one_evi / total_has_evi, 4)
        hit_all_evidence = round(hit_all_evi / total_has_evi, 4)
        
        hit_one = round(hit_one / total, 4)
        hit_all = round(hit_all / total, 4)
    
        result = hit_one_evidence, hit_all_evidence, hit_one, hit_all

        print(result)

    return hit_one_evidence

In [5]:
def calculate_recall_at_k(claim_set, predicted_set):
    for recall_k in [50,20,10,5,3,1]:
        eval_dataset = {data['id']: data for data in claim_set}
        total_evidence = 0
        correct_retrieved = 0
        for retrieval in predicted_set:
            try:
                data = eval_dataset[retrieval['id']]
            except KeyError:
                print('KeyError: claim_id not found')
            pred_doc_ids = set(retrieval['doc_ids'][:recall_k])   
            true_doc_ids = set(map(int, data['evidence'].keys()))
            correct_retrieved += len(pred_doc_ids.intersection(true_doc_ids))
            total_evidence += len(list(data['evidence'].keys()))
        recall_at_k = round(correct_retrieved / total_evidence,4)

        
        print(f"Recall@{recall_k}: {recall_at_k}")
    
    return recall_at_k

def calculate_precision_at_k(claim_set, predicted_set):
    for precision_k in [50, 20, 10, 5, 3, 1]:
        eval_dataset = {data['id']: data for data in claim_set}
        total_retrieved = 0
        correct_retrieved = 0
        for retrieval in predicted_set:
            try:
                data = eval_dataset[retrieval['id']]
            except KeyError:
                print('KeyError: claim_id not found')
                continue 
            pred_doc_ids = set(retrieval['doc_ids'][:precision_k])   
            true_doc_ids = set(map(int, data['evidence'].keys()))
            correct_retrieved += len(pred_doc_ids.intersection(true_doc_ids))
            total_retrieved += min(len(retrieval['doc_ids']), precision_k)
        
        precision_at_k = round(correct_retrieved / total_retrieved, 4) if total_retrieved > 0 else 0.0
        print(f"Precision@{precision_k}: {precision_at_k}")
    
    return precision_at_k


In [6]:
claim_path_train = "dataset/scifact_open/claims.jsonl"
claim_set_train = list(jsonlines.open(claim_path_train))
citation = filter_scifact(claim_set_train)
pooling = filter_open(claim_set_train)

In [7]:
t5 = "document_retrieval_result/rerank_t5/t5_scifact_open_2000.jsonl"
rank_t5 = list(jsonlines.open(t5))

In [8]:
prob_result =  "result_summary/prob_n5_scifact_open.jsonl"
prob_result = list(jsonlines.open(prob_result))

In [9]:
combo = mix_rerank_processed_data(rank_t5,prob_result,0.5)

In [10]:
calculate_recall_at_k(citation, combo)

Recall@50: 0.9234
Recall@20: 0.8852
Recall@10: 0.8565
Recall@5: 0.8134
Recall@3: 0.7321
Recall@1: 0.5598


0.5598

In [11]:
calculate_recall_at_k(pooling, combo)

Recall@50: 0.9163
Recall@20: 0.7649
Recall@10: 0.5936
Recall@5: 0.4701
Recall@3: 0.3586
Recall@1: 0.1195


0.1195

In [12]:
# Generate train/dev data for +VeriRel
# Here is an example using ComboScorer in SciFact-Open's documents.
# +VeriRel in paper is trained on SciFact only, including 809 cliams with documents in SciFact corpus, as shown in an example file.

target = combo
claim_path_train = "dataset/scifact/claims_train.jsonl"
# ouput_path = "reranker_train_data/train_n5_verirel.jsonl"

claim_set_train = list(jsonlines.open(claim_path_train))


output_data = []

for item_idx in range(len(target)):
    id = target[item_idx]['id']
    claim = target[item_idx]['claim']
    doc_ids = target[item_idx]['doc_ids'][:20]
    prf_scores = target[item_idx]['scores'][:20]

    gold_evi = list(claim_set_train[item_idx]['evidence'].keys())

    for doc_idx in range(len(doc_ids)):
        if str(doc_ids[doc_idx]) in gold_evi:
            prf_scores[doc_idx] = 1
            gold_evi.remove(str(doc_ids[doc_idx]))

    for extra_evi in gold_evi:
        doc_ids.append(int(extra_evi))
        prf_scores.append(1)

    output_data.append({
        'id': id,
        'claim': claim,
        'doc_ids': doc_ids,
        'prf_scores': prf_scores
    })

# with jsonlines.open(ouput_path, 'w') as writer:
#     writer.write_all(output_data)

In [13]:
output_data[1]

{'id': 8,
 'claim': '25% of patients with melanoma and an objective response to PD-1 blockade will experience a progression in their melanoma.',
 'doc_ids': [90543925,
  46759314,
  85563812,
  3471191,
  16869160,
  4535882,
  4468861,
  16151191,
  27452674,
  73448986,
  8443224,
  4430143,
  20634012,
  52896012,
  37970308,
  73496636,
  58602640,
  13758726,
  25967339,
  208190996,
  13734012],
 'prf_scores': [0.8430081279765058,
  0.8069977897657188,
  0.7339445404429519,
  0.7284783915761297,
  0.701802160556745,
  0.6954451328224922,
  0.5576161104618667,
  0.5490565537207442,
  0.5452018016621938,
  0.529645609493147,
  0.5235234311643397,
  0.5022624192675493,
  0.49779895210932745,
  0.495493161813086,
  0.49500382823649736,
  0.4948032039116302,
  0.49388343836439896,
  0.48632340028069526,
  0.48132901319658183,
  0.4659646362193301,
  1]}

In [14]:
output_data[7]

{'id': 38,
 'claim': 'A deficiency of vitamin B6 increases blood levels of homocysteine.',
 'doc_ids': [45829252,
  11911440,
  22556029,
  35969491,
  2326835,
  5702170,
  8294579,
  4511158,
  33409100,
  16252863,
  15834427,
  4515153,
  6084615,
  12810152,
  35765068,
  207584290,
  23481830,
  10662555,
  24278506,
  5640510],
 'prf_scores': [0.902665050508054,
  0.8708963809070753,
  0.8698555186254426,
  0.8024508961400137,
  0.797819553858055,
  0.765661512688123,
  0.7008841344217085,
  0.6732646584018274,
  1,
  0.6219188457835844,
  0.5662139375124435,
  0.5643951489571466,
  0.5605270635818482,
  0.5512924270833537,
  0.5438508768769463,
  0.5366398960834721,
  0.5223664523460679,
  0.46648852831982635,
  0.4561933331730313,
  0.4511174614591327]}