The goal of this notebook is to demonstrate the usage of a fast K-NN search index to find the most relevant CPC contexts for a given search query. This could potentially be used as an additional source of information for predicting the relevance of a target to an anchor within a CPC context.

In [None]:
! pip install -q sentence-transformers faiss-gpu 
from sentence_transformers import SentenceTransformer
from scipy.stats import pearsonr
from sklearn import metrics
import time
import traceback
import sys
import faiss
import os
import pandas as pd
import numpy as np

In [None]:
class Config:
    build_index = False

Generate CPC title embeddings and build the FAISS search index

In [None]:
model = SentenceTransformer('all-mpnet-base-v2')
cpc_df = pd.read_csv('../input/cpc-codes/titles.csv')

cpc_embs = model.encode(cpc_df.title.values.tolist(), batch_size=256, normalize_embeddings=True)
  
index = faiss.IndexIDMap(faiss.IndexFlatIP(768)) # See here for details on choosing the index: https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index
index.add_with_ids(cpc_embs, np.array(range(0, len(cpc_df))))

faiss.write_index(index, 'cpc_index') # To load the index from storage: index = faiss.read_index('cpc_index')

# Move the index to GPU for fast retrieval
res = faiss.StandardGpuResources()
index = faiss.index_cpu_to_gpu(res, 0, index)

Search the index

In [None]:
def search(data, index, query_emb, num_results=15):
    top_k = index.search(np.array([np.asarray(query_emb).astype('float32')]), num_results)

    return list(zip([
            {'code': data[_id][0],
             'title': data[_id][1],
             'section': data[_id][2],
             'class': data[_id][3],
             'subclass': data[_id][4],
             'group': data[_id][5],
             'main_group': data[_id][6]} 
        for _id in top_k[1].tolist()[0]], top_k[0].tolist()[0]))


query_text = 'abatement'
query_emb = model.encode([query_text], normalize_embeddings=True)[0]
search_results = search(cpc_df.values.tolist(), index, query_emb)
print('\n'.join(set([x[0]['title'] for x in search_results])))