In [10]:
import faiss
import json
import numpy as np
import torch

from transformers import AutoModel, AutoTokenizer


class DataLoader:

    def __init__(self, path):
        self.path = path


    def load_data(self):
        with open(self.path, 'r') as file:
            data = json.load(file)
        return data


    def save_data(self, data ,save_path):
        with open(save_path, 'w') as file:
            json.dump(data, file)


class TransformerEncoder:

    """Elements of the TransformerEncoder class were influence by a blog post from Txus Bach
       https://www.codegram.com/blog/finding-similar-documents-with-transformers/"""


    def __init__(self, text_document, model_type):
        self.data = text_document
        self.tokenizer = AutoTokenizer.from_pretrained(model_type)
        self.model = AutoModel.from_pretrained(model_type)

    
    def text_encoder(self, text):
        text_tokens = self.tokenizer(text, return_tensors='pt')
        embed = self.model(**text_tokens)[0].detach().squeeze()
        return torch.mean(embed, dim=0)

    
    def search(self, query, index, k=5):
        encoded_query = self.text_encoder(query).unsqueeze(dim=0).detach().numpy()
        D, I = index.search(encoded_query, k)
        return [(self.data[i], j) for i, j in zip(I[0], D[0])]
        


if __name__ == '__main__':
    data = DataLoader('../data_hub/sentences.json').load_data()

    encoder = TransformerEncoder(data, 'distilbert-base-uncased')
    vectors = [encoder.text_encoder(d) for d in data]
    index = faiss.IndexIDMap(faiss.IndexFlatIP(768))
    index.add_with_ids(np.array([vec.numpy() for vec in vectors]), np.array(range(0, len(data))))

    query = 'Spanish flu casualties?'
    result = encoder.search(query, index, k=5)
    print(result)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.weight', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


[('The Spanish flu, also known as the 1918 flu pandemic, was an unusually deadly influenza pandemic caused by the H1N1 influenza A virus.', 47.486176), ('As of 2018, approximately 37.9 million people are infected with HIV globally.', 44.16535), ('A pandemic is an epidemic of an infectious disease that has spread across a large region, for instance multiple continents or worldwide, affecting a substantial number of people.', 43.215614), ('The death toll of Spanish Flu is estimated to have been somewhere between 17 million and 50 million, and possibly as high as 100 million, making it one of the deadliest pandemics in human history.', 42.78693), ('The most fatal pandemic in recorded history was the Black Death (also known as The Plague), which killed an estimated 75–200 million people in the 14th century.', 42.419025)]
