In [1]:
%cd /mnt/z/Project/semantic_search/qualia

/mnt/z/Project/semantic_search/qualia


In [2]:
import pandas as pd
import transformers
from embedding_model import Model

In [4]:
!pwd

/mnt/z/Project/semantic_search/qualia


In [5]:
m = Model()

In [6]:
m.make_pretrained('stsb-distilbert-base')

Model Compiled: stsb-distilbert-base


In [7]:
# Loading dataset
# Big JSON file 
import json
arxiv = json.load(open('/mnt/z/Project/semantic_search/qualia/data/arxivData.json'))
# arxiv



In [8]:
def join_ds(x, useful_keys):
    return '. '.join([x[useful_keys[0]], x[useful_keys[1]]])

useful_keys = ['title', 'summary']
arxiv_data = list(map(lambda x: join_ds(x, useful_keys), arxiv))
arxiv_data_labels = list(range(len(arxiv_data)))


In [10]:
label_mapping = dict(zip(arxiv_data_labels, arxiv_data))
# label_mapping

In [13]:
sample_len = 200
data_vec = m.encode_sentences(arxiv_data[:sample_len])
data_vec.shape

(200, 768)

In [34]:
data_labels = arxiv_data_labels[:sample_len]
# data_labels

In [14]:
# Make Index from docs
import os
import shutil
import re
import hnswlib
import numpy as np
from utils import *

In [26]:
HNSW_PARAMS = {
    "save_file": 'hnsw_index.bin',
    "M": 16,
    "ef_construction": 200,
    "num_threads": MAX_SEARCH_THREADS,
    "num_elements": 50,
    "label_mapping": label_mapping 
}

In [27]:
# from search_index import Index

In [28]:
hnsw_index = Index(HNSW_PARAMS)

In [29]:
hnsw_index.M

16

In [30]:
hnsw_index.define_index(idx_size=10)

First time?
Saving Fresh New index: /mnt/z/Project/semantic_search/qualia/model/v1/hnsw_index/hnsw_index.bin
Save index size: 10


In [35]:
hnsw_index.update_index(data_vec, data_labels)


            Loading Previously saved index:
                Loading File:     hnsw_index.bin
                New max_elements: 410
        
Adding n=10 batches of items from the data
200 200
Adding keys to update list
Adding keys to update list
Adding keys to update list
Adding keys to update list
Adding keys to update list
Adding keys to update list
Adding keys to update list
Adding keys to update list
Adding keys to update list
Adding keys to update list
Adding keys to update list
Adding keys to update list
Adding keys to update list
Adding keys to update list
Adding keys to update list
Adding keys to update list
Adding keys to update list
Adding keys to update list
Adding keys to update list
Adding keys to update list
Saving new index of size 200
Save index size: 200


In [49]:
# print(hnsw_index.index.get_ids_list())

In [31]:
hnsw_index.init_search()

Loading saved index


In [48]:
# hnsw_index.k = 2
MAX_NEAREST_NBRS = 2
hnsw_index.search(m.encode_sentences("A Structured Self-attentive Sentence Embedding"))

{'matches': [{81: ('A Structured Self-attentive Sentence Embedding. This paper proposes a new model for extracting an interpretable sentence\nembedding by introducing self-attention. Instead of using a vector, we use a\n2-D matrix to represent the embedding, with each row of the matrix attending on\na different part of the sentence. We also propose a self-attention mechanism\nand a special regularization term for the model. As a side effect, the\nembedding comes with an easy way of visualizing what specific parts of the\nsentence are encoded into the embedding. We evaluate our model on 3 different\ntasks: author profiling, sentiment classification, and textual entailment.\nResults show that our model yields a significant performance gain compared to\nother sentence embedding methods in all of the 3 tasks.',
    '45.662')},
  {48: ('Explaining Recurrent Neural Network Predictions in Sentiment Analysis. Recently, a technique called Layer-wise Relevance Propagation (LRP) was shown\nto del

In [41]:
hnsw_index.M

16

In [19]:
class Index():
    def __init__(self, hnsw_params:dict):
        """
        {
            "save_dir": ,
            "save_file": ,
            "M": ,
            "ef_construction": ,
            "item_batch_size": ,
            "num_threads": ,
            "num_elements": 
        }
        """
        self.SAVE_DIR = hnsw_params.get("save_dir", INDEX_DIR)
        self.SAVE_FILE = hnsw_params.get("save_file", "hnsw_index.bin")
        self.CURR_IDX_SIZE = self.find_current_index_size()
        self.M = hnsw_params.get("M", 200)
        self.ef_construction = hnsw_params.get("ef_construction", 200)
        self.item_batch_size = hnsw_params.get("item_batch_size", 10)
        self.num_threads = hnsw_params.get("num_threads", MAX_SEARCH_THREADS)
        self.num_elements = hnsw_params.get("num_elements", None)
        self.label_mapping = hnsw_params.get("label_mapping", None)
        self.index = hnswlib.Index(space='cosine', dim=EMB_DIM)
        self.index_loaded = False

    def init_search(self):
        """
        load any presaved index if not, return empty index def
        """
        print("Loading saved index")
        self.index.load_index(self.SAVE_DIR+self.SAVE_FILE,
                              max_elements=self.CURR_IDX_SIZE)
        self.index_loaded = True

    def find_current_index_size(self):
        """
            Find saved index size 
        """
        try:
            return pickle.load(open(INDEX_SIZE_PATH, 'rb'))
        except Exception as e:
            print("No Index size pre saved")
            return None

    def define_index(self, idx_size):
        """
        define fresh new index with params given when init
        save in given path
        """
        print("First time?")
        # Initing index - the maximum number of elements should be known beforehand
        self.CURR_IDX_SIZE = idx_size
        self.index.init_index(max_elements=self.CURR_IDX_SIZE,
                              ef_construction=self.ef_construction,
                              M=self.M)
        print(f"Saving Fresh New index: {self.SAVE_DIR+self.SAVE_FILE}")
        self.index.save_index(self.SAVE_DIR+self.SAVE_FILE)
        print(f"Save index size: {self.CURR_IDX_SIZE}")
        pickle.dump(self.CURR_IDX_SIZE, open(INDEX_SIZE_PATH, 'wb'))

    def _batch(self, iterable1, iterable2, n=1):
        l1 = iterable1.shape[0]
        l2 = iterable2.shape[0]
        print(l1, l2)
        assert l1 == l2
        # print("LEN OF ITER", l)
        for ndx in range(0, l1, n):
            # print("ITER", iterable[ndx:min(ndx + n, l)])
            yield iterable1[ndx:min(ndx + n, l1)], iterable2[ndx:min(ndx + n, l1)]
    
    def get_current_count(self):
        if not self.index_loaded:
            print("Index Not Loaded, please run init_search() first")
            return False
        return self.index.get_current_count()


    def update_index(self, data, data_labels):
        """
        batch add things to index and save to save path
        data should be:
            [(vector, label)]
        """
        
        if not self.index_loaded:
            print("Index Not Loaded, please run init_search() first")
            return False
                
        print(f"""
            Loading Previously saved index:
                Loading File:     {self.SAVE_FILE}
                New max_elements: {self.CURR_IDX_SIZE + len(data)}
        """)
        try:
            num_elements = len(data)
            self.index.load_index(self.SAVE_DIR+self.SAVE_FILE,
                                  max_elements=self.CURR_IDX_SIZE+num_elements)
            self.CURR_IDX_SIZE += num_elements
        except Exception as e:
            print(e)
            print("Probably not saved earlier")
            return None
        
        # List of qids to update as 'processed=True'
        update_list = [] 
        data = np.array(data)  
        data_labels = np.array(data_labels) 
        # Element insertion (can be called several times):
        print(f"Adding n={self.item_batch_size} batches of items from the data")
        for d, dl in self._batch(data, data_labels, n=self.item_batch_size):
            self.index.add_items(data=d,
                                 ids=dl,
                                 num_threads=self.num_threads)
            print(f"Adding keys to update list")
            update_list.extend(dl)

        print(f"Saving new index of size {self.index.get_current_count()}")
        
        # Save index so far
        self.index.save_index(self.SAVE_DIR+self.SAVE_FILE)
        
        self.CURR_IDX_SIZE = self.index.get_current_count()
        print(f"Save index size: {self.CURR_IDX_SIZE}")
        pickle.dump(self.CURR_IDX_SIZE, open(INDEX_SIZE_PATH, 'wb'))

    def search(self, query_vector):
        if not self.index_loaded:
            print("Index Not Loaded, please run init_search() first")
            return False

        # vec = vectorizer.transform([query])
        # vec = [svd.transform(x.A)[0] for x in vec]
        labels, distances = self.index.knn_query(list(query_vector),
                                                 k=MAX_NEAREST_NBRS,
                                                 num_threads=self.num_threads)
        if labels is None:
            print(f"No Matches found")
            return {"matches": []}

        if self.label_mapping is not None:
            resp = []
            for q_ids, dists in zip(labels, distances):
                raw_set = []
                for qid, d in zip(q_ids, dists):
                    raw_set.append((int(qid), d, self.label_mapping[qid]))
                raw_set = sorted(raw_set, key=lambda x: x[1], reverse=False)
                resp.append([{k: (v, "%.3f"%((1-d)*100))} for k, d, v in raw_set])
            

        return {"matches": resp[0]}
            

In [16]:
MAX_NEAREST_NBRS

5