In [1]:
import faiss
import json 
import numpy as np
import torch
from sentence_transformers import SentenceTransformer


class DataLoader:

    def __init__(self, path):
        self.path = path


    def load_data(self):
        with open(self.path, 'r') as file:
            data = json.load(file)
        return data


    def save_data(self, data ,save_path):
        with open(save_path, 'w') as file:
            json.dump(data, file)


class SentenceEncoder:

    def __init__(self, text_data, model_name):
        self.data = text_data
        self.model = SentenceTransformer(model_name)
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.encoder = self.model.to(self.device)


    def encode_text(self, text):
        embeddings = self.encoder.encode(text, show_progress_bar=True)
        return embeddings

    
    def search(self, query, index, k=5):
        query_encoded = self.encode_text(query).reshape(1, -1)
        D, I = index.search(query_encoded, k)
        return D[0], I[0] 


if __name__ == '__main__':
    data = DataLoader('../data_hub/data.json').load_data()

    enc = SentenceEncoder(data, 'distilbert-base-nli-stsb-mean-tokens')

    text_corpus = [i['text'] for i in data]
    text_embeddings = enc.encode_text(text_corpus)

    index = faiss.IndexFlatL2(text_embeddings.shape[1])
    index = faiss.IndexIDMap(index)
    index.add_with_ids(text_embeddings, np.array(range(0, text_embeddings.shape[0])))
    # faiss.write_index(index, '../data_hub/pandemics')

    q = "How many leprosy outbreaks are known to happen?"
    results = enc.search(q, index)
    results[0][::-1].sort()
    search_results = [(i, data[j]['title'], j) for i, j in zip(results[0], results[1])]
    print(search_results)




Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

[(260.04678, 'Plague of Cyprian', 15), (258.3327, 'HIV/AIDS', 10), (250.29892, 'Spanish flu', 19), (247.85165, 'Swine influenza', 21), (238.0826, 'Pandemic severity index', 14)]
