Skip to content

Commit

Permalink
Merge pull request #49 from Dobatymo/bktree-add-indices
Browse files Browse the repository at this point in the history
bktree return indices and change handling for duplicates
  • Loading branch information
Dobatymo committed Sep 1, 2023
2 parents 6eeab30 + be1ac95 commit fe0a820
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 57 deletions.
52 changes: 29 additions & 23 deletions pynear/include/BKTree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#include <map>
#include <optional>

typedef int64_t index_t;

template <typename key_t, typename distance_t> class Metric {
public:
static distance_t distance(const key_t &a, const key_t &b);
Expand All @@ -11,18 +13,19 @@ template <typename key_t, typename distance_t> class Metric {
template <typename key_t, typename distance_t> class BKNode {
public:
key_t key;
index_t index;
std::map<distance_t, BKNode<key_t, distance_t> *> leaves;
std::optional<distance_t> max_distance;

BKNode(key_t key) : key(key) {}
BKNode(key_t key, index_t index) : key(key), index(index) {}

void add_leaf(distance_t distance, key_t key) {
leaves[distance] = new BKNode<key_t, distance_t>(key);
void add_leaf(distance_t distance, key_t key, index_t index) {
leaves[distance] = new BKNode<key_t, distance_t>(key, index);
max_distance = std::max<distance_t>(distance, max_distance.value_or(0));
}

void clear() {
for (auto const &[d, bknode] : leaves) {
for (auto [d, bknode] : leaves) {
delete bknode;
}
leaves.clear();
Expand All @@ -33,33 +36,30 @@ template <typename key_t, typename distance_t> class BKNode {

template <typename key_t, typename distance_t, typename metric> class BKTree {
BKNode<key_t, distance_t> *root;
size_t index;

public:
BKTree() : root(nullptr) {}
BKTree() : root(nullptr), index(0) {}

bool add(key_t key) {
void add(key_t key) {
if (root == nullptr) {
root = new BKNode<key_t, distance_t>(key);
root = new BKNode<key_t, distance_t>(key, index);
} else {
BKNode<key_t, distance_t> *node = root;
distance_t dist;

while (true) {
dist = metric::distance(node->key, key);
auto next_it = node->leaves.find(dist);
if (next_it == node->leaves.end() || dist == 0) {
if (next_it == node->leaves.end()) {
break;
}
node = next_it->second;
}

if (dist > 0) {
node->add_leaf(dist, key);
} else {
return false; // didn't insert key
}
node->add_leaf(dist, key, index);
}
return true; // inserted new key
++index;
}

void update(std::vector<key_t> keys) {
Expand All @@ -68,15 +68,16 @@ template <typename key_t, typename distance_t, typename metric> class BKTree {
}
}

std::tuple<std::vector<distance_t>, std::vector<key_t>> find(key_t key, distance_t threshold) {
std::tuple<std::vector<index_t>, std::vector<distance_t>, std::vector<key_t>> find(key_t key, distance_t threshold) {
static_assert(std::is_signed<distance_t>::value, "Arithmetic required signed distances");

BKNode<key_t, distance_t> *node = root;
std::vector<index_t> indices;
std::vector<distance_t> distances;
std::vector<key_t> keys;

if (node == nullptr) {
return std::make_tuple(distances, keys);
return std::make_tuple(indices, distances, keys);
}

std::deque<BKNode<key_t, distance_t> *> candidates = {node};
Expand All @@ -97,37 +98,40 @@ template <typename key_t, typename distance_t, typename metric> class BKTree {
dist = dist_opt.value();

if (dist <= threshold) {
indices.push_back(candidate->index);
distances.push_back(dist);
keys.push_back(candidate->key);
}

lower = dist - threshold;
upper = dist + threshold;
for (auto const &[d, bknode] : candidate->leaves) {
for (auto [d, bknode] : candidate->leaves) {
if (lower <= d && d <= upper) {
candidates.push_back(bknode);
}
}
}
return std::make_tuple(distances, keys);
return std::make_tuple(indices, distances, keys);
}

std::tuple<std::vector<std::vector<distance_t>>, std::vector<std::vector<key_t>>> find_batch(const std::vector<key_t> &keys,
distance_t threshold) {
std::tuple<std::vector<std::vector<index_t>>, std::vector<std::vector<distance_t>>, std::vector<std::vector<key_t>>>
find_batch(const std::vector<key_t> &keys, distance_t threshold) {
std::vector<std::vector<index_t>> indices_out(keys.size());
std::vector<std::vector<distance_t>> distances_out(keys.size());
std::vector<std::vector<key_t>> keys_out(keys.size());

#if (ENABLE_OMP_PARALLEL)
#pragma omp parallel for schedule(static, 1)
#endif
// i should be size_t, however msvc requires signed integral loop variables (except with -openmp:llvm)
for (int i = 0; i < keys.size(); ++i) {
auto const &[distances_res, keys_res] = find(keys[i], threshold);
for (int i = 0; i < static_cast<int>(keys.size()); ++i) {
auto &&[indices_res, distances_res, keys_res] = find(keys[i], threshold);
indices_out[i] = std::move(indices_res);
distances_out[i] = std::move(distances_res);
keys_out[i] = std::move(keys_res);
}

return std::make_tuple(distances_out, keys_out);
return std::make_tuple(indices_out, distances_out, keys_out);
}

bool empty() { return root == nullptr; }
Expand Down Expand Up @@ -156,5 +160,7 @@ template <typename key_t, typename distance_t, typename metric> class BKTree {
root = nullptr;
}

size_t size() { return static_cast<size_t>(index); }

virtual ~BKTree() { clear(); }
};
18 changes: 14 additions & 4 deletions pynear/src/PythonBindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,23 @@ template <distance_func_li distance_f> class HammingMetric : Metric<arrayli, int

template <distance_func_li distance> class BKTreeBinaryNumpyAdapter {
public:
BKTree<arrayli, int64_t, HammingMetric<distance>> tree;
typedef arrayli key_t;
typedef int64_t distance_t;

BKTree<arrayli, distance_t, HammingMetric<distance>> tree;

BKTreeBinaryNumpyAdapter() = default;

void set(const ndarrayli &array) { tree.update(array); }
void set(const std::vector<key_t> &array) { tree.update(array); }

std::tuple<std::vector<std::vector<int64_t>>, std::vector<ndarrayli>> find_threshold(const ndarrayli &queries, int64_t threshold) {
std::tuple<std::vector<std::vector<index_t>>, std::vector<std::vector<distance_t>>, std::vector<std::vector<key_t>>>
find_threshold(const std::vector<key_t> &queries, distance_t threshold) {
return tree.find_batch(queries, threshold);
}

bool empty() { return tree.empty(); }
ndarrayli values() { return tree.values(); }
size_t size() { return tree.size(); }
std::vector<key_t> values() { return tree.values(); }
};

static const char *index_set = "Add vectors to index";
Expand Down Expand Up @@ -235,6 +240,7 @@ PYBIND11_MODULE(_pynear, m) {
.def("find_threshold", &BKTreeBinaryNumpyAdapter<dist_hamming_512>::find_threshold, index_find_threshold, py::arg("vectors"),
py::arg("threshold"))
.def("empty", &BKTreeBinaryNumpyAdapter<dist_hamming_512>::empty)
.def("size", &BKTreeBinaryNumpyAdapter<dist_hamming_512>::size)
.def("values", &BKTreeBinaryNumpyAdapter<dist_hamming_512>::values, index_values);

py::class_<BKTreeBinaryNumpyAdapter<dist_hamming_256>>(m, "BKTreeBinaryIndex256")
Expand All @@ -243,6 +249,7 @@ PYBIND11_MODULE(_pynear, m) {
.def("find_threshold", &BKTreeBinaryNumpyAdapter<dist_hamming_256>::find_threshold, index_find_threshold, py::arg("vectors"),
py::arg("threshold"))
.def("empty", &BKTreeBinaryNumpyAdapter<dist_hamming_256>::empty)
.def("size", &BKTreeBinaryNumpyAdapter<dist_hamming_256>::size)
.def("values", &BKTreeBinaryNumpyAdapter<dist_hamming_256>::values, index_values);

py::class_<BKTreeBinaryNumpyAdapter<dist_hamming_128>>(m, "BKTreeBinaryIndex128")
Expand All @@ -251,6 +258,7 @@ PYBIND11_MODULE(_pynear, m) {
.def("find_threshold", &BKTreeBinaryNumpyAdapter<dist_hamming_128>::find_threshold, index_find_threshold, py::arg("vectors"),
py::arg("threshold"))
.def("empty", &BKTreeBinaryNumpyAdapter<dist_hamming_128>::empty)
.def("size", &BKTreeBinaryNumpyAdapter<dist_hamming_128>::size)
.def("values", &BKTreeBinaryNumpyAdapter<dist_hamming_128>::values, index_values);

py::class_<BKTreeBinaryNumpyAdapter<dist_hamming_64>>(m, "BKTreeBinaryIndex64")
Expand All @@ -259,6 +267,7 @@ PYBIND11_MODULE(_pynear, m) {
.def("find_threshold", &BKTreeBinaryNumpyAdapter<dist_hamming_64>::find_threshold, index_find_threshold, py::arg("vectors"),
py::arg("threshold"))
.def("empty", &BKTreeBinaryNumpyAdapter<dist_hamming_64>::empty)
.def("size", &BKTreeBinaryNumpyAdapter<dist_hamming_64>::size)
.def("values", &BKTreeBinaryNumpyAdapter<dist_hamming_64>::values, index_values);

py::class_<BKTreeBinaryNumpyAdapter<dist_hamming>>(m, "BKTreeBinaryIndex")
Expand All @@ -267,5 +276,6 @@ PYBIND11_MODULE(_pynear, m) {
.def("find_threshold", &BKTreeBinaryNumpyAdapter<dist_hamming>::find_threshold, index_find_threshold, py::arg("vectors"),
py::arg("threshold"))
.def("empty", &BKTreeBinaryNumpyAdapter<dist_hamming>::empty)
.def("size", &BKTreeBinaryNumpyAdapter<dist_hamming>::size)
.def("values", &BKTreeBinaryNumpyAdapter<dist_hamming>::values, index_values);
};
5 changes: 2 additions & 3 deletions pynear/tests/VPTreeTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

#include <Eigen/Core>
#include <chrono>
#include <exception>
#include <iostream>
#include <random>
#include <sstream>
#include <vector>
#include <exception>
#include <stdint.h>
#include <vector>

#if defined(_MSC_VER)
#include <intrin.h>
Expand Down Expand Up @@ -122,7 +122,6 @@ TEST(VPTests, TestEmpty) {
VPTree<Eigen::Vector3d, float, distance> nonEmpty;
nonEmpty.set(queries);
EXPECT_NO_THROW(nonEmpty.search1NN(queries, indices, distances));

}

TEST(VPTests, TestToString) {
Expand Down
36 changes: 32 additions & 4 deletions pynear/tests/test_bktree.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,20 @@ def test_bktree_empty_index(bktree_cls, dimensions):
empty = np.array([], dtype=np.uint8)

tree = bktree_cls()
distances, keys = tree.find_threshold(data, 1)
indices, distances, keys = tree.find_threshold(data, 1)
truth = [[]] * num_points
assert truth == indices
assert truth == distances
assert tree.empty()
assert tree.values() == []

tree.set(empty)
distances, keys = tree.find_threshold(data, 1)
indices, distances, keys = tree.find_threshold(data, 1)
truth = [[]] * num_points
assert truth == indices
assert truth == distances
assert tree.empty()
assert tree.values() == []


@pytest.mark.parametrize("bktree_cls, dimensions", CLASSES)
Expand All @@ -41,9 +47,12 @@ def test_bktree_find_self(bktree_cls, dimensions):

tree = bktree_cls()
tree.set(data)
distances, keys = tree.find_threshold(data, 0)
indices, distances, keys = tree.find_threshold(data, 0)
assert indices == [[i] for i in range(num_points)]
assert distances == [[0]] * num_points
assert keys == data[:, None, :].tolist()
assert tree.size() == num_points
assert sorted(tree.values()) == sorted(data.tolist())


@pytest.mark.parametrize("bktree_cls, dimensions", CLASSES)
Expand All @@ -53,7 +62,26 @@ def test_bktree_find_all(bktree_cls, dimensions):

tree = bktree_cls()
tree.set(data)
distances, keys = tree.find_threshold(data, 255)
indices, distances, keys = tree.find_threshold(data, 255)

assert indices == [list(range(num_points))] * num_points
assert distances == hamming_distance_pairwise(data, data).tolist()
assert keys == np.broadcast_to(data, (num_points, num_points, dimensions)).tolist()
assert tree.size() == num_points
assert sorted(tree.values()) == sorted(data.tolist())


@pytest.mark.parametrize("bktree_cls, dimensions", CLASSES)
def test_bktree_find_duplicates(bktree_cls, dimensions):
num_points = 2
data = np.zeros((num_points, dimensions), dtype=np.uint8)

tree = bktree_cls()
tree.set(data)
indices, distances, keys = tree.find_threshold(data, 255)

assert indices == [list(range(num_points))] * num_points
assert distances == [[0] * num_points] * num_points
assert keys == np.broadcast_to(data, (num_points, num_points, dimensions)).tolist()
assert tree.size() == num_points
assert sorted(tree.values()) == sorted(data.tolist())
33 changes: 10 additions & 23 deletions pynear/tests/test_vptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,16 @@ def _num_dups(distances):
]


def test_binary():
@pytest.mark.parametrize("num_points, k", [(2021, 2), (40021, 3)])
def test_binary(num_points, k):
np.random.seed(seed=42)

dimension = 32
num_points = 2021
data = np.random.normal(scale=255, loc=0, size=(num_points, dimension)).astype(dtype=np.uint8)

num_queries = 8
queries = np.random.normal(scale=255, loc=0, size=(num_queries, dimension)).astype(dtype=np.uint8)

k = 2

exaustive_indices, exaustive_distances = exhaustive_search_hamming(data, queries, k)

vptree = pynear.VPTreeBinaryIndex()
Expand All @@ -140,33 +138,22 @@ def test_binary():
vptree_distances = np.array(vptree_distances, dtype=np.int64)[:, ::-1]

assert np.array_equal(exaustive_distances, vptree_distances)
# assert np.array_equal(exaustive_indices, vptree_indices) # indices order can vary for same distances
# assert np.array_equal(exaustive_indices, vptree_indices) # indices order can vary for same distances


def test_large_binary():
np.random.seed(seed=42)

def test_binary_duplicates():
dimension = 32
num_points = 40021
data = np.random.normal(scale=255, loc=0, size=(num_points, dimension)).astype(dtype=np.uint8)
num_points = 2
data = np.zeros((num_points, dimension), dtype=np.uint8)

num_queries = 8
queries = np.random.normal(scale=255, loc=0, size=(num_queries, dimension)).astype(dtype=np.uint8)

k = 3

exaustive_indices, exaustive_distances = exhaustive_search_hamming(data, queries, k)
k = 2

vptree = pynear.VPTreeBinaryIndex()
vptree.set(data)
vptree_indices, vptree_distances = vptree.searchKNN(queries, k)
indices, distances = vptree.searchKNN(data, k)

vptree_indices = np.array(vptree_indices, dtype=np.uint64)[:, ::-1]
vptree_distances = np.array(vptree_distances, dtype=np.int64)[:, ::-1]

assert np.array_equal(exaustive_distances, vptree_distances)
if _num_dups(exaustive_distances) == 0:
assert np.array_equal(exaustive_indices, vptree_indices) # indices order can vary for same distances
assert [sorted(i) for i in indices] == [list(range(num_points))] * num_points
assert distances == [[0] * k] * num_points


@pytest.mark.parametrize("vptree_cls, exaustive_metric", CLASSES)
Expand Down

0 comments on commit fe0a820

Please sign in to comment.