# Math-BERT and Anserini Demo on ARQMath CLEF 2020 Main Task 


This notebook provides a demo on how to get started in searching the [COVID-19 Open Research Dataset](https://pages.semanticscholar.org/coronavirus-research) (release of 2020/03/20) from AI2.
In this notebook, we'll be working with the title + abstract + body index. 

In [1]:
from IPython.core.display import display, HTML

First, install Python dependencies

In [2]:
import json
import os
import torch
import numpy
from transformers import *
import warnings
import glob
from tqdm.notebook import tqdm
import os
import jsonlines
import re
import xml.etree.ElementTree as ET
from collections import defaultdict
import numpy as np
from scipy import spatial
import numpy as np
import gc
import math
import torch
import pandas as pd
warnings.filterwarnings('ignore')

ques_path = '/data/szr207/dataset/ArqMath/jsons/questions/all.ques.jsonl'
ans_path = '/data/szr207/dataset/ArqMath/jsons/answers/all.ans.jsonl'
topic_file_path = "/data/szr207/dataset/ArqMath/Task1/Topics/Topics_V2.0.xml"
runs_path = '/data/szr207/projects/ArqMath/runs'
model_path = "/data/szr207/github/transformers/examples/language-modeling/output/checkpoint-3000000"

In [3]:
tokenizer = AutoTokenizer.from_pretrained('roberta-base', do_lower_case=False)
model = AutoModel.from_pretrained(model_path)

In [7]:
dict_ques = {}
dict_ans = {}
dict_aid_body = {}

with jsonlines.open(os.path.join(ques_path)) as reader:
        for obj in tqdm(reader):
            dict_ques[obj['post_id']] = obj
    
with jsonlines.open(os.path.join(ans_path)) as reader:
        for obj in tqdm(reader):
            dict_ans[obj['post_id']] = obj

for a_id in tqdm(list(dict_ans.keys())):
    ans_body = re.sub('<[^<]+?>', '',  dict_ans[a_id]['body'])
    qid = dict_ans[a_id]['parent_id']
    ques_body = re.sub('<[^<]+?>', '',  dict_ques[qid]['body'])
    ques_title = re.sub('<[^<]+?>', '',  dict_ques[qid]['title'])
    dict_aid_body[a_id] = ques_title + '. ' + ques_body + '. ' + ans_body

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1435643.0), HTML(value='')))




var `dict_idx_text` has queries and there ids

In [4]:
class Topic:
    """
    This class shows a topic for task 1. Each topic has an topic_id which is str, a title and question which
    is the question body and a list of tags.
    """

    def __init__(self, topic_id, title, question, tags):
        self.topic_id = topic_id
        self.title = title
        self.question = question
        self.lst_tags = tags


class TopicReader:
    """
    This class takes in the topic file path and read all the topics into a map. The key in this map is the topic id
    and the values are Topic which has 4 attributes: id, title, question and list of tags for each topic.

    To see each topic, use the get_topic method, which takes the topic id and return the topic in Topic object and
    you have access to the 4 attributes mentioned above.
    """

    def __init__(self, topic_file_path):
        self.__map_topics = self.__read_topics(topic_file_path)

    def __read_topics(self, topic_file_path):
        map_topics = {}
        tree = ET.parse(topic_file_path)
        root = tree.getroot()
        for child in root:
            topic_id = child.attrib['number']
            title = child[0].text
            question = child[1].text
            lst_tag = child[2].text.split(",")
            map_topics[topic_id] = Topic(topic_id, title, question, lst_tag)
        return map_topics

    def get_topic(self, topic_id):
        if topic_id in self.__map_topics:
            return self.__map_topics[topic_id]
        return None

def remove_stop(query):
    with open('englishST.txt') as f:
        all_stopwords = f.readlines()
    # you may also want to remove whitespace characters like `\n` at the end of each line
    all_stopwords = [x.strip() for x in all_stopwords] 
    text_tokens = query.split(' ')
    query = [word for word in text_tokens if not word in all_stopwords]
    query = ' '.join(query)
    return query

def remove_punct(my_str):
    punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~'''
    # To take input from the user
    # my_str = input("Enter a string: ")

    # remove punctuation from the string
    no_punct = ""
    for char in my_str:
        if char not in punctuations:
            no_punct = no_punct + char

    # display the unpunctuated string
    return no_punct

queries = []
#"In this example, the title and the question body of topic with id A.1 is printed."
topic_reader = TopicReader(topic_file_path)
dict_idx_text = {}

for topic_id in tqdm(topic_reader._TopicReader__map_topics):
    title = re.sub('<[^<]+?>', '', topic_reader.get_topic(topic_id).title)
    body = topic_reader.get_topic(topic_id).question
    body_pro = re.sub('<[^<]+?>', '', body)
    query = title + '. ' + body_pro
    dict_idx_text[topic_id] = query

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))




In [5]:
def show_query(query):
    """HTML print format for the searched query"""
    return HTML('<br/><div style="font-family: Times New Roman; font-size: 20px;'
                'padding-bottom:12px"><b>Query</b>: '+query+'</div>')

def show_document(idx, doc):
    """HTML print format for document fields"""
    return HTML('<div style="font-family: Times New Roman; font-size: 18px; padding-bottom:10px">' + 
               f'<b>Document {idx}:</b> -- ' +
               f'{doc}. ' )

def show_query_results(query, list_docs, top_k=10):
    """HTML print format for the searched query"""
    display(show_query(query))
    for i, hit in enumerate(list_docs[:top_k]):
        display(show_document(hit['docid'], hit['body']))
    return list_docs[:top_k]

In [236]:
topic_id = 'A.21'
query = dict_idx_text[topic_id]

# df_judge = pd.read_csv(os.path.join(runs_path, "tf.psu-task1-mlt.base-auto-both-A.tsv"), sep='\t',
#             names = ['topic_id', 'Q','post_id', 'rank', 'score', 'algo'])

df_judge = pd.read_csv(os.path.join('/data/szr207/projects/ArqMath/notebooks', "tf.psu-task1-rrf.anserini.bert-auto-both-P.tsv"), sep=' ',
            names = ['topic_id', 'Q','post_id', 'rank', 'score', 'algo'])

df_q = df_judge.loc[df_judge['topic_id'] == topic_id]

post_idx = np.array(df_q['post_id'])

list_docs = []
for idx in tqdm(post_idx):
    dict_doc = {}
    dict_doc['body'] = dict_aid_body[idx]
    dict_doc['docid'] = idx
    list_docs.append(dict_doc)

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




In [237]:
# hits = show_query_results(query, list_docs, 10)
# len(hits)

Let's extract contextualized vectors of queries and abstracts from SciBERT for highlighting relevant paragraphs.

First, extract the contextualized vectors of the query above:

$$q_1, \ldots, q_T = \text{SciBERT}(\text{query})$$

In [238]:
def extract_bert(text, tokenizer, model):
    text_ids = torch.tensor([tokenizer.encode(text, add_special_tokens=True)])
    text_words = tokenizer.convert_ids_to_tokens(text_ids[0])[1:-1]
#     print(text_words)
    n_chunks = int(numpy.ceil(float(text_ids.size(1))/510))
    states = []
    
    for ci in range(n_chunks):
        text_ids_ = text_ids[0, 1+ci*510:1+(ci+1)*510]            
        text_ids_ = torch.cat([text_ids[0, 0].unsqueeze(0), text_ids_])
        if text_ids[0, -1] != text_ids[0, -1]:
            text_ids_ = torch.cat([text_ids_, text_ids[0,-1].unsqueeze(0)])
        
        with torch.no_grad():
            state = model(text_ids_.unsqueeze(0))[0]
            state = state[:, 1:-1, :]
        states.append(state)

    state = torch.cat(states, axis=1)
    return text_ids, text_words, state[0]

In [239]:
query_ids, query_words, query_state = extract_bert(query, tokenizer, model)

Second, let's extract contextualized vectors of all the paragraphs from the hit #7:

$$p_1^k, \ldots, p_{T_k}^k = \text{SciBERT}(\text{paragraph}^k)$$

In [240]:
ii = 1
# doc_json = json.loads(list_docs[ii].raw)

paragraph_states = []
for par in tqdm(list_docs[:10]):
#     par_ids, par_words, state = extract_bert(par['body'], tokenizer, model)
    state = extract_bert(par['body'], tokenizer, model)
    paragraph_states.append(state)

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Token indices sequence length is longer than the specified maximum sequence length for this model (801 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1564 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1001 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (636 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2289 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for 




We then compute the cosine similarity matrix between the query and each paragraph:

$$A^k = [a^k_{ij}] \in \mathbb{R}^{|\text{query}| \times |\text{paragraph}^k|},$$

where

$$a^k_{ij} = \frac{q_i^\top p_j^k}{\| q_i \| \| p_j^k \|}$$


In [241]:
def cross_match(state1, state2):
    state1 = state1 / torch.sqrt((state1 ** 2).sum(1, keepdims=True))
    state2 = state2 / torch.sqrt((state2 ** 2).sum(1, keepdims=True))
    sim = (state1.unsqueeze(1) * state2.unsqueeze(0)).sum(-1)
    return sim

In [242]:
sim_matrices = []
for pid, par in tqdm(enumerate(list_docs[:10])):
    sim_score = cross_match(query_state, paragraph_states[pid][2])
    sim_matrices.append(sim_score)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




Let's retrieve the most relevant paragraphs first, where define the top-$M$ most relevant paragraphs as 

$$\arg\text{top-$M$}_{k=1}^K \max_{i=1,\ldots,|\text{query}|} \max_{j=1,\ldots, |\text{paragraph}^k|} A_{ij}^k$$

that is, a paragraph with the highly matched words to the query words is considered relevant.

In [243]:
paragraph_relevance = [torch.max(sim).item() for sim in sim_matrices]

# Select the index of top 5 paragraphs with highest relevance
rel_index = numpy.argsort(paragraph_relevance)[-10:][::-1]

In [244]:
def show_sections(section):
    """HTML print format for document subsections"""
    return HTML('<div style="font-family: Times New Roman; font-size: 18px; padding-bottom:10px; margin-left: 15px">' + 
        f' -- {section.replace("Ġ","")} </div>')

# display(show_query(query))
# # display(show_document(ii, list_docs[0]))
# # for ri in numpy.sort(rel_index):
# for ri in rel_index:
#     print(ri)
#     display(show_document(list_docs[ri]['docid'], list_docs[ri]['body']))

We want to look at more details by highlighting relevant phrases in each paragraph, where we define relevant phrases for each paragraph as

$$\arg\text{top-$M$}_{j=1,\ldots, |\text{paragraph}^k|} \max_{i=1,\ldots,|\text{query}|} A_{ij}^k$$

that is, any word that had a high similarity to each of the query words is considered relevant. given these words, we highlight a window of 10 surrounding each of them.

In [245]:
def highlight_paragraph(ptext, rel_words, max_win=10):
    para = ""
    prev_idx = 0
    for jj in rel_words:
        
        if prev_idx > jj:
            continue
        
        found_start = False
        for kk in range(jj, prev_idx-1, -1):
            if ptext[kk] == "." and (ptext[kk+1][0].isupper() or ptext[kk+1][0] == '['):
                sent_start = kk
                found_start = True
                break
        if not found_start:
            sent_start = prev_idx-1
            
        found_end = False
        for kk in range(jj, len(ptext)-1):
            if ptext[kk] == "." and (ptext[kk+1][0].isupper() or ptext[kk+1][0] == '['):
                sent_end = kk
                found_end = True
                break
                
        if not found_end:
            if kk >= len(ptext) - 2:
                sent_end = len(ptext)
            else:
                sent_end = jj
        
        para = para + " "
        para = para + " ".join(ptext[prev_idx:sent_start+1])
        para = para + "<font color='green'>"
        para = para + " ".join(ptext[sent_start+1:sent_end])
        para = para + "</font> "
        prev_idx = sent_end
        
    if prev_idx < len(ptext):
        para = para + " ".join(ptext[prev_idx:])

    return para

In [246]:
display(show_query(query))

# display(show_document(ii, hits[ii]))

for ri in numpy.sort(rel_index):
# for ri in rel_index:
    sim = sim_matrices[ri].data.numpy()
    
    # Select the two highest scoring words in the paragraph
    rel_words = numpy.sort(numpy.argsort(sim.max(0))[-2:][::-1])
#     print(rel_words)
    p_tokens = paragraph_states[ri][1]
#     print(p_tokens)
    para = highlight_paragraph(p_tokens, rel_words)
#     display(show_sections(list_docs[ri]["body"]))
    print(ri)
    display(show_sections(para))

0


1


2


3


4


5


6


7


8


9
