In [1]:
import elasticsearch 
import json
import torch

from sentence_transformers import SentenceTransformer, util
from pprint import pprint

In [2]:
# Initiate Elasticsearch
es = elasticsearch.Elasticsearch('http://127.0.0.1:9200/')

# Initiate Transformer
model = SentenceTransformer('roberta-large-nli-stsb-mean-tokens')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

In [3]:
def load_questions(data_loc):
    with open(data_loc) as f:
        qs = json.load(f)
    return qs


def search(index, query):
    search_query = {"query": {
                       "bool": {
                           "should": { "match": {"text": query}},
                           "must_not": {"term" : { "section_title.keyword" : 'External links'}},
                            }
                        }
                    }
    results = index.search(index='pandemic_docs', body=search_query, size=50) 
    results = results['hits']['hits']
    return results 


def return_results(results):
    text = [x['_source']['text'] for x in results]
    article_title = [x['_source']['article_title'] for x in results]
    section_title = [x['_source']['section_title'] for x in results]
    score = [x['_score'] for x in results]
    return text, article_title, section_title, score


def search_similarity(questions, index, model):
    search_output = []
    for question in questions:
        q_result ={'question': question, 'results': []} 
        result = search(index, question)
        text, article_title, section_title, score = return_results(result)

        query_embeddings = model.encode(question).reshape(1, -1)
        text_embeddings = model.encode(text, show_progress_bar=True, convert_to_numpy=True)
        search_res = util.semantic_search(query_embeddings, text_embeddings, top_k=10)[0]
        
        result_output = []
        for val in search_res:
            output = {}
            output['title'] = article_title[val['corpus_id']]
            output['section_title'] = section_title[val['corpus_id']]
            output['score'] = score[val['corpus_id']]
            result_output.append(output)
            ranked_result = [x for x in sorted(result_output, key=lambda x:x['score'], reverse=True)]
        
        q_result['results'] = ranked_result
        search_output.append(q_result)
    return search_output


if __name__ == '__main__':
    qs = load_questions('../data_hub/questions.json')
    output = search_similarity(qs, es, model)
    pprint(output)

  results = index.search(index='pandemic_docs', body=search_query, size=50)


Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

[{'question': 'How many people have died during Black Death?',
  'results': [{'score': 8.295511,
               'section_title': 'Deaths',
               'title': 'COVID-19 pandemic'},
              {'score': 6.7890177,
               'section_title': 'Prognosis',
               'title': 'Cholera'},
              {'score': 6.6869187,
               'section_title': 'Epidemiology',
               'title': 'Cholera'},
              {'score': 6.2728453,
               'section_title': 'Ebola',
               'title': 'Science diplomacy and pandemics'},
              {'score': 6.0281324,
               'section_title': 'Tuberculosis',
               'title': 'Pandemic'},
              {'score': 5.914283,
               'section_title': 'Problems with the bills',
               'title': 'Bills of mortality'},
              {'score': 5.600337,
               'section_title': 'February 1930',
               'title': '1929–1930 psittacosis pandemic'},
              {'score': 5.5176015,
       