In [1]:
import sys
sys.path.append("./Bspli/")
import os  
import faiss
import time
import numpy as np 
import index 
import torch
import pandas as pd

mnist = np.load("dataset/mnist-784-euclidean.npy")
print(f'mnist data shape: {mnist.shape}')
print(f'mnist dtype: {mnist.dtype}')

# mnist = np.random.uniform(low=0, high=255, size=(100000, 2))
# mnist = mnist.astype(np.float32)
# print(f'mnist dtype: {mnist.dtype}')
# print(f'mnist data shape: {mnist.shape}')


mnist data shape: (60000, 784)
mnist dtype: float32


In [2]:
%%time
# Brute Search
flat = faiss.IndexFlatL2(mnist.shape[1])
flat.add(mnist)

D, FLAT_I = flat.search(mnist[0].reshape(1, mnist.shape[1]), k=100) 
print(f'brute query: {FLAT_I}')

brute query: [[    0 32248  8728 18932 30483 24149 42338 52295 26251 50173 53634 24330
  54159 57528  1482 53428 18123 31379 52864 10536 29719 36087 30489 23947
  20034 52057 33825 21654 31008 55208 22477 44282 47968 54203 19825  1634
  27378 33909 15378 24708 34474 26413 16017 46824 46358  1516 34557 16832
  21629 29021 10740 24107  5688 52665  1864  5036 39031  1978 40546 22322
  52231 37284 24730  5970 21976 16945  9568 36697 25675 54189 11396 42555
  33445 52540 44263 18404 19186 24232 54184 25762 14736 33970  5210 59212
   8642 22569 15052  2933  6772 22963  6516   832 21244 21583 35838 59846
  21210 13502 52559 13862]]
CPU times: total: 62.5 ms
Wall time: 71.3 ms


In [3]:
mnist_tensor = torch.from_numpy(mnist)
print(f'mnist tensor shape: {mnist_tensor.shape}')

idx = index.Indexing(
    gl_size=10000, 
    ll_size=2000,
    g_epoch_num=3,
    l_epoch_num=10,
    g_hidden_size=5,
    l_hidden_size=5,
    g_block_range=4,
    l_block_range=4,
    random_partitioning=False
)
idx.train(mnist_tensor)

print(f"local model len:{len(idx._l_model)}")

mnist tensor shape: torch.Size([60000, 784])
torch.Size([1995, 785])
torch.Size([8906, 785])
torch.Size([4610, 785])
torch.Size([3738, 785])
torch.Size([7329, 785])
torch.Size([4915, 785])
torch.Size([8447, 785])
torch.Size([4660, 785])
torch.Size([2740, 785])
torch.Size([2620, 785])
torch.Size([382, 785])
torch.Size([9658, 785])
first stage partitioning finish
partitioning blocks : 12
training local model
training local model
training local model
training local model
training local model
training local model
training local model
training local model
training local model
training local model
training local model
training local model
trainging global model
global index train smaple count: 132
finish
local model len:12


  ma = torch.tensor(means)
  return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta)
  return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta)


In [4]:
%%time

def recall(pred, true):
    x = np.isin(pred, true)
    return x.sum() / true.size

qp = torch.from_numpy(mnist[0])
# print(qp)
pred = idx.query(qp, k=100)
print(f"recall: {recall(pred, FLAT_I)}")

pred = pred.to(torch.int)
print(f"pred: {pred}")

recall: 0.73
pred: tensor([    0,  8728, 30483, 24149, 42338, 52295, 50173, 53634, 24330,  1482,
        53428, 18123, 31379, 52864, 10536, 36087, 30489, 23947, 20034, 52057,
        21654, 44282, 54203, 19825,  1634, 27378, 33909, 15378, 24708,  1516,
        34557, 29021, 10740, 24107,  5688, 52665,  1864,  5036, 39031,  1978,
        40546, 22322,  5970, 21976, 16945, 36697, 25675, 54189, 11396, 42555,
        33445, 52540, 44263, 24232, 54184, 25762, 14736, 33970,  5210, 59212,
         8642, 22569, 15052,  2933, 22963,  6516,   832, 21244, 35838, 21210,
        13502, 52559, 13862, 41980, 43997, 53812, 13193, 46698, 15975, 46968,
        50187,  3188, 28384, 24271, 21661, 15907, 33469, 38896, 15444, 23225,
        29049, 50071, 45567, 43031, 26313,  2395, 11936, 16112, 49334, 44476],
       dtype=torch.int32)
CPU times: total: 453 ms
Wall time: 82.3 ms


In [5]:
result = []

def benchmark_knn_query(data, index, size=1000, k=100):
    indices = np.random.choice(data.shape[0], size, replace=False)
    query_time = 0
    cur_recall = 0

    # query
    for i in indices:
        q = torch.from_numpy(data[i])
        start = time.time()
        qk = index.query(q, k=100, g_block_range=8, l_block_range=5)
        query_time += (time.time() - start)
        D, FLAT_I = flat.search(data[i].reshape(1, data.shape[1]), k=k) 
        cur_recall += recall(qk, FLAT_I)
    result.append((query_time/1000, cur_recall/1000))

In [6]:
# recall(pred, FLAT_I)

benchmark_knn_query(mnist, idx, size=1000, k=100)

In [7]:
print(result)
print(idx.get_search_blocks_num())

[(0.10986903166770935, 0.9176100000000027)]
59


In [14]:
%%time
from sklearn.neighbors import BallTree

tree = BallTree(mnist, leaf_size=2000)
dist, ind = tree.query(mnist[0].reshape(1, mnist.shape[1]), k=100) 

flat = faiss.IndexFlatL2(mnist.shape[1])
flat.add(mnist)

D, FLAT_I = flat.search(mnist[0].reshape(1, mnist.shape[1]), k=100) 

def recall(pred, true):
    x = np.isin(pred, true)
    return x.sum() / true.size


def benchmark_knn_query(data, index, size=1000, k=100):
    indices = np.random.choice(data.shape[0], size, replace=False)
    query_time = 0
    cur_recall = 0

    # query
    for i in indices:
        q = data[i].reshape(1, mnist.shape[1])
        start = time.time()
        qk = index.query(q, k=100)
        query_time += (time.time() - start)
        D, FLAT_I = flat.search(data[i].reshape(1, data.shape[1]), k=k) 
        cur_recall += recall(qk, FLAT_I)
    result.append((query_time/1000, cur_recall/1000))


CPU times: total: 2.55 s
Wall time: 2.66 s


1.0