In [1]:
import faiss
import json
import numpy as np
import torch

from transformers import AutoModel, AutoTokenizer
from pprint import pprint


class DataLoader:

    def __init__(self, path):
        self.path = path


    def load_data(self):
        with open(self.path, 'r') as file:
            data = json.load(file)
        return data


    def save_data(self, data ,save_path):
        with open(save_path, 'w') as file:
            json.dump(data, file)


class TransformerEncoder:

    def __init__(self, text_document, model_type):#, input_shape=768):
        self.data = text_document
        self.tokenizer = AutoTokenizer.from_pretrained(model_type)
        self.model = AutoModel.from_pretrained(model_type)
        # self.index = faiss.IndexIDMap(faiss.IndexFlatIP(input_shape))


    def text_encoder(self):
        embeds = [self.model(**self.tokenizer(doc, return_tensors='pt'))[0].detach().squeeze() for doc in self.data]
        return [torch.mean(embed, dim=0) for embed in embeds]

    
    def query_encoder(self, query):
        query_tokens = self.tokenizer(query, return_tensors='pt', padding=True)
        embed = self.model(**query_tokens)[0]#.detach()
        return torch.mean(embed, dim=0)

    
    def search(self, query, index, k=5):
        encoded_query = self.query_encoder(query).detach().numpy()
        I, D = index.search(encoded_query, k)
        return [(self.data[i], j) for i, j in zip(D[0], I[0])]


if __name__ == '__main__':
    data = DataLoader('./data/sentences.json').load_data()

    encoder = TransformerEncoder(data, 'distilbert-base-uncased')
    vectors = encoder.text_encoder()
    index = faiss.IndexIDMap(faiss.IndexFlatIP(768))
    index.add_with_ids(np.array([vec.numpy() for vec in vectors]), np.array(range(0, len(data))))

    query = 'How to prevent the spread of viral infections?'
    result = encoder.search(query, index)
    print(result)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


[('As of 2018, approximately 37.9 million people are infected with HIV globally.', 20.992712), ('Current pandemics include COVID-19 (SARS-CoV-2) and HIV/AIDS.', 20.251728), ('Cholera is an infection of the small intestine by some strains of the bacterium Vibrio cholerae.', 19.224836), ('Common symptoms of COVID-19 include fever, cough, fatigue, breathing difficulties, and loss of smell.', 18.865395), ('A pandemic is an epidemic of an infectious disease that has spread across a large region, for instance multiple continents or worldwide, affecting a substantial number of people.', 17.826128)]
