In [1]:
import pandas as pd
from utils import clean_text
import faiss
import time

from sentence_transformers import SentenceTransformer
from sentence_transformers import CrossEncoder

model = SentenceTransformer('msmarco-distilbert-base-dot-prod-v3')
# cross_model = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-6', max_length=512)

index = faiss.read_index('data_article.index')
data_chunk = pd.read_csv('data_chunk.csv')

In [2]:
def fetch_data_info(dataframe_idx, score):

    '''Data should be data_chunk'''
    info = data_chunk.iloc[dataframe_idx]
    meta_dict = {}
    meta_dict['id'] = info['id']
    meta_dict['article'] = info['article']
    meta_dict['score'] = score

    return meta_dict

In [3]:
def search(query, top_k, index, model):

    query_vector = model.encode([query])
    top_k = index.search(query_vector, top_k)

    top_k_ids = list(top_k[1].tolist()[0])
    score = list(top_k[0].tolist()[0])

    results =  [fetch_data_info(idx, score) for idx, score in zip(top_k_ids, score)]

    return results

In [15]:
def query_answer(query, query_id, cross_model):
    # query = "Who is the vice chairman of Samsung?"
    query = clean_text(query)

    # Search top 20 related documents
    results = search(clean_text(query), top_k=20, index=index, model=model)

    # Sort the scores in decreasing order
    model_inputs = [[query, result['article']] for result in results]
    scores = cross_model.predict(model_inputs)
    ranked_results = [{'id': result['id'], 'article': result['article'], 'score': score} for result, score in zip(results, scores)]
    ranked_results = sorted(ranked_results, key=lambda x: x['score'], reverse=True)

    result_dataset = []
    for i, rank in enumerate(ranked_results[:3]):
        dataset = {'question_id': query_id,
                   'rank': i + 1,
                   'id': rank['id'] // 10}
        result_dataset.append(dataset)

    return result_dataset

In [5]:
def mrr_score(answers, queries):
    '''answers is a list of list of ids'''
    score = []
    for i, answer in enumerate(answers):
        for j, index in enumerate(answer):
            if index == queries[i]:
                score.append(1 / (j + 1))
                break
        if len(score) < (i + 1):
            score.append(0)
    return sum(score) / len(score) if len(score) > 0 else 0

In [6]:
def accuracy_score(answers, queries):
    '''answers is a list of list of ids'''
    score = []
    for i, answer in enumerate(answers):
        for index in answer:
            if index == queries[i]:
                score.append(1)
                break
        if len(score) != i + 1:
            score.append(0)
    return sum(score) / len(score) if len(score) > 0 else 0

In [7]:
cross_models = ['cross-encoder/ms-marco-MiniLM-L-12-v2',
                'cross-encoder/ms-marco-MiniLM-L-6-v2',
                'cross-encoder/ms-marco-MiniLM-L-4-v2',
                'cross-encoder/ms-marco-MiniLM-L-2-v2',
                'cross-encoder/ms-marco-TinyBERT-L-6',
                'cross-encoder/ms-marco-TinyBERT-L-2-v2']

In [38]:
question_list = pd.read_csv('question_test_data_2.csv')
question_list

Unnamed: 0,doc_id,question
0,17552,"Who is the author of the memoir ""Nicotine""?"
1,17382,What skills are inmates learning in the innova...
2,17547,Who surrendered to the authorities for changin...
3,17778,What is the name of the journalist and archivi...
4,17841,Which leader of the fringe movement embracing ...
5,18228,"Who narrated the documentary ""I Am Not Your Ne..."
6,18443,Who is the billionaire restaurant owner nomina...
7,18170,Who is the federal judge that ordered Presiden...
8,17980,What business relationship between Donald Trum...
9,17838,Who vowed to take executive action on a nearly...


In [22]:
# model = SentenceTransformer('msmarco-distilbert-base-dot-prod-v3')
# cross_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)  

In [35]:
def test_model(question_list, cross_model):
    start_time = time.time()

    answers = []
    for id, question in enumerate(question_list['question']):
        answer = query_answer(question, id, cross_model)
        answers.append(answer)

    answers = [answer for sublist in answers for answer in sublist]
    answers = pd.DataFrame(answers)
    
    question_article_ids = {}
    for i, question_id in enumerate(answers['question_id']):
        if question_id not in question_article_ids:
            question_article_ids[question_id] = [answers['id'][i]]
        else:
            question_article_ids[question_id].append(answers['id'][i])

    reranked_result = [question_article_ids[x] for x in question_article_ids]

    accuracy = accuracy_score(reranked_result, question_list['doc_id'])
    mrr = mrr_score(reranked_result, question_list['doc_id'])
    
    end_time = time.time()
    elapsed_time = end_time - start_time
    
    result = {'accuracy_score': accuracy, 
              'mrr_score': mrr,
              'time': elapsed_time}

    return result

In [36]:
test_result = []

for cross_model_name in cross_models:
    cross_model = CrossEncoder(cross_model_name) 

    result = test_model(question_list, cross_model)

    result['cross_model'] = cross_model_name
    
    test_result.append(result)

test_result = pd.DataFrame(test_result)

In [37]:
test_result

Unnamed: 0,accuracy_score,mrr_score,time,cross_model
0,1.0,1.0,96.617535,cross-encoder/ms-marco-MiniLM-L-12-v2
1,1.0,1.0,48.962104,cross-encoder/ms-marco-MiniLM-L-6-v2
2,1.0,0.954545,32.835262,cross-encoder/ms-marco-MiniLM-L-4-v2
3,1.0,1.0,16.814126,cross-encoder/ms-marco-MiniLM-L-2-v2
4,1.0,0.954545,101.113986,cross-encoder/ms-marco-TinyBERT-L-6
5,1.0,1.0,3.559668,cross-encoder/ms-marco-TinyBERT-L-2-v2
