# Question answering using ElasticSearch and SciBERT
This notebook attempts to answer the most questions in the vaccines and [therapeutics tasks](https://www.kaggle.com/allen-institute-for-ai/CORD-19-research-challenge/tasks?taskId=561) using a combination of ElasticSearch for the initial information retrieval and SciBERT for the further answering of the questions. It is loosly based on [this paper by David R. Cheriton](https://arxiv.org/pdf/1902.01718.pdf). 

Roughly what it does is the following:

1.   Retrieve relevant papers based on keywords (this is annotated by humans)
2.   Train a SciBERT model on the SQuAD 2.0 set for Question and Answering
3.   Predict the answer for each of the questions based on each of the relevant articles and display the results.

For easy reading some of the sections have been hidden. In order to retrain the BERT model you need a Google Cloud Storage bucket to which you've write access. Since Kaggle doesn't provide this, the output of the model has been added as a data set. If you want to run the model yourself you should set the `is_kaggle` parameter below to `False` and enter the name of your Google Cloud Storage bucket below (in the Prerequisites section).

## Abstract

As Coronavirus is spreading across the world, there is a rapid acceleration in related literature, making it difficult for the medical research community to keep up. In this notebook, we combined \[[Elasticsearch](https://github.com/elastic/elasticsearch)] and \[[BERT model](https://arxiv.org/pdf/1810.04805.pdf)] to mine useful text information from the article dataset and answer related queries.

To begin with, Elasticsearch is one of the most popular and powerful enterprise search engines, and BERT outperforms other models in tasks of question answering and language inference. For each query, we determined its keywords, and used Elasticsearch to output 50 most relevant articles from the dataset. Then we re-trained the BERT model with plain texts in Google Cloud, after which BERT would figure out the most suitable answer for the query. The procedure is as follows.

$\{ \text{query} = (\text{all articles},\text{keywords}) \} \xrightarrow{\text{Elasticsearch}} \{ (50 \text{ candidate articles}) \} \xrightarrow{\text{BERT}} \text{Answer}$ 

Usually these two methods might not be optimal, namely because Elasticsearch would provide non-relevant results and BERT would cost lots of time. Similar to \[[BERTserini](https://arxiv.org/pdf/1902.01718.pdf)] which integrates BERT with another toolkit, we combined Elasticsearch and BERT, hoping to answer the queries in an efficient and accurate way.

## Prerequisites
This code depends on Google's BERT implementation and the training script they've written to retrain the BERT model for the SQuAD challenge.
It also requires an ElasticSearch server as well as an Google TPU for faster training and prediction. Alternatively you can also run this code locally, though a GPU is highly recommended.  

In [None]:
is_kaggle = True

In [None]:
!git clone https://github.com/ofjpostema/bert.git

In order to train the BERT model you need a Google Cloud Storage bucket to which you've got access. You should enter that below.

In [None]:
output_dir_name = 'bert_output'    
if is_kaggle:
    QUESTION_DIR = "/kaggle/input/elasticsearchquestions"
    from kaggle_datasets import KaggleDatasets
    BUCKET_NAME = KaggleDatasets().get_gcs_path("bertanswers")
    BUCKET = BUCKET_NAME.split("/")[-1]
    OUTPUT_DIR = '/kaggle/input/bertanswers'
else:
    QUESTION_DIR = os.path.join("..", "data", "interim", "questions")
    if not os.path.isdir(QUESTION_DIR):
        os.mkdir(QUESTION_DIR)

    BUCKET = 'of-covid-19-clean'
    BUCKET_NAME = 'gs://{}'.format(BUCKET)
    OUTPUT_DIR = 'gs://{}/{}'.format(BUCKET, output_dir_name)

GS_QUESTION_DIR_BASE = "questions"
GS_QUESTION_DIR = 'gs://{}/{}'.format(BUCKET, GS_QUESTION_DIR_BASE)
ANSWER_DIR = "answers"
GS_ANSWER_DIR = '{}/{}'.format(BUCKET_NAME, ANSWER_DIR)
GCP_PROJECT = 'covid-19-271609'

In [None]:
# NOTE: This is only relevant for Google Colab. On Kaggle it uses the Kaggle's OAuth authentication
if not is_kaggle:
    from google.colab import auth
    auth.authenticate_user()

In [None]:
if not is_kaggle:
    # You can't import these by default on Google Colab
    from elasticsearch_dsl import connections, Index, Search
    from elasticsearch_dsl import Document, Text, Boolean
    from elasticsearch import Elasticsearch

In [None]:
import os
import pandas as pd
import json
import pprint
pp = pprint.PrettyPrinter(indent=2)
from collections import Counter
import re
import numpy as np
from tqdm.notebook import tqdm
import datetime
import random
import string
import sys
import tensorflow as tf
import collections

## ElasticSearch
This section of the code processes all of the documents and reads them into the ElasticSearch index. This requires an Elasticsearch server to be up-and-running. Since we ran the server locally, we won't run the code below on Kaggle, however, all the code to preprocess the data, insert it and then search in the resulting index is here. If you're new to ElasticSearch I'd recommend first reading the [getting started guide](https://www.elastic.co/guide/en/elasticsearch/reference/current/getting-started.html).

In [None]:
# Create a connection to ElasticSearch
if not is_kaggle:
    connections.create_connection(hosts=['localhost'], timeout=20)

In [None]:
datasets = ['biorxiv_medrxiv', 'comm_use_subset', 'custom_license', 'noncomm_use_subset']
if not is_kaggle:
    # Define the paths to the data
    dir_data_raw = os.path.join("..", "data", "raw")
    data_dir_interim = os.path.join("..", "data", "interim")
else:
    dir_data_raw = os.path.join("..", "input", "CORD-19-research-challenge")
    data_dir_interim = os.path.join("interim")
    if not os.path.isdir(data_dir_interim):
        os.mkdir(data_dir_interim)

### Formatting
These formatting helper functions are courtesy of [xhlulu](https://www.kaggle.com/xhlulu/cord-19-eda-parse-json-and-generate-clean-csv). It formats the body of a paper using the different section headings, etc. This is used to format the document body before it's ingested into ElasticSearch and SciBERT.

In [None]:
def format_body(body_text):
    texts = [(di['section'], di['text']) for di in body_text]
    texts_di = {di['section']: "" for di in body_text}
    
    for section, text in texts:
        texts_di[section] += text

    body = ""

    for section, text in texts_di.items():
        body += section
        body += "\n\n"
        body += text
        body += "\n\n"
    
    return body

### Preprocessing
We'll attempt to extract the results and conclusion sections from the articles. This is done based on the heading titles. We first collected the most occuring section titles and then used these to determine which parts belong to the results and conclusions sections respectively.

In [None]:
def parse_article(full_path, file_path):
    """
    Parse an article's body text and extract the full text, the results and the conclusion.
    
    full_path: str: The fully qualified path to the file
    file_path: str: The file path starting from the data_raw dir
    """
    section_headings = {
        "results": ["results and discussion", "results"],
        "conclusion": ["conclusion", "conclusions", "discussion and conclusions"],
        #TODO: Intro
    }
    with open(full_path) as file:
        json_article = json.load(file)["body_text"]
        article_sections = []
        # For extracting the main body we 
        metadata.loc[index, 'full_text'] = format_body(json_article)
        for body_text in json_article:
            # Clean the section headings, lowercase and trim them
            section_heading = re.sub(r'[^a-zA-Z0-9 ]', '', body_text["section"]).lower().strip()
            for section, headings in section_headings.items():
                if section_heading in headings:
                    metadata.loc[index, section] =  article[section] + body_text["text"]

We'll load the metadata file and add the full text, file path and the results and conclusions sections to it, using the previously defined methods.

In [None]:
if not is_kaggle:
    # Load the metadata and initialize the new, empty, columns
    metadata = pd.read_csv(os.path.join(dir_data_raw, "metadata.csv"))
    metadata["full_text"] = ""
    metadata["file_path"] = None
    metadata["results"] = ""
    metadata["conclusion"] = ""

In [None]:
if not is_kaggle:
    for index, article in tqdm(metadata.iterrows()):
        # We only need to update if there's a full text
        if article["has_full_text"]:
            for dataset in datasets:
                file_path = os.path.join(dataset, dataset, str(article["sha"]) + ".json")
                metadata.loc[index, "file_path"] = file_path
                full_path = os.path.join(dir_data_raw, file_path)
                if os.path.exists(full_path):
                    parse_article(full_path, file_path)

#### Checkpointing
Optional: Store the results in a CSV file.

In [None]:
if not is_kaggle:
    metadata.to_csv(os.path.join(data_dir_interim, "1_full_data.csv"))

In [None]:
if not is_kaggle:
    metadata = pd.read_csv(os.path.join(data_dir_interim, "1_full_data.csv"))

### Ingestion
Create a document type using the ElasticSearch DSL package for the data and upload all papers that have a full text to the ElasticSearch server.

In [None]:
if not is_kaggle:
    class Paper(Document):
        id = Text(required=True, index='covid')
        title = Text(required=True)
        authors = Text(required=True)
        abstract = Text(required=True)
        text = Text(required=True)
        results = Text(required=True)
        conclusion = Text(required=True)
        bibliography = Text(required=False)

        class Meta:
            name = 'covid'

In [None]:
if not is_kaggle:
    # First create an index
    index = Index("covid")

    for index, paper in tqdm(metadata.iterrows()):
        # We'll only upload the document if it has a full text.
        if paper["has_full_text"]:
            paper_doc = Paper(
                id=paper["sha"] if type(paper["sha"]) == str else "",
                title=paper["title"] if type(paper["title"]) == str else "",
                authors=paper["authors"] if type(paper["authors"]) == str else "",
                abstract=paper["abstract"] if type(paper["abstract"]) == str else "",
                text=paper["full_text"] if type(paper["full_text"]) == str else "",
                results=paper["results"] if type(paper["results"]) == str else "",
                conclusion=paper["conclusion"] if type(paper["conclusion"]) == str else "",
                bibliography=""
            )
            paper_doc.save(index="covid")

In [None]:
if not is_kaggle:
    client = Elasticsearch()

In [None]:
queries = [
    {
        "id": 1,
        "question": "What is the clinical effectiveness of antiviral agents?",
        "keywords": ["clinical effectiveness", "therapeutic", "antiviral agents"],
    },
    {
        "id": 2,
        "question": "What is the effectiveness of drugs being developed and tried to treat COVID-19 patients?",
        "keywords": ["clinical trials", "bench trials", "viral inhibitors", "naproxen", "clarithromycin", "minocyclinethat", "viral replication"],
    },
    {
        "id": 3,
        "question": "Are there potential complication of Antibody-Dependent Enhancement (ADE) in vaccine recipients?",
        "keywords": ["complications", "Antibody-Dependent Enhancement", "vaccine", "antiviral proteins"],
    },
    {
        "id": 4,
        "question": "Are there animal models that offer predictive value for a human vaccine?",
        "keywords": ["animal models", "predictive", "vaccine"],
    },
    {
        "id": 5,
        "question": "How to distribute scarces therapeutics?",
        "keywords": ["distribution", "therapeutics", "antiviral agents", "decision making", "prioritizing"],
    },
    {
        "id": 6,
        "question": "How to expand production capacity of antiviral agents?",
        "keywords": ["production capacity", "therapeutic", "antiviral agents"],
    },
    {
        "id": 7,
        "question": "Are there universal coronavirus vaccines?",
        "keywords": ["coronavirus vaccine", "universal vaccine"],
    },
    {
        "id": 8,
        "question": "Which animal models are there?",
        "keywords": ["animal models", "challenge studies"],
    },
    {
        "id": 9,
        "question": "Which prophylaxis clinical studies are there?",
        "keywords": ["prevention", "prophylaxis", "clinical study"],
    },
    {
        "id": 10,
        "question": "What is the clinical effectiveness of antiviral agents?",
        "keywords": ["clinical effectiveness", "therapeutic", "antiviral agents"],
    },
]

### Collecting data
We'll collect all relevant data from ElasticSearch. To do this we

1. Create a search query based on the question
2. Get all nouns from the query
3. Get synonyms for all nouns
4. Search using this search query (in the abstract and the keywords)
5. Collect the top 50 results
6. Create a train.json file for this question, posing it to each article

In [None]:
def search_get_results(question, limit_from, limit_size):
    should = [{"match": {"text": keyword}} for keyword in question["keywords"]]
    response = client.search(
        index="covid",
        body={
          "from": limit_from,
          "size": limit_size,
          "query": {
                "bool": {
                  "should": [{"match": {"text": "covid"}}, {"match": {"text": "ncov"}}],
                  "should": should,
                }
          },
        }
    )
    return response


def get_all_results(question, min_score):
    """
    Get all results from elasticsearch that have a minimum score. This method continues going 
    through the pages (using the limit and size parameters) until it has found an article that 
    has a score below the min score.
    
    question: dict: The question
    min_score: float: The minimum score an article should have.
    """
    last_score = 10000
    limit_size = 50
    limit_from = 0
    hits = []
    while last_score > min_score:
        search_results = search_get_results(question, limit_from, limit_size)
        hits = hits + search_results["hits"]["hits"]
        limit_from += limit_size
        last_score = hits[-1]["_score"]
        
    # We'll delete articles with a score under the min_score.
    while hits[-1]["_score"] < min_score:
        hits.pop()
        
    return hits

We are going to go through each of the previously defined queries and retrieve all documents that match the keywords. These articles are then combined in a JSON file that follows the input format for [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/). There is one file for each question with one paragraph for each paper.

In [None]:
if not is_kaggle:
    for query in queries:
        hits = get_all_results(query, 11)
        print("{} hits for query {}".format(len(hits), query["id"]))
        input_questions = {
            "version": "v0.1",
            "data": [
                {
                    "title": hit["_source"]["title"],
                    "paragraphs": []
                }
            ]
        }
        # Get the query
        for hit in hits:
            input_questions["data"][0]["paragraphs"].append({
                "qas": [{
                    "question": query["question"],
                    "id": "q_{}_h_{}".format(query["id"], hit["_source"]["id"]),
                    "is_impossible": ""
                }],
                "context": hit["_source"]["text"].lower()
            })
        with open(os.path.join(QUESTION_DIR, 
                                "q_{}.json".format(
                                    query["id"]
                                )), 'w') as outfile:
            json.dump(input_questions, outfile)

In [None]:
if not is_kaggle:
    !gsutil -m cp -r $QUESTION_DIR $GS_QUESTION_DIR

## Training
We first need to re-train SciBERT to actually answer questions.

In [None]:
if is_kaggle:
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
        TPU_ADDRESS = tpu.master()
    except ValueError:
        TPU_ADDRESS = None
else:
    assert 'COLAB_TPU_ADDR' in os.environ, 'ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!'
    TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']
print('TPU address is => ', TPU_ADDRESS)

In [None]:
if not is_kaggle:
    !wget https://s3-us-west-2.amazonaws.com/ai2-s2-research/scibert/tensorflow_models/scibert_scivocab_uncased.tar.gz
    !tar -xf scibert_scivocab_uncased.tar.gz

In [None]:
if not is_kaggle:
    !gsutil mv /content/scibert_scivocab_uncased $BUCKET_NAME

In [None]:
if not is_kaggle:
    !wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json
    !wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json

In [None]:
if not is_kaggle:
    !pip install -U tensorflow==1.15.0

In [None]:
if not is_kaggle:
    !python bert/run_squad.py \
      --vocab_file=$BUCKET_NAME/scibert_scivocab_uncased/vocab.txt \
      --bert_config_file=$BUCKET_NAME/scibert_scivocab_uncased/bert_config.json \
      --init_checkpoint=$BUCKET_NAME/scibert_scivocab_uncased/bert_model.ckpt \
      --do_train=True \
      --train_file=train-v2.0.json \
      --do_predict=True \
      --predict_file=dev-v2.0.json \
      --train_batch_size=24 \
      --learning_rate=3e-5 \
      --num_train_epochs=2.0 \
      --use_tpu=True \
      --tpu_name=$TPU_ADDRESS \
      --max_seq_length=512 \
      --doc_stride=128 \
      --version_2_with_negative=True \
      --output_dir=$OUTPUT_DIR

## Predicting
This assumes that there is already a trained model in the previously mentioned directory in the Google Cloud bucket. If you want to train using a TPU you also need to enter the TPU's address. 

In [None]:
if not is_kaggle:
    from google.cloud import storage
    storage_client = storage.Client(project=GCP_PROJECT)
    bucket = storage_client.get_bucket(BUCKET)

In [None]:
if not is_kaggle:
    question_blobs = bucket.list_blobs(
        prefix=GS_QUESTION_DIR_BASE
    )

In order to make predictions, you need a trained model and access to a GCS bucket to which you've got write access.

In [None]:
if not is_kaggle:
    for question in tqdm(question_blobs):
        question_name = question.name
        output_dir_answer = question.name.split(".")[0].split("/")[-1]
        !python bert/run_squad.py \
          --vocab_file=$BUCKET_NAME/scibert_scivocab_uncased/vocab.txt \
          --bert_config_file=$BUCKET_NAME/scibert_scivocab_uncased/bert_config.json \
          --init_checkpoint=$BUCKET_NAME/scibert_scivocab_uncased/bert_model.ckpt \
          --do_train=False \
          --max_query_length=30  \
          --do_predict=True \
          --predict_file=$BUCKET_NAME/$question_name \
          --use_tpu=True \
          --tpu_name=$TPU_ADDRESS \
          --predict_batch_size=8 \
          --n_best_size=3 \
          --max_seq_length=512 \
          --doc_stride=128 \
          --output_dir=$BUCKET_NAME/answers/$output_dir_answer/

# Postprocessing
After the predictions have been made, we should do some post processing. This is mainly matching the position of the found answer to the original text and extracting passages from it that contain the answer.

In [None]:
import nltk
nltk.download('punkt')

In [None]:
from bert import tokenization

In [None]:
class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self,
                 unique_id,
                 example_index,
                 doc_span_index,
                 tokens,
                 token_to_orig_map,
                 token_is_max_context,
                 input_ids,
                 input_mask,
                 segment_ids,
                 start_position=None,
                 end_position=None,
                 is_impossible=None):
        self.unique_id = unique_id
        self.example_index = example_index
        self.doc_span_index = doc_span_index
        self.tokens = tokens
        self.token_to_orig_map = token_to_orig_map
        self.token_is_max_context = token_is_max_context
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.start_position = start_position
        self.end_position = end_position
        self.is_impossible = is_impossible

In [None]:
tokenizer = tokenization.BasicTokenizer(do_lower_case=True)

In [None]:
def parse_questions(paragraphs):
    """
    Parse the questions from the paragraphs in the questions file.
    This will extract the article texts and the question.
    args
    paragraphs: list: A list of paragraph dicts.

    return 
    articles: dict
    question: str
    """
    qa_id_text = {}
    for paragraph in original_json["data"][0]["paragraphs"]:
        for qas in paragraph["qas"]:
            qa_id_text[qas["id"]] = paragraph["context"]
        question_text = qas["question"]
    return qa_id_text, question_text

def get_sentence_index(word_index, sentences):
    """
    Get the index of a sentence given the index of a word in the text.

    args
    word_index: int: The index of the word
    sentences: list<str>: The list of sentences

    return
    sentence_index: int: The index of the sentence
    """
    i = 0
    for idx_sentence, sentence in enumerate(sentences):
        sententence_tokens = sentence.split(" ")
        i += len(sententence_tokens)
        if i > word_index:
            return idx_sentence

def get_passage(word_index, sentences, text_to_find):
    """
    Get a passage from the text, given a list of sentences, the text to find and 
    an index of the word.
    """
    selected_sentence = get_sentence_index(word_index, sentences)
    distance = 1
    found = False
    while not found and distance < 6:
        combined_sentences = " ".join([
            sentence 
            for idx_sentence, sentence in enumerate(sentences) 
            if  idx_sentence >= selected_sentence - distance and 
                idx_sentence <= selected_sentence + distance
            ]
        )
        if text_to_find in combined_sentences:
            return combined_sentences
        else:
            distance += 1
    return None

def get_questions():
    """
    Get a list or iterator of all the questions.
    """
    if is_kaggle:
        return [(f, os.path.join(QUESTION_DIR, f)) 
                for f in os.listdir(QUESTION_DIR) 
                if os.path.isfile(os.path.join(QUESTION_DIR, f))]
    else:
        question_blobs = bucket.list_blobs(
            prefix=GS_QUESTION_DIR_BASE
        )
        return [(question_blob.name, question_blob) 
                for question_blob in question_blobs]
    
def get_answers():
    """
    Get a list or iterator of all the answers.
    """
    if is_kaggle:
        return [(f, os.path.join(OUTPUT_DIR, f)) for f in os.listdir(OUTPUT_DIR) 
                if os.path.isfile(os.path.join(OUTPUT_DIR, f))]
    else:
        answer_blobs = bucket.list_blobs(
            prefix=ANSWER_DIR
        )
        return [(answer_blob.name, answer_blob) 
                for answer_blob in answer_blobs 
                if answer_blob.name.endswith("nbest_predictions.json")]

    
def get_qa_content(qa):
    """
    Get the contents of a questions or answers file.
    """
    if is_kaggle:
        with open(qa, "r") as in_file:
            return json.load(in_file)
    else:
        if type(qa) == str:
            # This means that it's only a path to the file on GCS.
            qa = storage.blob.Blob(qa, bucket)

        return json.loads(qa.download_as_string())

In [None]:
answers = get_answers()

qa_overview = collections.defaultdict(list)

for answer_name, answer in tqdm(answers):
    # Load the question file, to get the original texts
    print(answer_name)
    if is_kaggle:
        question_key = answer_name.split(".")[0]
    else:
        question_key = answer_name.split("/")[-2]

    # We'll first load the original questions file that was used to predict on
    original_json = get_qa_content(QUESTION_DIR + "/"+question_key+".json")
    qa_id_text, question_text = parse_questions(original_json["data"][0]["paragraphs"])

    # Now we'll get the predicted answers
    question_results = get_qa_content(answer)

    for question, results in tqdm(question_results.items()):
        if results[0]["text"] != "empty":
            text_tokenized = qa_id_text[question].split(" ")
            text = tokenizer._clean_text(qa_id_text[question])
            text = re.sub(' +', ' ', text)

            sentences = nltk.sent_tokenize(text)
            passage = get_passage(results[0]["start_orig_doc"], sentences, results[0]["text"])

            if passage:
                id_question = question.split("_")[1]
                id_article = question.split("_")[3].split(".")[0]
                qa_overview[id_question].append({
                    "id": id_article,
                    "passage": passage, 
                    "probability": results[0]["probability"]
                })

for id_question, articles in qa_overview.items():
    qa_overview[id_question] = sorted(articles, key = lambda i: i['probability']) 

# Answers
This section attempts to provide answers to the questions that were asked before.

In [None]:
from tabulate import tabulate
from IPython.display import HTML, display
import ipywidgets as widgets

In [None]:
for query in queries:
    display(HTML("<h3>{}</h3>".format(query["question"])))
    data = [[article["id"], article["passage"]] for article in qa_overview[str(query["id"])]]
    display(HTML(tabulate(data, 
                        tablefmt='html', 
                        headers=["Paper SHA","Passage"])))