# Candidate selection and reranking

### Preliminery NDCG
|Question ID| bert   | ES     |
|-----|--------|--------|
| A.101 | 0.0955 | 0.0400 |
| A.31  | 0.0615 | 0.1808 |
| A.78  | 0      | 0.0435 |
| all | 0.0523 | 0.0881 


### Number of Retrieved
| Question Id | 1k | 5k | 10k | BERT.1k | Actual relevant |
|----------|----|----|-----|---------|-----------------|
| A.101    | 2  | 3  | 3   |     1   | 18              |
| A.31     | 17 | 25 | 30  | 6       | 45              |
| A.78     | 2  | 3  | 4   | 0       | 16              |
| all      | 21 | 32 | 38  | 7       | 79              |

These are NDCG scores with faiss+BERT and ES+BoW retrieval. This notebook aims to retrieve 10k candidates using ES and then rerank them using BERT to report the top-1k

In [1]:
import glob
from tqdm import tqdm
import os
import jsonlines
import re
import xml.etree.ElementTree as ET
from elasticsearch import Elasticsearch
import re
from collections import defaultdict
import numpy as np
from scipy import spatial
import numpy as np

In [2]:
class Topic:
    """
    This class shows a topic for task 1. Each topic has an topic_id which is str, a title and question which
    is the question body and a list of tags.
    """

    def __init__(self, topic_id, title, question, tags):
        self.topic_id = topic_id
        self.title = title
        self.question = question
        self.lst_tags = tags


class TopicReader:
    """
    This class takes in the topic file path and read all the topics into a map. The key in this map is the topic id
    and the values are Topic which has 4 attributes: id, title, question and list of tags for each topic.

    To see each topic, use the get_topic method, which takes the topic id and return the topic in Topic object and
    you have access to the 4 attributes mentioned above.
    """

    def __init__(self, topic_file_path):
        self.__map_topics = self.__read_topics(topic_file_path)

    def __read_topics(self, topic_file_path):
        map_topics = {}
        tree = ET.parse(topic_file_path)
        root = tree.getroot()
        for child in root:
            topic_id = child.attrib['number']
            title = child[0].text
            question = child[1].text
            lst_tag = child[2].text.split(",")
            map_topics[topic_id] = Topic(topic_id, title, question, lst_tag)
        return map_topics

    def get_topic(self, topic_id):
        if topic_id in self.__map_topics:
            return self.__map_topics[topic_id]
        return None

def remove_stop(query):
    with open('englishST.txt') as f:
        all_stopwords = f.readlines()
    # you may also want to remove whitespace characters like `\n` at the end of each line
    all_stopwords = [x.strip() for x in all_stopwords] 
    text_tokens = query.split(' ')
    query = [word for word in text_tokens if not word in all_stopwords]
    query = ' '.join(query)
    return query

def remove_punct(my_str):
    punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~'''
    # To take input from the user
    # my_str = input("Enter a string: ")

    # remove punctuation from the string
    no_punct = ""
    for char in my_str:
        if char not in punctuations:
            no_punct = no_punct + char

    # display the unpunctuated string
    return no_punct

es = Elasticsearch(['http://csxindex05:9200/'], verify_certs=True)
queries = []
#"In this example, the title and the question body of topic with id A.1 is printed."
topic_file_path = "/data/szr207/dataset/ArqMath/Task1/Sample Topics/Task1_Samples_V2.0.xml"
topic_reader = TopicReader(topic_file_path)
dict_q_a = defaultdict(list)
for topic_id in ['A.31','A.78', 'A.101']:
#     topic_id = "A.31"
    title = re.sub('<[^<]+?>', '', topic_reader.get_topic(topic_id).title)
    body = topic_reader.get_topic(topic_id).question
    body_pro = re.sub('<[^<]+?>', '', body)
    query = title + '. ' + body_pro
    queries.append(query)
    query = query.lower()
    query = remove_stop(query)
    
#     print(topic_reader.get_topic(topic_id).lst_tags)
    print(topic_id, topic_reader.get_topic(topic_id).title, topic_reader.get_topic(topic_id).question)
#     print(body_pro)
    print("=================================")
    body = {
        "size": 500,
        "query": {
            "match": {
                "body": query
            }

        }
    }

    res = es.search(index="answer_bulk_index", body=body)
    for result in res['hits']['hits']:
        dict_q_a[topic_id].append(str(result['_source']['post_id']))

A.31 Doubt on Implication of Logical reasoning <p>Its evident that in the truth table of <span class="math-container" id="q_253">$p \to q$</span> </p>  <p>When <span class="math-container" id="q_254">$p$</span> is False and <span class="math-container" id="q_255">$q$</span> is True, Then <span class="math-container" id="q_256">$p \to q$</span> is True.</p>  <p>But in some instances i could not convince myself about this truth value.</p>  <p>For example:</p>  <p><span class="math-container" id="q_257">$p$</span>: Quadrilateral is Cyclic</p>  <p><span class="math-container" id="q_258">$q$</span>: Opposite angles are supplementary</p>  <p>Now is <span class="math-container" id="q_259">$p$</span> is False and <span class="math-container" id="q_260">$q$</span> is True, how can <span class="math-container" id="q_261">$p \to q$</span> can be True?</p> 
A.78 if <span class="math-container" id="q_730">$\sum_{n=1}^{\infty} a_n$</span> converges absolutely then prove that <span class="math-contai

dict_q_a has query id and ans_list given by Elastic

In [3]:
root_path = '/data/szr207/dataset/ArqMath/jsons/answers/'

dict_aid_post = {}
with jsonlines.open(os.path.join(root_path,'all.ans.jsonl')) as reader:
    for obj in tqdm(reader):
        if obj['body']:
            obj['body'] = re.sub('<[^<]+?>', '',  obj['body'])
            dict_aid_post[obj['post_id']] = obj['body']

1435643it [01:04, 22336.92it/s]


In [49]:
ans_emb = np.load('ans_bert.npy')
ques_emb = np.load('query.bert.31.78.101.npy')

In [4]:
def Sort_Tuple(tup):  
    # reverse = None (Sorts in Ascending order)  
    # key is set to sort using second element of  
    # sublist lambda has been used  
    tup.sort(key = lambda x: x[2])  
    return tup

In [57]:
# q_id = 'A.31'
dict_q_idx = {'A.31':0,'A.78':1, 'A.101':2}
result_list = []

for q_id in ['A.31','A.78', 'A.101']:
    tup_postid_sim = []
    for post_id in dict_q_a[q_id]:
        result = 1 - spatial.distance.cosine(ques_emb[dict_q_idx[q_id]], ans_emb[dict_aid_idx[int(post_id)]])
        tup_postid_sim.append((q_id,post_id,result))
    print(len(tup_postid_sim))
    print(q_id)
#     print(Sort_Tuple(tup_postid_sim)[::-1][:1000][:4])
    result_list.append(Sort_Tuple(tup_postid_sim)[::-1][:1000])

1000
A.31
1000
A.78
1000
A.101


In [34]:
result_list

[[('A.31', '2246314', 0.9594488143920898),
  ('A.31', '1508462', 0.9512484669685364),
  ('A.31', '205398', 0.9503062963485718),
  ('A.31', '134854', 0.9488744139671326),
  ('A.31', '2170920', 0.9486956596374512),
  ('A.31', '2650569', 0.9482060074806213),
  ('A.31', '2451383', 0.9478105306625366),
  ('A.31', '1443248', 0.9476836919784546),
  ('A.31', '1788145', 0.9474408030509949),
  ('A.31', '2707867', 0.9471750259399414),
  ('A.31', '1797473', 0.9470195770263672),
  ('A.31', '1880720', 0.9468636512756348),
  ('A.31', '2593223', 0.9464331865310669),
  ('A.31', '484774', 0.9463779330253601),
  ('A.31', '832924', 0.9463376998901367),
  ('A.31', '2839560', 0.9463103413581848),
  ('A.31', '3022023', 0.9462641477584839),
  ('A.31', '2599825', 0.9459764957427979),
  ('A.31', '2260940', 0.945824384689331),
  ('A.31', '1632012', 0.9452061057090759),
  ('A.31', '1886470', 0.9451159238815308),
  ('A.31', '1001158', 0.9450810551643372),
  ('A.31', '92249', 0.944877564907074),
  ('A.31', '2236010

In [58]:
with open('es.bert.100.dat', 'w') as eval_file:
    for res in result_list:
        count = 1
        for tuples in res:
#             print(tuples)
            eval_file.write(tuples[0]+'\t'+ '1\t' + tuples[1] +'\t'+str(count)+'\t'+ str(tuples[2])+'\t'+ 'rr'+'\n')
#             print(tuples[0]+'\t'+ '1\t' + tuples[1] +'\t'+str(count)+'\t'+ str(tuples[2])+'\t'+ 'rr')
            count+=1

### Finetune Reranker with HuggingFace

In [5]:
from transformers import pipeline
from transformers import *

# feat_ext = pipeline("feature-extraction", model="shauryr/arqmath-roberta-base", tokenizer='roberta-base', device=0)
feat_ext = pipeline("feature-extraction", model="shauryr/arqmath-roberta-base-2M", tokenizer='roberta-base', device=0)

# tokenizer =  BertTokenizer.from_pretrained('bert-base-cased')
# feat_ext = pipeline("feature-extraction", model="bert-large-cased", tokenizer='bert-large-cased')

In [6]:
ques_emb = []
for query in queries:
    ques_emb.append(np.mean(feat_ext(query)[0], axis=0))
#     ques_emb.append(np.asarray(feat_ext(query)[0][0]))

In [6]:
def reduce_str(body, N):
    if len(body.split()) > N:
        return ' '.join(body.split()[:N])
    return body

In [7]:
dict_q_idx = {'A.31':0,'A.78':1, 'A.101':2}
result_list = []
count = 0
import math
for qid in ['A.31','A.78', 'A.101']:
    tup_postid_sim = []
    for post_id in tqdm(dict_q_a[qid]):
#         print(qid,post_id)
        try:
#             ans_emb = np.mean(feat_ext(reduce_str(dict_aid_post[int(post_id)], 300))[0], axis=0)
            ans_emb = np.mean(feat_ext(dict_aid_post[int(post_id)])[0], axis=0)
        except:
            continue
            print(post_id)
        
#         ans_emb = np.mean(feat_ext(dict_aid_post[int(post_id)])[0], axis=0)
#         ans_emb = np.asarray(feat_ext(dict_aid_post[int(post_id)])[0][0])
#         print(ans_emb)
        result = 1 - spatial.distance.cosine(ques_emb[dict_q_idx[qid]], ans_emb)
        if math.isnan(result):
            count+=1
            print(qid, post_id, count)
#             tup_postid_sim.append((qid,post_id,str(1.0)))
            continue
        tup_postid_sim.append((qid,post_id,result))
    
    print(len(tup_postid_sim))
    print(qid)
    result_list.append(Sort_Tuple(tup_postid_sim)[::-1][:1000])
#         result = 1 - spatial.distance.cosine(ques_emb[idx], ans_emb[dict_aid_idx[int(post_id)]])
#         tup_postid_sim.append((q_id,post_id,result))
#     print(len(tup_postid_sim))
#     print(q_id)
# #     print(Sort_Tuple(tup_postid_sim)[::-1][:1000][:4])
#     result_list.append(Sort_Tuple(tup_postid_sim)[::-1][:100])

100%|██████████| 500/500 [00:16<00:00, 30.09it/s]
  0%|          | 2/500 [00:00<00:41, 12.05it/s]

500
A.31


100%|██████████| 500/500 [00:27<00:00, 18.45it/s]
  1%|          | 4/500 [00:00<00:21, 23.52it/s]

500
A.78


100%|██████████| 500/500 [00:27<00:00, 18.46it/s]

500
A.101





In [13]:
result_list

[[('A.31', '1612068', 0.9517469053243353),
  ('A.31', '2170920', 0.9473377987624941),
  ('A.31', '2839560', 0.940099963135127),
  ('A.31', '725030', 0.9399963859525906),
  ('A.31', '2815674', 0.9349993200823086),
  ('A.31', '2416512', 0.9326792298067289),
  ('A.31', '2196521', 0.9323500999315928),
  ('A.31', '2130036', 0.9320589886211682),
  ('A.31', '2873453', 0.9313222845943607),
  ('A.31', '2307082', 0.9308457495942378),
  ('A.31', '1192419', 0.9302692248036004),
  ('A.31', '620550', 0.9292281648626275),
  ('A.31', '30622', 0.9286528427209723),
  ('A.31', '260904', 0.9284681982235318),
  ('A.31', '3014093', 0.9282936082761182),
  ('A.31', '1265126', 0.9282760196109724),
  ('A.31', '1337264', 0.9282662981789258),
  ('A.31', '219357', 0.9279817116762755),
  ('A.31', '1107725', 0.9279564355731625),
  ('A.31', '1589923', 0.9277479881367119),
  ('A.31', '2593223', 0.9275930643912051),
  ('A.31', '660565', 0.9273210759428132),
  ('A.31', '2385191', 0.926284957596511),
  ('A.31', '1370866'

In [17]:
np.asarray(feat_ext(dict_aid_post[1180567])).shape

IndexError: index out of range in self

In [8]:
with open('es.rr.finetune.2M.dat', 'w') as eval_file:
    for res in result_list:
        count = 1
        for tuples in res:
#             print(tuples)
            eval_file.write(tuples[0]+'\t'+ '1\t' + tuples[1] +'\t'+str(count)+'\t'+ str(tuples[2])+'\t'+ 'rr'+'\n')
#             print(tuples[0]+'\t'+ '1\t' + tuples[1] +'\t'+str(count)+'\t'+ str(tuples[2])+'\t'+ 'rr')
            count+=1

## Evaluating NDCG

Clearly the score increases from the baseline. So using BERT to rerank was a no brainer

In [2]:
!java -jar jtreceval-0.0.5-jar-with-dependencies.jar -q -m ndcg /data/szr207/dataset/ArqMath/Task1/Sample\ Topics/qrel.V1.0.tsv /data/szr207/projects/ArqMath/runs/es.rr.1k.dat

ndcg                  	A.101	0.1140
ndcg                  	A.31	0.1804
ndcg                  	A.78	0.0435
ndcg                  	all	0.1126


In [3]:
!java -jar jtreceval-0.0.5-jar-with-dependencies.jar -q -m ndcg /data/szr207/dataset/ArqMath/Task1/Sample\ Topics/qrel.V1.0.tsv /data/szr207/projects/ArqMath/runs/es.roberta-base.dat

ndcg                  	A.101	0.1568
ndcg                  	A.31	0.1602
ndcg                  	A.78	0.0443
ndcg                  	all	0.1204


Roberta-base gave the best results in the other Transformer based models - so it was chosen for pretraining