In [1]:
import pandas as pd
import numpy as np
from scipy.spatial import cKDTree
import vecnn
import time
import hnswlib
import faiss
from typing import Any, Tuple

import h5py

DATA_PATH = '../../data'

laion_path = f'{DATA_PATH}/laion2B-en-clip768v2-n=300K.h5'
laion_gold_queries_path = f'{DATA_PATH}/public-queries-2024-laion2B-en-clip768v2-n=10k.h5'
laion_gold_path = f'{DATA_PATH}/gold-standard-dbsize=300K--public-queries-2024-laion2B-en-clip768v2-n=10k.h5'

f = h5py.File(laion_path, 'r')
laion_data = np.array(f["emb"]).astype("float32") # shape: (300000, 768)

f = h5py.File(laion_gold_queries_path) 
laion_gold_queries = np.array(f["emb"]).astype("float32") # shape: (10000, 768)

f = h5py.File(laion_gold_path) 
laion_gold_dists = np.array(f["dists"]).astype("float32") # shape: (10000, 1000), seem to be sorted in ascending order
laion_gold_knns = np.array(f["knns"]).astype("uint64") - 1 # -1 bc 1-based indexing in downloaded data. shape: (10000, 1000), same shape as dists

In [2]:
q = laion_gold_queries[0]

best_i = 1000000000
max_dot = -33333333
for i in range(laion_data.shape[0]):
    d = np.dot(laion_data[i,:], q)
    if d > max_dot:
        max_dot = d
        best_i = i

print("predicted: ", best_i)
print("from download:", laion_gold_knns[0].shape)

# data = np.random.random((1000,768)).astype("float32")
# queries = np.random.random((300,768)).astype("float32")
ds = vecnn.Dataset(laion_data)
res = vecnn.linear_knn(ds, q, 1000, "dot")
print("from vecnn: ", res.indices[:20])

predicted:  224632
from download: (1000,)
from vecnn:  [224632 123279   9918    465 134211 227188 274488 176266 170212 134008
  58730 293160 158095 242165  35680 260287 254897 183112  80773 167011]


In [16]:
def overlap(a, b):
    assert(len(a) == len(b))
    both = np.intersect1d(a, b, assume_unique=False, return_indices=False)
    return len(both) / len(a)

DATA_N = 10000
np.random.seed(42)
small_laion_data = laion_data[np.random.choice(laion_data.shape[0], DATA_N, replace=False)]

q = laion_gold_queries[0]

ds_small = vecnn.Dataset(small_laion_data)
true_res = vecnn.linear_knn(ds_small, q, 1000, "dot").indices

print(overlap(true_res, true_res))

# vptree = vecnn.VpTree(ds_small)
# res = vptree.knn(q, 1000).indices
# print("vp_tree: ", overlap(res, true_res))

# hnsw = vecnn.Hnsw(ds_small, 0.5, 40, 10, 10, "dot")
# res = hnsw.knn(q, 1000).indices
# print("hnsw: ", overlap(res, true_res))

# rustcv_hnsw = vecnn.RustCvHnsw(ds_small, 40)
# res = rustcv_hnsw.knn(q, 1000, 1000).indices
# print("rustcv_hnsw: ", overlap(res, true_res))


rnn_graph = vecnn.RNNGraph(ds_small, 10, 3, 10, 20, "dot")
res = rnn_graph.knn(q, 1000, 10).indices
print(res)


1.0
[4030 7019 2350  757 7550 4651 8627 8605 4191 5212 4879 9553 3051 5219
 8549 5828 1528 9308 9593 2081 6356 2035  647 9539 3925 8349 5236 1136
 9949 4783 6671 2531 8806 2989 4317 9655 7278 7050 1027 5381 3060 1731
 4533 6845 2179 8271 9325 6825 7992 8194 7445 2169 6340 4199 5940 3268
 3061 7150    4 5648 7564 2996 9943 3491 1982 8293 2548 9287 4367 7340
 5536 4168 2764 9919 6704 8938 6738 8760 5909 2557 6195 6420  920 9376
 2297 5653 4810 8993 2543 8138 2381 5887 9265 8466 6734 7017 5260 5822
 7437 6293 2380 1659 7356 4789 2948]
