In [1]:
import numpy as np
import pandas as pd
import re
import time
import os
import pickle
from datasketch import MinHash, MinHashLSHForest

In [2]:
def preprocess(text):
    text=re.sub(r'[^\w\s]','',text)
    tokens=text.lower()
    tokens=re.split(r'[\s_]',tokens)
    return tokens

In [3]:
def get_forest(data,perms):
    start_time=time.time()
    minhash=[]
    
    m=MinHash(num_perm=perms)
    for entity in data:
        tokens=preprocess(entity)
        m=MinHash(num_perm=perms)
        for s in tokens:
            m.update(s.encode('utf8'))
        minhash.append(m)    
    
    forest=MinHashLSHForest(num_perm=perms)
    for i,m in enumerate(minhash):
        forest.add(i,m)
    
    forest.index()
    print('It took {} seconds to build forest.'.format(time.time()-start_time))
    return forest

In [4]:
def predict(query_entity,entities,perms,num_results,forest):    
    tokens=preprocess(query_entity)
    m=MinHash(num_perm=perms)
    for s in tokens:
        m.update(s.encode('utf8'))
    
    idx_array=np.array(forest.query(m,num_results))
    if len(idx_array)==0:
        return None
    
    results=[]
    for idx in idx_array: results.append(entities[idx])
    return results

In [5]:
folder='D:/Projects/DH/Intern/VinAI/Repos/secret/dataset/KDWD'
file='training_data//entity2id.txt'
entities=[]
with open(os.path.join(folder,file),'r',encoding='utf8') as f:
    lines=f.readlines()
    for line in lines[1:]:
        ent,id=line.strip().split('\t')
        entities.append(ent)

In [6]:
%%time

permutations=128
forest=get_forest(entities,permutations)

It took 200.7253224849701 seconds to build forest.
Wall time: 3min 20s


In [7]:
pickleOut=open('D:/Projects/DH/Intern/VinAI/Repos/secret/models/weights/lsh_forest.pkl','wb')
pickle.dump(forest,pickleOut)
pickleOut.close()

In [13]:
query_entity='messi'
num_results=5
predict(query_entity,entities,permutations,num_results,forest)

['lionel_messi']