In [1]:
import os  
import faiss
import time
import numpy as np 
import bspli
import torch
import pandas as pd

sift = np.load("dataset/sift-128-euclidean.npy")
print(f'sift data shape: {sift.shape}')

flat = faiss.IndexFlatL2(sift.shape[1])
flat.add(sift)

D, FLAT_I = flat.search(sift[0].reshape(1, sift.shape[1]), k=100) 
print(f'brute query: {FLAT_I}')

sift data shape: (1000000, 128)
brute query: [[     0      2      6  83606 631203 677834 246710 677793 480592  10336
  658180 799350 738996 516942 965310 451321 725637 480903 719046 248230
  799488 500141 799404 466880 529593 688749 558961 686828 183625 432221
  532473 187896 678008 772144  89757 432521 182418 633385 596413 657774
  679499 769701 930778 236210 528709 216605 738730 192507 271323 559065
  134358  81704 631964 206873 271151 851764 261934 225014 404206 632106
  256176 547359 514307 523094 630017 258188 705267 216395 419350 204933
  269211 197644 276460  65965 551986 876717 228378  95134  87235 719140
  407157  79225 808018 559250 420871 525531 162637 764500 547845 724103
  547004 219183 832018 533417  42705 197990 276806 720756 116581 729207]]


In [2]:
sift_tensor = torch.from_numpy(sift)
print(f'sift tensor shape: {sift_tensor.shape}')

index = bspli.index.Indexing(gl_size=1000000, ll_size=300, random_partitioning=False)
index.train(sift_tensor)

print(f"local model len:{len(index._l_model)}")

sift tensor shape: torch.Size([1000000, 128])
torch.Size([343802, 129])
torch.Size([656198, 129])
first stage partitioning finish
partitioning blocks : 2
training local model
1, 100 loss: 0.07621049404144287
1, 200 loss: 0.07621051788330079
1, 300 loss: 0.07621056556701661
1, 400 loss: 0.07621056556701661
1, 500 loss: 0.07621056556701661
1, 600 loss: 0.07621056556701661
1, 700 loss: 0.07621056556701661
1, 800 loss: 0.07621056556701661
1, 900 loss: 0.07621054172515869
1, 1000 loss: 0.07621056556701661
1, 1100 loss: 0.07621056556701661
1, 1200 loss: 0.07621056556701661
1, 1300 loss: 0.07621056556701661
1, 1400 loss: 0.07621056556701661
1, 1500 loss: 0.07621056556701661
1, 1600 loss: 0.07621056556701661
1, 1700 loss: 0.07621056556701661
1, 1800 loss: 0.07611056327819825
1, 1900 loss: 0.07621056079864502
1, 2000 loss: 0.07621056079864502
1, 2100 loss: 0.07621054649353028
1, 2200 loss: 0.07621056079864502
1, 2300 loss: 0.07621055126190185
1, 2400 loss: 0.07621054649353028
1, 2500 loss: 0.07

  ma = torch.tensor(means)


1, 400 loss: 0.005632616281509399
1, 500 loss: 0.0031326165795326232
1, 600 loss: 0.008132616877555848
1, 700 loss: 0.005632616281509399
1, 800 loss: 0.0081326162815094
1, 900 loss: 0.0056326168775558474
1, 1000 loss: 0.005632616281509399
1, 1100 loss: 0.0081326162815094
1, 1200 loss: 0.005632616281509399
1, 1300 loss: 0.0056326168775558474
1, 1400 loss: 0.0031326165795326232
2, 100 loss: 0.005632616281509399
2, 200 loss: 0.0031326165795326232
2, 300 loss: 0.0106326162815094
2, 400 loss: 0.0106326162815094
2, 500 loss: 0.005632616281509399
2, 600 loss: 0.008132616877555848
2, 700 loss: 0.0031326165795326232
2, 800 loss: 0.005632616281509399
2, 900 loss: 0.0031326165795326232
2, 1000 loss: 0.0081326162815094
2, 1100 loss: 0.008132616877555848
2, 1200 loss: 0.0106326162815094
2, 1300 loss: 0.005632616281509399
2, 1400 loss: 0.0056326168775558474
3, 100 loss: 0.0081326162815094
3, 200 loss: 0.005632616281509399
3, 300 loss: 0.0106326162815094
3, 400 loss: 0.0031326165795326232
3, 500 loss

In [3]:
%%time
qp = torch.from_numpy(sift[0])
# print(qp)
pred = index.query(qp, k=100)
pred = pred.to(torch.int)
print(f"pred: {pred}")

predicted local model: 1
pred: tensor([ 17174,  50996,  38474,  49754,  35718,  40714,  47756, 782781,  47722,
         32395,  19600,  10252, 743912, 626270,   9427, 816359, 706007, 599982,
        653453,   7712, 866242, 323382, 961315, 130376, 136915, 648948,   4142,
        425958, 914245, 560147, 985477, 287816, 334001, 651986, 247638,  10353,
        387955, 352830,  20955, 755964, 364932, 448072, 494567, 246103, 536723,
        182870,   7161,   9379, 457276, 759266, 559445, 415125, 506411, 979444,
        323562,  47837, 748438, 825405, 359529, 715956,   8310, 485432, 334348,
        773715, 250453, 283940, 366302, 986731, 239505, 895640, 384280, 624510,
        647472, 763841, 172667, 153847, 393083, 960221, 816115, 652122, 829206,
         21242, 839129, 895246, 408407, 855222, 314417, 878834, 904718, 710267,
          7660, 369629, 845820, 774397, 842194, 915632,  49678, 864095,   9451,
        146684], dtype=torch.int32)
CPU times: total: 11.9 s
Wall time: 1.98 s


In [4]:
result = []

def recall(pred, true):
    x = np.isin(pred, true)
    return x.sum() / true.size


def benchmark_knn_query(data, index, size=1000, k=100):
    indices = np.random.choice(data.shape[0], size, replace=False)
    query_time = 0
    cur_recall = 0

    # query
    for i in indices:
        q = torch.from_numpy(data[i])
        start = time.time()
        qk = index.query(q, k=100)
        query_time += (time.time() - start)
        D, FLAT_I = flat.search(data[i].reshape(1, data.shape[1]), k=k) 
        cur_recall += recall(qk, FLAT_I)
    result.append((query_time/1000, cur_recall/1000))

In [5]:
# recall(pred, FLAT_I)

benchmark_knn_query(sift, index, size=1000, k=100)

predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1


In [10]:
print(result)

[(0.025773720026016234, 6e-05), (0.005614554643630982, 8.999999999999999e-05)]
