# Reranking using SciBERT

This Notebook is to rerun the Anserini Baseline using SciBERT.


## Prerequisite

### CORD-19 Dataset (Rev 2020-06-19)

Download [here](https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/historical_releases.html)

### Trec Eval


A program used for IR evaluation. Please have [trec_eval](https://github.com/usnistgov/trec_eval) installed. Alternatively, use [Python](https://github.com/cvangysel/pytrec_eval) version.

Install original trec_eval:
1. Go to parent directory of this repo
2. git clone https://github.com/usnistgov/trec_eval.git
3. make

### Anserini Baseline Run

Get Round 4 Anserini Baseline from https://github.com/castorini/anserini/blob/master/docs/experiments-covid.md
```
wget -P ../tmp/ https://www.dropbox.com/s/lbgevu4wiztd9e4/anserini.covid-r4.abstract.qdel.bm25.txt
```
    
### Topics

The list of queries to be used against CORD-19 dataset.

```
wget -P ../tmp/ https://ir.nist.gov/covidSubmit/data/topics-rnd4.xml
```

### Qrels

The human evaluation of the retrieved document against topics. Used to evaluate the performance of IR system.

```

wget -P ../tmp/ https://ir.nist.gov/covidSubmit/data/qrels-covid_d4_j0.5-4.txt
```


In [None]:
import torch
import pickle
import numpy
from tqdm import tqdm
from transformers import *
from collections import defaultdict
import warnings
import os

%load_ext autoreload
%autoreload
from lib.utils import read_topics, get_abstract

warnings.simplefilter(action='ignore', category=FutureWarning)

PATH_TO_CORD_19_DATA = "../../CORD-19/2020-06-19/"
PATH_TO_ANSERINI_RUN = "anserini.covid-r4.abstract.qdel.bm25.txt"
PATH_TO_TOPICS = "../tmp/topics-rnd4.xml"
PATH_TO_TREC = "../../trec_eval/trec_eval"
PATH_TO_QRELS = "../tmp/qrels-covid_d4_j0.5-4.txt"


In [None]:
# Take a look at Anserini baseline. We will see if SciBERT can imporve

ret = os.system(PATH_TO_TREC + " -c -m all_trec " + PATH_TO_QRELS + " " + PATH_TO_ANSERINI_RUN + " > ../tmp/out.txt")

if ret == 0:
    with open("../tmp/out.txt", "r") as f:
        print(f.read())


**Anserini Baseline Before Reranking**
```
runid                 	all	anserini.covid-r4.abstract.qdel.bm25.txt
num_q                 	all	45
num_ret               	all	447306
num_rel               	all	15765
num_rel_ret           	all	12201
map                   	all	0.2994
gm_map                	all	0.2325
Rprec                 	all	0.3484
bpref                 	all	0.5407
recip_rank            	all	0.8660
iprec_at_recall_0.00  	all	0.9165
iprec_at_recall_0.10  	all	0.6471
iprec_at_recall_0.20  	all	0.5301
iprec_at_recall_0.30  	all	0.4325
iprec_at_recall_0.40  	all	0.3391
iprec_at_recall_0.50  	all	0.2526
iprec_at_recall_0.60  	all	0.1871
iprec_at_recall_0.70  	all	0.1276
iprec_at_recall_0.80  	all	0.0689
iprec_at_recall_0.90  	all	0.0258
iprec_at_recall_1.00  	all	0.0018
P_5                   	all	0.7911
P_10                  	all	0.7689
P_15                  	all	0.7407
P_20                  	all	0.7189
P_30                  	all	0.6911
P_100                 	all	0.5398
P_200                 	all	0.4162
P_500                 	all	0.2657
P_1000                	all	0.1713
recall_5              	all	0.0157
recall_10             	all	0.0294
recall_15             	all	0.0426
recall_20             	all	0.0549
recall_30             	all	0.0788
recall_100            	all	0.1913
recall_200            	all	0.2856
recall_500            	all	0.4238
recall_1000           	all	0.5233
infAP                 	all	0.2994
gm_bpref              	all	0.4535
Rprec_mult_0.20       	all	0.6115
Rprec_mult_0.40       	all	0.5108
Rprec_mult_0.60       	all	0.4363
Rprec_mult_0.80       	all	0.3878
Rprec_mult_1.00       	all	0.3484
Rprec_mult_1.20       	all	0.3164
Rprec_mult_1.40       	all	0.2899
Rprec_mult_1.60       	all	0.2675
Rprec_mult_1.80       	all	0.2482
Rprec_mult_2.00       	all	0.2325
utility               	all	-9397.8667
11pt_avg              	all	0.3208
binG                  	all	0.1522
G                     	all	0.1313
ndcg                  	all	0.6607
ndcg_rel              	all	0.5722
Rndcg                 	all	0.4991
ndcg_cut_5            	all	0.7299
ndcg_cut_10           	all	0.7081
ndcg_cut_15           	all	0.6794
ndcg_cut_20           	all	0.6650
ndcg_cut_30           	all	0.6403
ndcg_cut_100          	all	0.5201
ndcg_cut_200          	all	0.4570
ndcg_cut_500          	all	0.4650
ndcg_cut_1000         	all	0.5236
map_cut_5             	all	0.0151
map_cut_10            	all	0.0270
map_cut_15            	all	0.0380
map_cut_20            	all	0.0483
map_cut_30            	all	0.0669
map_cut_100           	all	0.1449
map_cut_200           	all	0.1963
map_cut_500           	all	0.2499
map_cut_1000          	all	0.2746
relative_P_5          	all	0.7911
relative_P_10         	all	0.7689
relative_P_15         	all	0.7407
relative_P_20         	all	0.7189
relative_P_30         	all	0.6911
relative_P_100        	all	0.5455
relative_P_200        	all	0.4541
relative_P_500        	all	0.4353
relative_P_1000       	all	0.5233
success_1             	all	0.8000
success_5             	all	0.9333
success_10            	all	1.0000
set_P                 	all	0.0273
set_relative_P        	all	0.7831
set_recall            	all	0.7831
set_map               	all	0.0218
set_F                 	all	0.0523
num_nonrel_judged_ret 	all	13706
```

In [None]:
# Read Anserini Baseline

with open(PATH_TO_ANSERINI_RUN, "r") as f:
    anserini_run = f.readlines()
    
set_of_top_k_uid_all_topics = set() # We won't get all abstracts because there are too many.

for line in anserini_run:
    qid, _, uid, rank, _, _ = line.strip().split()
    if int(rank) <= 200:
        set_of_top_k_uid_all_topics.add(uid)

# Read topics

topics = read_topics(PATH_TO_TOPICS)

# Read Abstract

abstracts_dict = get_abstract(PATH_TO_CORD_19_DATA, set_of_top_k_uid_all_topics)


# load SciBert (alternative: monologg/biobert_v1.1_pubmed)
tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', do_lower_case=False)
model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')


# Reranking using SciBERT + Words-level Matches


What we are doing is - given Anserini Run on abstract, use SciBERT to rerank first K abstracts, and evaluate using TREC metric as usual. 

In short, we use SciBERT to encode the abstracts and queries into vectors of size [ num_of_tokens, 768]. If the sentence is longer than maximum size of BERT, use slicing windows. We use the vectors of the last hidden states of BERT model.


Once both abstracts and queries are encoded, we do token level cosine similarity to score the relevant of paragraph and query with the highly matched words.

This approach is inspired by [Notebook](https://github.com/castorini/anserini-notebooks/blob/master/Pyserini%2BSciBERT_on_COVID_19_Demo.ipynb).

In [None]:
CHUNK_SIZE = 510

def extract_scibert(text, tokenizer, model):
    """
    Encode text to vectors
    @text string to be encoded
    @tokenizer BertTokenizer object
    @model BertModel object
    @return tensor of size [num_tokens, 768] (last hidden state of BERT)
    """
    
    text_ids = torch.tensor([tokenizer.encode(text, add_special_tokens=True)])
    text_words = tokenizer.convert_ids_to_tokens(text_ids[0])[1:-1]

    n_chunks = int(numpy.ceil(float(text_ids.size(1))/CHUNK_SIZE))
    states = []
    
    for ci in range(n_chunks):
        text_ids_ = text_ids[0, 1+ci*CHUNK_SIZE:1+(ci+1)*CHUNK_SIZE]  
        text_ids_ = torch.cat([text_ids[0, 0].unsqueeze(0), text_ids_])
        if text_ids_[-1] != text_ids[0, -1]: # their code is wrong

            text_ids_ = torch.cat([text_ids_, text_ids[0,-1].unsqueeze(0)])
        
        with torch.no_grad():
            state = model(text_ids_.unsqueeze(0))[0] # last hidden states
            state = state[:, 1:-1, :]
        states.append(state)

    state = torch.cat(states, axis=1)
    # return text_ids, text_words, state[0]
    return state[0]
    

In [None]:

def encode_abstract_query_narrative_and_save(topics, abstracts_dict, extract_scibert, tokenizer, model, path_to_output, **kwargs):
    """
    Encode topics and abstracts using given encoding function, save it to path_to_output
    @topics
    @abstracts_dict
    @extract_scibert BERT encoding function, 
        input: text, tokenizer, model **kwargs, 
        output  tensor of size [num, 768]
    @tokenizer to pass to extract_scibert
    @model to pass to extract_scibert
    """
    
    # encode abstract
    encoded_abstract = dict()

    for k, text in tqdm(abstracts_dict.items()):
        encoded_abstract[text] = extract_scibert(text, tokenizer, model, **kwargs)

    # encode topics
    encoded_queries = dict()
    encoded_naratives = dict()

    for topic in topics:
        encoded_queries[topic["query"]] = extract_scibert(topic["query"], tokenizer, model)
        encoded_naratives[topic["narrative"]] = extract_scibert(topic["narrative"], tokenizer, model)

    # save for future use
    bert_vectors = {
            "abstract": encoded_abstract, 
            "query": encoded_queries, 
            "narrative": encoded_naratives
    }

    with open("../tmp/" + path_to_output, "wb") as f:
        pickle.dump(bert_vectors, f)



In [None]:
encoded_vectors_output = "scibert.pkl"
encode_abstract_query_narrative_and_save(
    topics, 
    abstracts_dict, 
    extract_scibert, 
    tokenizer, 
    model, 
    encoded_vectors_output)
    
    
    

In [None]:
# load
with open("../tmp/scibert.pkl", "rb") as f:
    bert_vectors = pickle.load(f)

In [None]:
def cross_match(state1, state2):
    state1 = state1 / torch.sqrt((state1 ** 2).sum(1, keepdims=True))
    state2 = state2 / torch.sqrt((state2 ** 2).sum(1, keepdims=True))
    sim = (state1.unsqueeze(1) * state2.unsqueeze(0)).sum(-1)
    return sim

def rerank(topics, anserini_run, abstracts_dict, top_k, path_to_reranked_run, topics_field, bert_vectors):
    """
    Rerank the original run and save the reranked run in path_to_reranked_run
    @topics list of dict where dict represents each topic
    [
        {
            'number': '1', 
            'query': 'coronavirus origin', 
            'question': 'what is the origin of COVID-19',
            'narrative': "seeking range of information about ..."
         },
         ...
    ]
    @abstracts_dict dict where key=uid and value=abstract
    @anserini_run list of string run in anserini_run
    @top_k to rerank, the rest will remain same
    @path_to_reranked_run path to save the reranked run
    @topics_field query or narrative
    @bert_vectors dict k=query,abstract,narrative v=vectors
    """
    rerank = defaultdict(list)  # first k hits
    keeprank = defaultdict(list) # k+1 to 1000 hits
    query_dict = {e["number"]: e[topics_field] for e in topics}  # choose whether to use query or narrative
    encoded_queries = bert_vectors[topics_field] # choose whether to use query or narrative
    encoded_abstract = bert_vectors["abstract"]

    # calculate similarity
    for line in anserini_run:
        qid, _, uid, j, score, _ = line.strip().split()
        if len(rerank[qid]) < top_k:
            if not abstracts_dict[uid]:
                continue # Some uid don't have abstract. But why they show up in Anserini run?

            _, _, enc_abs = encoded_abstract[abstracts_dict[uid]]
            _, _, enc_query = encoded_queries[query_dict[qid]]
            sim = cross_match(enc_query, enc_abs)

            rel_score = torch.max(sim).item()
            rerank[qid].append([uid, rel_score])

        elif len(rerank[qid]) >= top_k and len(keeprank[qid]) < 1000 - top_k: 
            keeprank[qid].append([uid, score, j])


    # create reranked run and save to path_to_reranked_run
    template = "{} Q0 {} {} {} anserini_scibert"
    run = list()

    for qid in rerank:
        rank = 1
        for uid, score in sorted(rerank[qid], key=lambda x:-x[1]):
            run.append(template.format(qid, uid, rank, score + 10))
            rank += 1

        for uid, score, j in keeprank[qid]:
            run.append(template.format(qid, uid, rank, score))
            rank += 1
            
        assert rank == 1001 # if no bugs, each topic will have at most 1000 uid (can be less if original run has less)

    with open("../tmp/" + path_to_reranked_run, "w") as f:
        f.write("\n".join(run))

In [None]:
path_to_reranked_run = "scibert_sim_rerank.txt"
topics_field = "narrative" # query or narrative
rerank(topics, anserini_run, abstracts_dict, 100, path_to_reranked_run, topics_field, bert_vectors)


In [None]:

ret = os.system(PATH_TO_TREC + " -c -m all_trec ../tmp/qrels-covid_d4_j0.5-4.txt  ../tmp/scibert_sim_rerank_on_narative.txt > ../tmp/out.txt")
if ret == 0:
    with open("../tmp/out.txt", "r") as f:
        print(f.read())


**Result after reranking**


It turns out that this approach doesn't improve the reranking.


```

runid                 	all	anserini_scibert
num_q                 	all	45
num_ret               	all	44955
num_rel               	all	15765
num_rel_ret           	all	7678
map                   	all	0.1947
gm_map                	all	0.1564
Rprec                 	all	0.2981
bpref                 	all	0.4178
recip_rank            	all	0.7022
iprec_at_recall_0.00  	all	0.8272
iprec_at_recall_0.10  	all	0.4782
iprec_at_recall_0.20  	all	0.3947
iprec_at_recall_0.30  	all	0.3183
iprec_at_recall_0.40  	all	0.2316
iprec_at_recall_0.50  	all	0.1332
iprec_at_recall_0.60  	all	0.0942
iprec_at_recall_0.70  	all	0.0484
iprec_at_recall_0.80  	all	0.0317
iprec_at_recall_0.90  	all	0.0053
iprec_at_recall_1.00  	all	0.0000
P_5                   	all	0.5244
P_10                  	all	0.5378
P_15                  	all	0.5126
P_20                  	all	0.4933
P_30                  	all	0.4778
P_100                 	all	0.4300
P_200                 	all	0.3614
P_500                 	all	0.2406
P_1000                	all	0.1706
recall_5              	all	0.0091
recall_10             	all	0.0177
recall_15             	all	0.0258
recall_20             	all	0.0333
recall_30             	all	0.0473
recall_100            	all	0.1445
recall_200            	all	0.2385
recall_500            	all	0.3819
recall_1000           	all	0.5172
infAP                 	all	0.1947
gm_bpref              	all	0.3567
Rprec_mult_0.20       	all	0.4542
Rprec_mult_0.40       	all	0.3903
Rprec_mult_0.60       	all	0.3560
Rprec_mult_0.80       	all	0.3260
Rprec_mult_1.00       	all	0.2981
Rprec_mult_1.20       	all	0.2731
Rprec_mult_1.40       	all	0.2512
Rprec_mult_1.60       	all	0.2327
Rprec_mult_1.80       	all	0.2173
Rprec_mult_2.00       	all	0.2044
utility               	all	-657.7556
11pt_avg              	all	0.2330
binG                  	all	0.0900
G                     	all	0.0810
ndcg                  	all	0.4718
ndcg_rel              	all	0.4235
Rndcg                 	all	0.3752
ndcg_cut_5            	all	0.4589
ndcg_cut_10           	all	0.4637
ndcg_cut_15           	all	0.4468
ndcg_cut_20           	all	0.4326
ndcg_cut_30           	all	0.4216
ndcg_cut_100          	all	0.3865
ndcg_cut_200          	all	0.3628
ndcg_cut_500          	all	0.3883
ndcg_cut_1000         	all	0.4718
map_cut_5             	all	0.0072
map_cut_10            	all	0.0128
map_cut_15            	all	0.0172
map_cut_20            	all	0.0214
map_cut_30            	all	0.0293
map_cut_100           	all	0.0768
map_cut_200           	all	0.1181
map_cut_500           	all	0.1640
map_cut_1000          	all	0.1947
relative_P_5          	all	0.5244
relative_P_10         	all	0.5378
relative_P_15         	all	0.5126
relative_P_20         	all	0.4933
relative_P_30         	all	0.4778
relative_P_100        	all	0.4320
relative_P_200        	all	0.3882
relative_P_500        	all	0.3934
relative_P_1000       	all	0.5172
success_1             	all	0.5556
success_5             	all	0.8667
success_10            	all	0.9556
set_P                 	all	0.1708
set_relative_P        	all	0.5172
set_recall            	all	0.5172
set_map               	all	0.0937
set_F                 	all	0.2449
num_nonrel_judged_ret 	all	5724

```


# Reranking using SciBERT + Phrase-level Matches

In [None]:
def extract_scibert_phrase(text, tokenizer, model, pool, window_size = 100, overlap_size = 20):
    """
    Encode text to vectors
    @text string to be encoded
    @tokenizer BertTokenizer object
    @model BertModel object
    @pool "cls" or "sum". If class, use [CLS] vector. If sum, sum all token vectors except [CLS] and [SEP]. 
    see https://huggingface.co/transformers/model_doc/bert.html#transformers.BertModel
    @return tensor of size [num_phrases, 768]
    """
    
    num_phrases = len(text.split()) // (window_size -  overlap_size) 
    
    vectors = []
    for batch in range(num_phrases):
        start = batch * (window_size - overlap_size)
        end = start + window_size
        s = " ".join(text.split()[start:end])

        # convert to vectors
        inputs = tokenizer(s, return_tensors="pt")
    
        with torch.no_grad():
            outputs = model(**inputs)
        
        if pool == "cls":
            output = outputs[1]
        elif pool == "sum":
            output = outputs[0][0,1:-1,:].sum(axis=0).reshape(1,-1)
        else:
            print("Invalid Pool Method. Must be 'cls' or 'sum'.")

        vectors.append(output)
    
    output_vectors = torch.cat(vectors, axis=0)
    return output_vectors
    

In [None]:

encode_abstract_query_narrative_and_save(
    topics, 
    abstracts_dict, 
    extract_scibert_phrase, 
    tokenizer, 
    model, 
    "scibert_phrase_cls_w_100_o_20.pkl",
    pool="cls"

)
    
    

encode_abstract_query_narrative_and_save(
    topics, 
    abstracts_dict, 
    extract_scibert_phrase, 
    tokenizer, 
    model, 
    "scibert_phrase_sum_w_100_o_20.pkl",
    pool="sum"

)

# Reranking with BertForNextSentence