In [18]:
import os  
import faiss
import time
import numpy as np 
import bspli
import torch
import pandas as pd

mnist = np.load("dataset/mnist-784-euclidean.npy")
print(f'mnist data shape: {mnist.shape}')

flat = faiss.IndexFlatL2(mnist.shape[1])
flat.add(mnist)

D, FLAT_I = flat.search(mnist[0].reshape(1, mnist.shape[1]), k=100) 
print(f'brute query: {FLAT_I}')

mnist data shape: (60000, 784)
brute query: [[    0 32248  8728 18932 30483 24149 42338 52295 26251 50173 53634 24330
  54159 57528  1482 53428 18123 31379 52864 10536 29719 36087 30489 23947
  20034 52057 33825 21654 31008 55208 22477 44282 47968 54203 19825  1634
  27378 33909 15378 24708 34474 26413 16017 46824 46358  1516 34557 16832
  21629 29021 10740 24107  5688 52665  1864  5036 39031  1978 40546 22322
  52231 37284 24730  5970 21976 16945  9568 36697 25675 54189 11396 42555
  33445 52540 44263 18404 19186 24232 54184 25762 14736 33970  5210 59212
   8642 22569 15052  2933  6772 22963  6516   832 21244 21583 35838 59846
  21210 13502 52559 13862]]


In [24]:
mnist_tensor = torch.from_numpy(mnist)
print(f'mnist tensor shape: {mnist_tensor.shape}')

index = bspli.index.Indexing(gl_size=100000, ll_size=100, random_partitioning=False)
index.train(mnist_tensor)

print(f"local model len:{len(index._l_model)}")

mnist tensor shape: torch.Size([60000, 784])
torch.Size([1995, 785])
torch.Size([58005, 785])
first stage partitioning finish
partitioning blocks : 2
training local model
training local model
1, 100 loss: 0.07005531787872314
1, 200 loss: 0.07005526065826416
1, 300 loss: 0.07005530834197998
1, 400 loss: 0.07005532264709473
1, 500 loss: 0.07005532264709473
2, 100 loss: 0.07005532264709473
2, 200 loss: 0.07005532264709473
2, 300 loss: 0.07005532264709473
2, 400 loss: 0.07005532264709473
2, 500 loss: 0.07005532264709473
3, 100 loss: 0.07005532264709473
3, 200 loss: 0.07005532264709473
3, 300 loss: 0.07005532264709473
3, 400 loss: 0.07005532264709473
3, 500 loss: 0.07005532264709473
4, 100 loss: 0.07005532264709473
4, 200 loss: 0.07005532264709473
4, 300 loss: 0.07005532264709473
4, 400 loss: 0.07005532264709473
4, 500 loss: 0.07005532264709473
5, 100 loss: 0.07005532264709473
5, 200 loss: 0.07005532264709473
5, 300 loss: 0.07005532264709473
5, 400 loss: 0.07005532264709473
5, 500 loss: 0.0

In [25]:
%%time
qp = torch.from_numpy(mnist[0])
# print(qp)
pred = index.query(qp, k=100)
pred = pred.to(torch.int)
print(f"pred: {pred}")

predicted local model: 1
pred: tensor([    0, 32248,  8728, 18932, 30483, 24149, 42338, 52295, 26251, 50173,
        53634, 24330, 54159, 57528,  1482, 53428, 18123, 31379, 52864, 10536,
        29719, 36087, 30489, 23947, 20034, 52057, 33825, 21654, 31008, 55208,
        22477, 44282, 47968, 54203, 19825,  1634, 27378, 33909, 15378, 24708,
        34474, 26413, 16017, 46824, 46358,  1516, 34557, 16832, 21629, 29021,
        10740, 24107,  5688, 52665,  1864,  5036, 39031,  1978, 40546, 22322,
        52231, 37284, 24730,  5970, 21976, 16945,  9568, 36697, 25675, 54189,
        11396, 42555, 33445, 52540, 44263, 18404, 19186, 24232, 54184, 25762,
        14736, 33970,  5210, 59212,  8642, 22569, 15052,  2933,  6772, 22963,
         6516,   832, 21244, 21583, 35838, 59846, 21210, 13502, 52559, 13862],
       dtype=torch.int32)
CPU times: total: 1.44 s
Wall time: 246 ms


In [21]:
result = []

def recall(pred, true):
    x = np.isin(pred, true)
    return x.sum() / true.size


def benchmark_knn_query(data, index, size=1000, k=100):
    indices = np.random.choice(data.shape[0], size, replace=False)
    query_time = 0
    cur_recall = 0

    # query
    for i in indices:
        q = torch.from_numpy(data[i])
        start = time.time()
        qk = index.query(q, k=100)
        query_time += (time.time() - start)
        D, FLAT_I = flat.search(data[i].reshape(1, data.shape[1]), k=k) 
        cur_recall += recall(qk, FLAT_I)
    result.append((query_time/1000, cur_recall/1000))

In [26]:
# recall(pred, FLAT_I)

benchmark_knn_query(mnist, index, size=1000, k=100)

predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1
predicted local model: 1


In [27]:
print(result)

[(0.007357717037200928, 0.019600000000000006), (0.25210712623596193, 0.972840000000001)]
