In [46]:
from annoy import AnnoyIndex
import pickle
import os
import time
from tqdm import tqdm_notebook as tqdm
# from tqdm import tqdm
import numpy as np
from functools import wrapper

In [54]:
track_vec = pickle.load(open('track_vec.pkl', 'rb'))
test_user_vec = pickle.load(open('val_user_vec.pkl', 'rb'))

In [115]:
f = 10
t = AnnoyIndex(f, metric='dot')
for i in range(len(track_vec)):
    t.add_item(i, track_vec[i])

t.build(50)

True

In [61]:
t.get_nns_by_vector(test_user_vec[0], 500)

[308298,
 38538,
 77075,
 115613,
 231232,
 231230,
 231225,
 269762,
 115615,
 77081,
 192692,
 269768,
 154160,
 346843,
 38548,
 77093,
 231235,
 346837,
 231281,
 38566,
 4,
 147,
 38829,
 38579,
 160,
 115748,
 192737,
 38622,
 77554,
 77248,
 231679,
 231473,
 154538,
 308349,
 140,
 270754,
 269953,
 269959,
 270338,
 40925,
 109,
 154247,
 154315,
 77786,
 154394,
 241,
 39359,
 271243,
 192903,
 271110,
 270017,
 660,
 232666,
 154510,
 1220,
 44939,
 193129,
 308664,
 155694,
 231785,
 232493,
 270093,
 308791,
 269928,
 309109,
 497,
 77502,
 158734,
 347622,
 39442,
 116199,
 236654,
 38998,
 348091,
 38720,
 232274,
 348841,
 347518,
 348501,
 77560,
 347765,
 231634,
 193271,
 156241,
 77357,
 193436,
 116801,
 39366,
 270415,
 77969,
 276214,
 195148,
 348123,
 270578,
 193527,
 312033,
 310016,
 154627,
 348215,
 233155,
 196600,
 236260,
 196246,
 120270,
 193550,
 78669,
 116467,
 193220,
 193585,
 193907,
 41537,
 270753,
 193249,
 2267,
 232098,
 154710,
 347503,
 1

In [84]:
def find_nearest_exhaustive(data, queries, k):
    if len(data.shape) == 1:
        data = np.array([x for x in data])
    n_items = data.shape[0]
    n_feat = data.shape[1]
    n_queries = len(queries)
    
    def single_query(query):
        start = time.time()
        if type(query) is not np.ndarray:
            query = np.array(query)
        res = np.argsort(-data.dot(query))[:k]
        interval = time.time() - start
        return interval, res
    times = []
    results = []
    for i in tqdm(range(n_queries)):
        interval, res = single_query(queries[i])
        times.append(interval)
        results.append(res)
    mean_time = sum(times) / len(times)
    print('-' * 26)
    print('Exhaustive Brute-force Search\n')
    print('Mean Query Search: %.6f' % mean_time)
    
    return mean_time, results    

In [85]:
bf_mean_time, bf_results = find_nearest_exhaustive(track_vec, test_user_vec, 500)

HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))


--------------------------
Exhaustive Brute-force Search

Mean Query Search: 0.046829


In [86]:
pickle.dump(bf_results, open('./bf_results', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)

In [106]:
def wrap_with(obj, method, mapping):
    '''
    obj: the model that can respond to the query
    method: the name of the query method
    mapping: what input be mapped
    '''
    get_map = lambda x: [x[mapping[i]] for i in range(len(mapping))]
    def wrapped(*args, **kwrds):
        return obj.__getattribute__(method)(*get_map(args))
    return wrapped

<div style='background-color: #ffffb3;border-left: 5px solid #e6e600;padding: 0.5em;position:relative;'> 

<div style='position: absolute;top: 0px;left: 0;bottom: 0;right: 0;z-index: 0;overflow: hidden;font-size:60px; font-weight:bold; color: gray;opacity: 0.3;font-family: Arial, sans-serif'>⚠</div> <br>


Please wrap your model so that it can take input `(query, k)`


</div>

In [116]:
annoy10_wrapped = wrap_with(t, 'get_nns_by_vector', [0, 1])

In [112]:
def find_nearest_algo(data, queries, true_label, model_wrapped, k):
    if len(data.shape) == 1:
        data = np.array([x for x in data])
    n_items = data.shape[0]
    n_feat = data.shape[1]
    n_queries = len(queries)
    def single_query(query):
        start = time.time()
        res = model_wrapped(query, k)
        interval = time.time() - start
        return interval, res
    def get_recall(predict, truth):
        return len([x for x in predict if x in truth]) / len(truth)
    times = []
    recalls = []
    for i in tqdm(range(n_queries)):
        interval, res = single_query(queries[i])
        recall = get_recall(res, true_label[i])
        times.append(interval)
        recalls.append(recall)
    mean_time = sum(times) / len(times)
    mean_recall = sum(recalls) / len(recalls)
    print('-' * 26)
    print('Algorithm with k\' = %d\n' % k)
    print('Mean Query Search Time: %.6f' % mean_time)
    print('Mean Recall: %.6f' % mean_recall)
    
    return mean_time, mean_recall 

In [117]:
algo100_time, algo100_recall = find_nearest_algo(track_vec, val_user_vec, bf_results, annoy10_wrapped, 500)

HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))


--------------------------
Algorithm with k' = 500

Mean Query Search Time: 0.003408
Mean Recall: 0.178295
