In [1]:
import pandas as pd
import pickle

import re
import numpy as np

from collections import Counter

In [2]:
from rank_bm25 import BM25Okapi
import nltk
from nltk.stem.porter import PorterStemmer

# nltk.download('punkt')
# nltk.download('stopwords')

def stem_tokenize(text, remove_stopwords=True):
    stemmer = PorterStemmer()
    tokens = [word for sent in nltk.sent_tokenize(text) \
                                      for word in nltk.word_tokenize(sent)]
    tokens = [word for word in tokens if word not in \
          nltk.corpus.stopwords.words('english')]
#     tokens = [word for word in tokens if not re.search(r'[^A-Za-z0-9]', word)]
    return [stemmer.stem(word) for word in tokens]

In [25]:
# Files paths

multi_turn_request_file_path = '../data/dev_synthetic.pkl'
question_bank_path = '../data/question_bank.tsv'
run_file_path = '../sample_runs/dev_bm25_multi_turn'

In [18]:
# Reads files and build bm25 corpus (index)
with open(multi_turn_request_file_path, 'rb') as fi:
    dev = pickle.load(fi)
question_bank = pd.read_csv(question_bank_path, sep='\t').fillna('')

question_bank['tokenized_question_list'] = question_bank['question'].map(stem_tokenize)
question_bank['tokenized_question_str'] = question_bank['tokenized_question_list'].map(lambda x: ' '.join(x))

bm25_corpus = question_bank['tokenized_question_list'].tolist()
bm25 = BM25Okapi(bm25_corpus)

In [19]:
# Reads the dev file and create the context_dict to make predictions
context_dict = dict()
for rec_id in dev:
    ctx_id = dev[rec_id]['context_id']
    if ctx_id not in context_dict:
        context_dict[ctx_id] = {'initial_request': dev[rec_id]['initial_request'],
                                'conversation_context': dev[rec_id]['conversation_context']}

In [22]:
# Runs bm25 for every query and stores output in file.

def build_query(context_info):
    query_str = context_info['initial_request']
    for ctx in context_info['conversation_context']:
        query_str += ctx['question'] + ' ' + ctx['answer']
    return query_str

def select_no_duplicate_questions(bm25_q_list, conv_context):
    prev_questions = [x['question'] for x in conv_context]
    bm25_preds = question_bank.set_index('tokenized_question_str').loc[bm25_q_list, 'question'].tolist()
    pred_list = []
    for q in bm25_preds:
        if q not in prev_questions:
            pred_list.append(q)
    return pred_list

with open(run_file_path, 'w') as fo:
  for ctx_id in context_dict:
    query = build_query(context_dict[ctx_id])
    bm25_ranked_list = bm25.get_top_n(stem_tokenize(query, True), 
                                    bm25_corpus,
                                    n=5)
    bm25_q_list = [' '.join(sent) for sent in bm25_ranked_list]
    preds = select_no_duplicate_questions(bm25_q_list, context_dict[ctx_id]['conversation_context'])
    for i, qid in enumerate(preds):
        fo.write('{} 0 "{}" {} {} bm25_multi_turn\n'.format(ctx_id, qid, i, len(preds)-i))
        break # we write only one result per context.

In [26]:
! python clariq_eval_tool.py    --eval_task document_relevance\
                                --data_dir ../data/ \
                                --multi_turn \
                                --experiment_type dev \
                                --run_file {run_file_path} #\
                                # --out_file {run_file_path}.eval

NDCG1: 0.21898645957785742
NDCG3: 0.201618860054938
NDCG5: 0.19652670322787674
NDCG10: 0.1856817702651898
NDCG20: 0.17112798502504814
P1: 0.2747245453338643
P3: 0.2423116067082614
P5: 0.2295632550112837
P10: 0.2003849727864065
P20: 0.15768949953537767
MRR100: 0.35986740195719213
