# Anserini for ARQMath


This notebook provides a demo on how to get started in searching the [COVID-19 Open Research Dataset](https://pages.semanticscholar.org/coronavirus-research) (release of 2020/03/20) from AI2.
In this notebook, we'll be working with the title + abstract + body index. 

In [4]:
from IPython.core.display import display, HTML

First, install Python dependencies

In [1]:
# %%capture
# !pip install pyserini==0.8.1.0
# !pip install transformers

import json
import os
import torch
import numpy
from tqdm.notebook import tqdm
from transformers import *
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

ARQ_INDEX = '/data/szr207/dataset/ArqMath/anserini_index'
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-11.0.7.10-4.el7_8.x86_64"
topic_file_path = "/data/szr207/dataset/ArqMath/Task1/Topics/Topics_V2.0.xml"
# topic_file_path = "/data/szr207/dataset/ArqMath/Task1/Sample Topics/Task1_Samples_V2.0.xml"
model_path = "/data/szr207/github/transformers/examples/language-modeling/output/checkpoint-4000000"

In [2]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

Let's load ARQ-BERT

In [3]:
tokenizer =  RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=False)
model = RobertaModel.from_pretrained(model_path)
model = model.cuda()

In [4]:
!wget -v -O anserini_index.tar.gz -L https://psu.box.com/shared/static/fpuz0fywuao1twh1nhd1pmu451fs8rr2.gz

--2020-06-18 13:04:37--  https://psu.box.com/shared/static/fpuz0fywuao1twh1nhd1pmu451fs8rr2.gz
Resolving psu.box.com (psu.box.com)... 107.152.29.197
Connecting to psu.box.com (psu.box.com)|107.152.29.197|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /public/static/fpuz0fywuao1twh1nhd1pmu451fs8rr2.gz [following]
--2020-06-18 13:04:37--  https://psu.box.com/public/static/fpuz0fywuao1twh1nhd1pmu451fs8rr2.gz
Reusing existing connection to psu.box.com:443.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://psu.app.box.com/public/static/fpuz0fywuao1twh1nhd1pmu451fs8rr2.gz [following]
--2020-06-18 13:04:37--  https://psu.app.box.com/public/static/fpuz0fywuao1twh1nhd1pmu451fs8rr2.gz
Resolving psu.app.box.com (psu.app.box.com)... 107.152.29.201
Connecting to psu.app.box.com (psu.app.box.com)|107.152.29.201|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://public.boxcloud.com/d/1/b1!fEI

You can use `pysearch` to search over an index. Here's the basic usage:

In [5]:
def show_query(query):
    """HTML print format for the searched query"""
    return HTML('<br/><div style="font-family: Times New Roman; font-size: 20px;'
                'padding-bottom:12px"><b>Query</b>: '+query+'</div>')

def show_document(idx, doc):
    """HTML print format for document fields"""
    return HTML('<div style="font-family: Times New Roman; font-size: 18px; padding-bottom:10px">' + 
               f'<b>Document {idx}:</b> {doc.docid} ({doc.score:1.2f}) -- ' +
                f'${doc.raw[:100]}$</div>')

def show_query_results(query, searcher, top_k):
    """HTML print format for the searched query"""
    hits = searcher.search(query,1000)
    display(show_query(query))
    for i, hit in enumerate(hits[:top_k]):
        display(show_document(i+1, hit))
    return hits

In [6]:
from pyserini.search import pysearch

query = ("How to compute this combinatoric sum?. I have the sum  $$\\sum_{k=0}^{n} \\binom{n}{k} k$$  I know the result is $n 2^{n-1}$ but I don't know how you get there. How does one even begin to simplify a sum like this that has binomial coefficients. ")
# query = ("compute combinatoric sum?. sum  $$\sum_{k=0}^{n} \binom{n}{k} k$$  result $n 2^{n-1}$ there. begin simplify sum binomial coefficients.")
searcher = pysearch.SimpleSearcher(ARQ_INDEX)
hits = show_query_results(query, searcher, 20)

## From the hits array, use `.lucene_document` to access the underlying indexed Lucene `Document`, and from there, call `.get(field)` to fetch specific fields, like "title", "doc", etc.
The complete list of available fields is [here](https://github.com/castorini/anserini/blob/master/src/main/java/io/anserini/index/generator/CovidGenerator.java#L46).

Let's extract contextualized vectors of queries and abstracts from SciBERT for highlighting relevant paragraphs.

First, extract the contextualized vectors of the query above:

$$q_1, \ldots, q_T = \text{SciBERT}(\text{query})$$

In [13]:
def extract_bert(text, tokenizer, model):
    text_ids = torch.tensor([tokenizer.encode(text, add_special_tokens=True)])
    text_ids = text_ids.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
    text_words = tokenizer.convert_ids_to_tokens(text_ids[0])[1:-1]

    n_chunks = int(numpy.ceil(float(text_ids.size(1))/510))
    states = []
    
    for ci in range(n_chunks):
        text_ids_ = text_ids[0, 1+ci*510:1+(ci+1)*510]            
        text_ids_ = torch.cat([text_ids[0, 0].unsqueeze(0), text_ids_])
        if text_ids[0, -1] != text_ids[0, -1]:
            text_ids_ = torch.cat([text_ids_, text_ids[0,-1].unsqueeze(0)])
        
        with torch.no_grad():
            state = model(text_ids_.unsqueeze(0))[0]
            state = state[:, 1:-1, :]
        states.append(state)

    state = torch.cat(states, axis=1)
    return text_ids, text_words, state[0]

In [14]:
query_ids, query_words, query_state = extract_bert(query, tokenizer, model)

Second, let's extract contextualized vectors of all the paragraphs from the hit #7:

$$p_1^k, \ldots, p_{T_k}^k = \text{SciBERT}(\text{paragraph}^k)$$

In [15]:
doc_states = []
for doc in tqdm(hits):
    state = extract_bert(doc.raw, tokenizer, model)
    doc_states.append(state)

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

Token indices sequence length is longer than the specified maximum sequence length for this model (676 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (543 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1072 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (809 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (848 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for th




We then compute the cosine similarity matrix between the query and each paragraph:

$$A^k = [a^k_{ij}] \in \mathbb{R}^{|\text{query}| \times |\text{paragraph}^k|},$$

where

$$a^k_{ij} = \frac{q_i^\top p_j^k}{\| q_i \| \| p_j^k \|}$$


In [16]:
def cross_match(state1, state2):
    state1 = state1 / torch.sqrt((state1 ** 2).sum(1, keepdims=True))
    state2 = state2 / torch.sqrt((state2 ** 2).sum(1, keepdims=True))
    sim = (state1.unsqueeze(1) * state2.unsqueeze(0)).sum(-1)
    return sim

In [17]:
sim_matrices = []
for pid, par in tqdm(enumerate(hits)):
    sim_score = cross_match(query_state.cpu(), doc_states[pid][-1].cpu())
    sim_matrices.append(sim_score)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [18]:
import gc
del model
torch.cuda.empty_cache()
gc.collect()

505

Let's retrieve the most relevant paragraphs first, where define the top-$M$ most relevant paragraphs as 

$$\arg\text{top-$M$}_{k=1}^K \max_{i=1,\ldots,|\text{query}|} \max_{j=1,\ldots, |\text{paragraph}^k|} A_{ij}^k$$

that is, a paragraph with the highly matched words to the query words is considered relevant.

In [19]:
doc_relevance = [torch.max(sim).item() for sim in sim_matrices]

# Select the index of top 5 paragraphs with highest relevance
# rel_index = numpy.argsort(doc_relevance)[-20:][::-1]
rel_index = numpy.argsort(doc_relevance)[:][::-1]

In [20]:
rel_index

array([862, 188, 155, 838, 844, 830,  98, 890, 502, 951, 215, 631, 791,
       943, 711, 653, 750, 477, 206, 821, 988, 304, 513, 695, 824, 939,
       803, 936,  72, 404, 934, 501, 770, 990, 193, 123, 181,  75, 494,
       245, 864, 223, 167, 151, 935, 829, 345, 394, 414, 634, 912, 571,
        82,   4, 387,   8,  10,   3, 507, 600, 182, 263, 882, 991, 492,
       158, 217,  48, 969,  50, 201, 740, 924, 875, 511, 418, 500, 359,
       481, 768, 788, 947, 106, 180, 224, 857, 365, 894, 143, 860, 226,
       231,   7, 228, 886, 329, 520, 848, 922, 566, 785, 192, 410, 471,
       828, 115,   9, 257,  84, 963, 382, 326, 966, 555, 980, 489, 371,
        43, 100, 731, 393, 645, 855,  69, 565, 679, 688, 858, 306, 177,
        93, 153, 347, 363, 735, 433, 126, 322, 554, 985, 575, 417, 415,
       110,  32, 426, 230,  89, 164, 845, 170, 216, 102, 238, 871,  45,
       841, 700, 747, 148, 401, 352,  40, 449, 234, 474, 597, 335, 413,
       448, 907, 405, 346, 728, 641, 577, 811, 644, 526,  56, 16

In [None]:
def show_sections(doc):
    """HTML print format for document subsections"""
#     return HTML('<div style="font-family: Times New Roman; font-size: 18px; padding-bottom:10px; margin-left: 15px">' + 
#         f' -- {text.replace("Ġ","")} </div>')
    return HTML('<div style="font-family: Times New Roman; font-size: 18px; padding-bottom:10px">' + 
                   f'<b>Document :</b> {doc.docid} ({doc.score:1.2f}) -- ' +
                    f'${doc.raw[:100]}$</div>')


display(show_query(query))
# display(show_document(ii, list_docs[ii]))
# for ri in numpy.sort(rel_index):
for ri in rel_index[:10]:
    print(ri)
    display(show_sections(hits[ri]))

In [None]:
def show_sections(section, text):
    """HTML print format for document subsections"""
    return HTML('<div style="font-family: Times New Roman; font-size: 18px; padding-bottom:10px; margin-left: 15px">' + 
        f' -- {text.replace("Ġ","")} </div>')

display(show_query(query))
# display(show_document(ii, list_docs[ii]))
for ri in rel_index:
    display(show_sections(hits[ri].raw, " ".join(doc_states[ri][1])))

In [None]:
rel_index

We want to look at more details by highlighting relevant phrases in each paragraph, where we define relevant phrases for each paragraph as

$$\arg\text{top-$M$}_{j=1,\ldots, |\text{paragraph}^k|} \max_{i=1,\ldots,|\text{query}|} A_{ij}^k$$

that is, any word that had a high similarity to each of the query words is considered relevant. given these words, we highlight a window of 10 surrounding each of them.

In [None]:
def highlight_paragraph(ptext, rel_words, max_win=10):
    para = ""
    prev_idx = 0
    for jj in rel_words:
        
        if prev_idx > jj:
            continue
        
        found_start = False
        for kk in range(jj, prev_idx-1, -1):
            if ptext[kk] == "." and (ptext[kk+1][0].isupper() or ptext[kk+1][0] == '['):
                sent_start = kk
                found_start = True
                break
        if not found_start:
            sent_start = prev_idx-1
            
        found_end = False
        for kk in range(jj, len(ptext)-1):
            if ptext[kk] == "." and (ptext[kk+1][0].isupper() or ptext[kk+1][0] == '['):
                sent_end = kk
                found_end = True
                break
                
        if not found_end:
            if kk >= len(ptext) - 2:
                sent_end = len(ptext)
            else:
                sent_end = jj
        
        para = para + " "
        para = para + " ".join(ptext[prev_idx:sent_start+1])
        para = para + " <font color='red'>"
        para = para + " ".join(ptext[sent_start+1:sent_end])
        para = para + "</font> "
        prev_idx = sent_end
        
    if prev_idx < len(ptext):
        para = para + " ".join(ptext[prev_idx:])

    return para

In [None]:
display(show_query(query))

# display(show_document(ii, hits[ii]))

for ri in numpy.sort(rel_index):
    sim = sim_matrices[ri].data.numpy()
    
    # Select the two highest scoring words in the paragraph
    rel_words = numpy.sort(numpy.argsort(sim.max(0))[-2:][::-1])
#     print(rel_words)
    p_tokens = doc_states[ri][1]
#     print(p_tokens)
    para = highlight_paragraph(p_tokens, rel_words)
    display(show_sections(hits[ri].raw, para))