In [55]:
import faiss
import json 
import numpy as np
import torch
from sentence_transformers import SentenceTransformer


class DataLoader:

    def __init__(self, path):
        self.path = path


    def load_data(self):
        with open(self.path, 'r') as file:
            data = json.load(file)
        return data


    def save_data(self, data ,save_path):
        with open(save_path, 'w') as file:
            json.dump(data, file)


class SentenceEncoder:

    def __init__(self, text_data, model_name):
        self.data = text_data
        self.model = SentenceTransformer(model_name)
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.encoder = self.model.to(self.device)


    def encode_text(self, text):
        embeddings = self.encoder.encode(text, convert_to_tensor=True, show_progress_bar=True)
        return embeddings

    
    def search(self, query, index, k=5):
        query_encoded = self.encode_text([query]).cpu().detach().numpy()
        D, I = index.search(query_encoded, k)
        results = [self.data[_id]['title'] for _id in I[0]]
        return list(zip(results, D[0]))


if __name__ == '__main__':
    data = DataLoader('../data_hub/data_2.json').load_data()
    enc = SentenceEncoder(data, 'distilbert-base-nli-stsb-mean-tokens')

    text_corpus = [i['text'] for i in data]
    text_embeddings = enc.encode_text(text_corpus)
    text_embeddings = text_embeddings.cpu()

    index = faiss.IndexIDMap(faiss.IndexFlatIP(text_embeddings.shape[1]))
    index.add_with_ids(text_embeddings.numpy(), np.array(range(0, len(text_corpus))))
    faiss.write_index(index, '../data_hub/faiss.index')

    q = "Which diseases can be transmitted by animals?"
    results = enc.search(q, index)
    print(results)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

[('Swine influenza', 89.7691), ('Targeted immunization strategies', 78.84313), ('Pandemic prevention', 71.77255), ('HIV/AIDS', 67.6545), ('PREDICT (USAID)', 61.000816)]
