In [1]:
import elasticsearch 
import json
import torch

from sentence_transformers import SentenceTransformer, util
from pprint import pprint

In [2]:
# Initiate Elasticsearch
es = elasticsearch.Elasticsearch('http://127.0.0.1:9200/')

# Initiate Transformer
model = SentenceTransformer('roberta-large-nli-stsb-mean-tokens')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

In [3]:
def load_questions(data_loc):
    with open(data_loc) as f:
        qs = json.load(f)
    return qs


def search(index, query):
    search_query = {"query": {
                       "bool": {
                           "should": { "match": {"text": query}},
                           "must_not": {"terms" : { "section_title.keyword" : excluded_search}},
                            }
                        }
                    }
    results = index.search(index='pandemic_docs', body=search_query, size=50) 
    results = results['hits']['hits']
    return results 


def return_results(results):
    text = [x['_source']['text'] for x in results]
    article_title = [x['_source']['article_title'] for x in results]
    section_title = [x['_source']['section_title'] for x in results]
    return text, article_title, section_title


def search_similarity(questions, index, model):
    search_output = []
    for question in questions:
        q_result ={'question': question, 'results': []} 
        result = search(index, question)
        text, article_title, section_title = return_results(result)

        query_embeddings = model.encode(question).reshape(1, -1)
        text_embeddings = model.encode(text, show_progress_bar=True, convert_to_numpy=True)
        search_res = util.semantic_search(query_embeddings, text_embeddings, top_k=10)[0]
        
        result_output = []
        for val in search_res:
            output = {}
            output['title'] = article_title[val['corpus_id']]
            output['section_title'] = section_title[val['corpus_id']]
            output['BERT_score'] = val['score']
            result_output.append(output)
            ranked_result = [x for x in sorted(result_output, key=lambda x:x['BERT_score'], reverse=True)]
        
        q_result['results'] = ranked_result
        search_output.append(q_result)
    return search_output


if __name__ == '__main__':
    excluded_search = ['Data and graphs', 'External links', 'Further reading', 'Medical journals', 'See also']
    qs = load_questions('../data_hub/questions.json')
    output = search_similarity(qs, es, model)
    pprint(output)

  results = index.search(index='pandemic_docs', body=search_query, size=50)


Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

[{'question': 'How many people have died during Black Death?',
  'results': [{'BERT_score': 0.25094470381736755,
               'section_title': 'Economic impact',
               'title': 'HIV/AIDS'},
              {'BERT_score': 0.2435573786497116,
               'section_title': 'Epidemiology',
               'title': 'HIV/AIDS'},
              {'BERT_score': 0.23495584726333618,
               'section_title': 'Tuberculosis',
               'title': 'Pandemic'},
              {'BERT_score': 0.22389301657676697,
               'section_title': 'Ebola',
               'title': 'Science diplomacy and pandemics'},
              {'BERT_score': 0.2073315978050232,
               'section_title': 'Epidemiology',
               'title': 'Cholera'},
              {'BERT_score': 0.19188454747200012,
               'section_title': 'Prognosis',
               'title': 'HIV/AIDS'},
              {'BERT_score': 0.1677030622959137,
               'section_title': 'Deaths',
               'title':