In [1]:
import pandas as pd
train_qna = pd.read_csv("input/train_qna.csv")
bm25_train_qna = pd.read_hdf("bm25_train_qna.h5")
cos_sim_train_qna = pd.read_hdf("cos_sim_train_qna_encode.h5")
train_qna['bm25_query'] = bm25_train_qna['bm25_query']
train_qna['vector'] = cos_sim_train_qna['vector']

In [2]:
qna = train_qna.copy()
qna = qna.filter(["relevant_articles", "relevant_titles"])

In [3]:
from pandarallel import pandarallel
pandarallel.initialize(progress_bar=True, nb_workers=4)

INFO: Pandarallel will run on 4 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.

https://nalepae.github.io/pandarallel/troubleshooting/


In [4]:
""" Iteratively choose 1 line and comment out the rest """
# bm25_results = open("bm25", "rb")
# cos_sim_results = open("cos_sim", "rb")
rrf_results = open("complete", "rb")

In [5]:
import pickle
rank = pickle.load(rrf_results)

In [6]:
rank

0       [105902, 68242, 203317, 61277, 61312, 105870, ...
1       [277981, 277982, 277983, 277974, 277963, 27798...
2       [146474, 146669, 146449, 372602, 146473, 14644...
3       [194673, 69905, 239384, 28266, 345484, 239381,...
4       [331828, 68028, 331850, 331742, 68042, 353997,...
                              ...                        
3191    [39643, 39430, 39959, 39651, 39640, 39639, 211...
3192    [252948, 252956, 252965, 253044, 252711, 25305...
3193    [277683, 112095, 276604, 76360, 76358, 75019, ...
3194    [317621, 124249, 124253, 155918, 124252, 61425...
3195    [246200, 246181, 176344, 77749, 278694, 347943...
Name: rr, Length: 3196, dtype: object

In [7]:
qna["ranking"] = rank # Edit this appropriately
del rank

In [8]:
def get_top_k_relevance_func(x):
    from rrf_utilities import get_top_k_relevance, bm25_corpus
    corpus_lookup = bm25_corpus['law_article']
    return get_top_k_relevance(x.ranking, corpus_look_up = corpus_lookup, unique=True, k = 10)

In [9]:
# qna.relevant_articles

In [10]:
qna["relevant"] = qna.parallel_apply(get_top_k_relevance_func, axis = 1)

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=799), Label(value='0 / 799'))), HB…

In [None]:
import ast

def precision(predicted, relevant): 
    if type(predicted).__name__ == "str":
        predicted = [str(i) for i in ast.literal_eval(predicted)]
    if type(relevant).__name__ == "str": 
        relevant = [str(i) for i in ast.literal_eval(relevant)] 
    
    return len(set(predicted).intersection(set(relevant)))/len(predicted) if len(predicted) > 0 else 0

def recall(predicted, relevant):
    if type(predicted).__name__ == "str":
        predicted = [str(i) for i in ast.literal_eval(predicted)]
    if type(relevant).__name__ == "str":
        relevant = [str(i) for i in ast.literal_eval(relevant)]
    
    return len(set(predicted).intersection(set(relevant)))/len(relevant) if len(relevant) > 0 else 0 

In [None]:
def f1_score(precision, recall):
    return (2*precision*recall)/(precision + recall) if precision + recall > 0 else 0

def f2_score(precision, recall):
    return (5*precision*recall)/(4*precision + recall) if precision + recall > 0 else 0

In [None]:
def get_indicator(predicted, relevant): 

    return [1 if i in relevant else 0 for i in predicted]

def ave_p(predicted, relevant):

    if type(predicted).__name__ == "str":
        predicted = [str(i) for i in ast.literal_eval(predicted)]
    if type(relevant).__name__ == "str": 
        relevant = [str(i) for i in ast.literal_eval(relevant)] 

    relevant = set(relevant)

    indicator = get_indicator(predicted, relevant)
    avp = 0
    
    for i in range(len(predicted)):
        avp += (precision(predicted[:i + 1], relevant)*indicator[i])
    
    avp = avp/len(relevant) if len(relevant) > 0 else 0
    return avp 

In [None]:
import math

def get_relevance(predicted, relevant): 

    return [1 if i in relevant else 0 for i in predicted]

def non_DCG(predicted, relevant):
    if type(predicted).__name__ == "str":
        predicted = [str(i) for i in ast.literal_eval(predicted)]
    if type(relevant).__name__ == "str": 
        relevant = [str(i) for i in ast.literal_eval(relevant)] 

    relevant = set(relevant)

    relevance = get_relevance(predicted, relevant)
    ideal_relevance = sorted(relevance, reverse=True)

    dcg = 0
    idcg = 0

    for i in range(1, len(relevance) + 1):
        rel_i = relevance[i-1]
        irel_i = ideal_relevance[i-1]

        dcg += rel_i / math.log(1 + i,2)
        idcg += irel_i / math.log(1 + i,2)

    ndcg = dcg/idcg if idcg > 0 else 0

    return ndcg 

In [None]:
from tqdm.auto import tqdm
tqdm.pandas()

In [None]:
qna['precision'] = qna.progress_apply(lambda x: precision(x.relevant, x.relevant_articles), axis = 1) 
qna['recall'] = qna.progress_apply(lambda x: recall(x.relevant, x.relevant_articles), axis = 1) 

qna['f1_score'] = qna.progress_apply(lambda x: f1_score(x.precision, x.recall), axis = 1) 
qna['f2_score'] = qna.progress_apply(lambda x: f2_score(x.precision, x.recall), axis = 1) 

qna['ave_p'] = qna.progress_apply(lambda x: ave_p(x.relevant, x.relevant_articles), axis = 1) 

qna['ndcg'] = qna.progress_apply(lambda x: non_DCG(x.relevant, x.relevant_articles), axis = 1) 

print("mean_avp: " + str(qna['ave_p'].mean()))
print("mean_recall: " + str(qna['recall'].mean())) 
print("mean_f1_score: " + str(qna['f1_score'].mean())) 
print("mean_f2_score: " + str(qna['f2_score'].mean())) 
print("mean_ndcg: " + str(qna['ndcg'].mean())) 

  0%|          | 0/3196 [00:00<?, ?it/s]

  0%|          | 0/3196 [00:00<?, ?it/s]

  0%|          | 0/3196 [00:00<?, ?it/s]

  0%|          | 0/3196 [00:00<?, ?it/s]

  0%|          | 0/3196 [00:00<?, ?it/s]

  0%|          | 0/3196 [00:00<?, ?it/s]

mean_avp: 0.6382613890709948
mean_recall: 0.8711931581143094
mean_f1_score: 0.16129062843706152
mean_f2_score: 0.3148709516043684
mean_ndcg: 0.6996295865909636


In [None]:
qna = qna.drop(columns=["ranking"])

In [None]:
qna

Unnamed: 0,relevant_articles,relevant_titles,relevant,precision,recall,f1_score,f2_score,ave_p,ndcg
0,"[{'law_id': '47/2011/tt-bca', 'article_id': '7'}]",['47/2011/tt-bca'],"[{'law_id': '12/2010/tt-bca', 'article_id': '1...",0.0,0.0,0.000000,0.000000,0.000000,0.000000
1,"[{'law_id': '41/2020/tt-bca', 'article_id': '1...",['41/2020/tt-bca'],"[{'law_id': '41/2020/tt-bca', 'article_id': '1...",0.1,1.0,0.181818,0.357143,1.000000,1.000000
2,"[{'law_id': '159/2020/nđ-cp', 'article_id': '1...",['159/2020/nđ-cp'],"[{'law_id': '159/2020/nđ-cp', 'article_id': '1...",0.1,1.0,0.181818,0.357143,1.000000,1.000000
3,"[{'law_id': '53/2010/qh12', 'article_id': '60'...","['53/2010/qh12', '82/2011/nđ-cp']","[{'law_id': '20/2021/nđ-cp', 'article_id': '11...",0.0,0.0,0.000000,0.000000,0.000000,0.000000
4,"[{'law_id': '63/2020/nđ-cp', 'article_id': '20'}]",['63/2020/nđ-cp'],"[{'law_id': '63/2020/nđ-cp', 'article_id': '19...",0.1,1.0,0.181818,0.357143,0.333333,0.500000
...,...,...,...,...,...,...,...,...,...
3191,"[{'law_id': '06/2021/nđ-cp', 'article_id': '24'}]",['06/2021/nđ-cp'],"[{'law_id': '06/2021/nđ-cp', 'article_id': '24...",0.1,1.0,0.181818,0.357143,1.000000,1.000000
3192,"[{'law_id': '35/2019/nđ-cp', 'article_id': '12'}]",['35/2019/nđ-cp'],"[{'law_id': '35/2019/nđ-cp', 'article_id': '17...",0.1,1.0,0.181818,0.357143,0.333333,0.500000
3193,"[{'law_id': '41/2019/qh14', 'article_id': '3'}]",['41/2019/qh14'],"[{'law_id': '41/2019/qh14', 'article_id': '167...",0.1,1.0,0.181818,0.357143,0.142857,0.333333
3194,"[{'law_id': '100/2019/nđ-cp', 'article_id': '2...",['100/2019/nđ-cp'],"[{'law_id': '57/2010/qh12', 'article_id': '3'}...",0.0,0.0,0.000000,0.000000,0.000000,0.000000


In [None]:
with open("rrf_results_unique", "wb") as out:
    pickle.dump(qna, out)

In [None]:
"""
bm25:
    mean_avp: 0.5154628843263073
    mean_recall: 0.757770129328327
    mean_f1_score: 0.1399954342695269
    mean_f2_score: 0.27350769450268825
    mean_ndcg: 0.5785190622027508

cos_sim:
    mean_avp: 0.6568568473488686
    mean_recall: 0.8622236128493951
    mean_f1_score: 0.15960145403199225
    mean_f2_score: 0.3115998172318573
    mean_ndcg: 0.7118466041872631

rrf:
    mean_avp: 0.6382613890709948
    mean_recall: 0.8711931581143094
    mean_f1_score: 0.16129062843706152
    mean_f2_score: 0.3148709516043684
    mean_ndcg: 0.6996295865909636
"""

'\nbm25:\n    mean_avp: 0.5154628843263073\n    mean_recall: 0.757770129328327\n    mean_f1_score: 0.1399954342695269\n    mean_f2_score: 0.27350769450268825\n    mean_ndcg: 0.5785190622027508\n\ncos_sim:\n    mean_avp: 0.6568568473488686\n    mean_recall: 0.8622236128493951\n    mean_f1_score: 0.15960145403199225\n    mean_f2_score: 0.3115998172318573\n    mean_ndcg: 0.7118466041872631\n'