In [2]:
import faiss 
import json
import torch 
import numpy as np

from sentence_transformers import SentenceTransformer
from transformers import pipeline


class DataLoader:

    def __init__(self, path):
        self.path = path


    def load_data(self):
        with open(self.path, 'r') as file:
            data = json.load(file)
        return data


    def save_data(self, data ,save_path):
        with open(save_path, 'w') as file:
            json.dump(data, file)


class TransformerEncoder:


    def __init__(self, model_name, pipe_model, token_model):
        self.encoder = SentenceTransformer(model_name)
        self.QnA = pipeline('question-answering', model=pipe_model,
                            tokenizer=token_model)


    def find_answers(self, text, query, index, k=5):
        query_encoded = self.encoder.encode([query],convert_to_tensor=True,
                                            show_progress_bar=True).cpu().detach().numpy()
        D, I = index.search(query_encoded, k)
        answers = [self.QnA({'question': query,
                             'context': f'{text[i]}'}) for i in I[0]]
        return sorted(answers, key=lambda x: x['score'], reverse=True)

    
    def print_answers(self, answers):
        for answer in answers:
            print(f"score: {answer['score']:.6f} \t answer: {answer['answer']}")


if __name__ == '__main__':
    data = DataLoader('../data_hub/data_2.json').load_data()
    embedder = TransformerEncoder('distilbert-base-nli-stsb-mean-tokens',
                                  'ktrapeznikov/albert-xlarge-v2-squad-v2',
                                   'albert-xlarge-v2')
    index = faiss.read_index('../data_hub/faiss.index')
    
    query = "How many people have died during Black Death?"
    text_corpus = [d['text'] for d in data]
    answers = embedder.find_answers(text_corpus, query, index)
    embedder.print_answers(answers)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

score: 0.554148 	 answer: 75–200 million
score: 0.000126 	 answer: 15%.
score: 0.000045 	 answer: between 17 million and 50 million,
score: 0.000001 	 answer: 37.9 million
score: 0.000000 	 answer: pandemic prevention in some respects.
