diff --git a/pyvptree/include/DistanceFunctions.hpp b/pyvptree/include/DistanceFunctions.hpp index c819289..df9eb2b 100644 --- a/pyvptree/include/DistanceFunctions.hpp +++ b/pyvptree/include/DistanceFunctions.hpp @@ -19,8 +19,6 @@ using ndarrayd = std::vector; using ndarrayf = std::vector; using ndarrayli = std::vector; -typedef float hamdis_t; - #if defined(_MSC_VER) #define ALIGN_16 __declspec(align(16)) #elif defined(__GNUC__) @@ -28,21 +26,21 @@ typedef float hamdis_t; #endif /* Hamming distances for multiples of 64 bits */ -template hamdis_t hamming(const uint64_t *bs1, const uint64_t *bs2) { +template int64_t hamming(const uint64_t *bs1, const uint64_t *bs2) { const size_t nwords = nbits / 64; size_t i; - hamdis_t h = 0; + int64_t h = 0; for (i = 0; i < nwords; i++) h += _mm_popcnt_u64(bs1[i] ^ bs2[i]); return h; } /* specialized (optimized) functions */ -template <> hamdis_t hamming<64>(const uint64_t *pa, const uint64_t *pb) { return _mm_popcnt_u64(pa[0] ^ pb[0]); } +template <> int64_t hamming<64>(const uint64_t *pa, const uint64_t *pb) { return _mm_popcnt_u64(pa[0] ^ pb[0]); } -template <> hamdis_t hamming<128>(const uint64_t *pa, const uint64_t *pb) { return _mm_popcnt_u64(pa[0] ^ pb[0]) + _mm_popcnt_u64(pa[1] ^ pb[1]); } +template <> int64_t hamming<128>(const uint64_t *pa, const uint64_t *pb) { return _mm_popcnt_u64(pa[0] ^ pb[0]) + _mm_popcnt_u64(pa[1] ^ pb[1]); } -template <> hamdis_t hamming<256>(const uint64_t *pa, const uint64_t *pb) { +template <> int64_t hamming<256>(const uint64_t *pa, const uint64_t *pb) { return _mm_popcnt_u64(pa[0] ^ pb[0]) + _mm_popcnt_u64(pa[1] ^ pb[1]) + _mm_popcnt_u64(pa[2] ^ pb[2]) + _mm_popcnt_u64(pa[3] ^ pb[3]); } @@ -182,7 +180,7 @@ float distL2f(const arrayf &p1, const arrayf &p2) { return std::sqrt(result); } -hamdis_t distHamming(const arrayli &p1, const arrayli &p2) { +int64_t distHamming(const arrayli &p1, const arrayli &p2) { return hamming<256>(reinterpret_cast(&p1[0]), reinterpret_cast(&p2[0])); } diff --git a/pyvptree/include/VPTree.hpp b/pyvptree/include/VPTree.hpp index 4a89847..e93c98a 100644 --- a/pyvptree/include/VPTree.hpp +++ b/pyvptree/include/VPTree.hpp @@ -20,9 +20,9 @@ namespace vptree { -class VPLevelPartition { +template class VPLevelPartition { public: - VPLevelPartition(float radius, unsigned int start, unsigned int end) { + VPLevelPartition(distance_type radius, unsigned int start, unsigned int end) { // For each partition, the vantage point is the first point within the partition (pointed by indexStart) _radius = radius; @@ -44,16 +44,16 @@ class VPLevelPartition { unsigned int start() { return _indexStart; } unsigned int end() { return _indexEnd; } unsigned int size() { return _indexEnd - _indexStart + 1; } - void setRadius(float radius) { _radius = radius; } - float radius() { return _radius; } + void setRadius(distance_type radius) { _radius = radius; } + distance_type radius() { return _radius; } - void setChild(VPLevelPartition *left, VPLevelPartition *right) { + void setChild(VPLevelPartition *left, VPLevelPartition *right) { _left = left; _right = right; } - VPLevelPartition *left() { return _left; } - VPLevelPartition *right() { return _right; } + VPLevelPartition *left() { return _left; } + VPLevelPartition *right() { return _right; } private: void clear() { @@ -67,7 +67,7 @@ class VPLevelPartition { _right = nullptr; } - float _radius; + distance_type _radius; // _indexStart and _indexEnd are index pointers to examples within the examples list, not index of coordinates // within the coordinate buffer.For instance, _indexEnd pointing to last element of a coordinate buffer of 9 entries @@ -76,11 +76,11 @@ class VPLevelPartition { unsigned int _indexStart; // points to the first of the example in which this level starts unsigned int _indexEnd; - VPLevelPartition *_left = nullptr; - VPLevelPartition *_right = nullptr; + VPLevelPartition *_left = nullptr; + VPLevelPartition *_right = nullptr; }; -template class VPTree { +template class VPTree { public: struct VPTreeElement { @@ -96,7 +96,7 @@ template class VPTree { struct VPTreeSearchResultElement { std::vector indexes; - std::vector distances; + std::vector distances; }; VPTree() = default; @@ -137,7 +137,7 @@ template class VPTree { } // An optimized version for 1 NN search - void search1NN(const std::vector &queries, std::vector &indices, std::vector &distances) { + void search1NN(const std::vector &queries, std::vector &indices, std::vector &distances) { if (_rootPartition == nullptr) { return; @@ -152,7 +152,7 @@ template class VPTree { #endif for (int i = 0; i < queries.size(); ++i) { const T &query = queries[i]; - float dist = 0; + distance_type dist = 0; unsigned int index = -1; search1NN(_rootPartition, query, index, dist); distances[i] = dist; @@ -168,15 +168,15 @@ template class VPTree { void build(const std::vector &array) { // Select vantage point - std::vector _toSplit; + std::vector *> _toSplit; - auto *root = new VPLevelPartition(0, 0, _examples.size() - 1); + auto *root = new VPLevelPartition(0, 0, _examples.size() - 1); _toSplit.push_back(root); _rootPartition = root; while (!_toSplit.empty()) { - VPLevelPartition *current = _toSplit.back(); + VPLevelPartition *current = _toSplit.back(); _toSplit.pop_back(); unsigned int start = current->start(); @@ -199,20 +199,20 @@ template class VPTree { VPDistanceComparator(_examples[start])); /* // distance from vantage point (which is at start index) and the median element */ - float medianDistance = distance(_examples[start].val, _examples[median].val); + auto medianDistance = distance(_examples[start].val, _examples[median].val); current->setRadius(medianDistance); // Schedule to build next levels // Left is every one within the median distance radius - VPLevelPartition *left = nullptr; + VPLevelPartition *left = nullptr; if (start + 1 <= median) { - left = new VPLevelPartition(-1, start + 1, median); + left = new VPLevelPartition(-1, start + 1, median); _toSplit.push_back(left); } - VPLevelPartition *right = nullptr; + VPLevelPartition *right = nullptr; if (median + 1 <= end) { - right = new VPLevelPartition(-1, median + 1, end); + right = new VPLevelPartition(-1, median + 1, end); _toSplit.push_back(right); } @@ -222,17 +222,16 @@ template class VPTree { // Internal temporary struct to organize K closest elements in a priorty queue struct VPTreeSearchElement { - VPTreeSearchElement(int index, float dist) : index(index), dist(dist) {} + VPTreeSearchElement(int index, distance_type dist) : index(index), dist(dist) {} int index; - float dist; + distance_type dist; bool operator<(const VPTreeSearchElement &v) const { return dist < v.dist; } }; - void exaustivePartitionSearch(VPLevelPartition *partition, const T &val, unsigned int k, std::priority_queue &knnQueue, - float tau) { + void exaustivePartitionSearch(VPLevelPartition *partition, const T &val, unsigned int k, std::priority_queue &knnQueue, distance_type tau) { for (int i = partition->start(); i <= partition->end(); ++i) { - float dist = distance(val, _examples[i].val); + auto dist = distance(val, _examples[i].val); if (dist < tau || knnQueue.size() < k) { if (knnQueue.size() == k) { @@ -246,20 +245,20 @@ template class VPTree { } } - void searchKNN(VPLevelPartition *partition, const T &val, unsigned int k, std::priority_queue &knnQueue) { + void searchKNN(VPLevelPartition *partition, const T &val, unsigned int k, std::priority_queue &knnQueue) { - float tau = std::numeric_limits::max(); + auto tau = std::numeric_limits::max(); // stores the distance to the partition border at the time of the storage. Since tau value will change // whiling performing the DFS search from on level, the storage distance will be checked again when about // to dive into that partition. It might not be necessary to dig into the partition anymore if tau decreased. - std::vector> toSearch = {{-1, partition}}; + std::vector *>> toSearch = {{-1, partition}}; while (!toSearch.empty()) { auto [distToBorder, current] = toSearch.back(); toSearch.pop_back(); - float dist = distance(val, _examples[current->start()].val); + auto dist = distance(val, _examples[current->start()].val); if (dist < tau || knnQueue.size() < k) { if (knnQueue.size() == k) { @@ -298,7 +297,7 @@ template class VPTree { unsigned int rightPartitionSize = (current->right() != nullptr) ? current->right()->size() : 0; bool notEnoughPointsOutside = rightPartitionSize < (k - neighborsSoFar); - float toBorder = dist - current->radius(); + auto toBorder = dist - current->radius(); // we might not have enough point outside to reject the inside partition, so we might need to search // for both @@ -321,7 +320,7 @@ template class VPTree { unsigned int leftPartitionSize = (current->left() != nullptr) ? current->left()->size() : 0; bool notEnoughPointsInside = leftPartitionSize < (k - neighborsSoFar); - float toBorder = current->radius() - dist; + auto toBorder = current->radius() - dist; if (notEnoughPointsInside) { toSearch.push_back({-1, current->right()}); @@ -338,19 +337,19 @@ template class VPTree { } } - void search1NN(VPLevelPartition *partition, const T &val, unsigned int &resultIndex, float &resultDist) { + void search1NN(VPLevelPartition *partition, const T &val, unsigned int &resultIndex, distance_type &resultDist) { - resultDist = std::numeric_limits::max(); + resultDist = std::numeric_limits::max(); resultIndex = -1; - std::vector> toSearch = {{-1, partition}}; + std::vector *>> toSearch = {{-1, partition}}; while (!toSearch.empty()) { auto [distToBorder, current] = toSearch.back(); toSearch.pop_back(); - float dist = distance(val, _examples[current->start()].val); + auto dist = distance(val, _examples[current->start()].val); if (dist < resultDist) { resultDist = dist; resultIndex = _examples[current->start()].originalIndex; @@ -364,7 +363,7 @@ template class VPTree { if (dist > current->radius()) { // may need to search inside as well - float toBorder = dist - current->radius(); + auto toBorder = dist - current->radius(); if (toBorder < resultDist && current->left() != nullptr) { toSearch.push_back({toBorder, current->left()}); } @@ -374,7 +373,7 @@ template class VPTree { toSearch.push_back({-1, current->right()}); } } else { - float toBorder = current->radius() - dist; + auto toBorder = current->radius() - dist; // may need to search outside as well if (toBorder < resultDist && current->right() != nullptr) { toSearch.push_back({toBorder, current->right()}); @@ -427,7 +426,7 @@ template class VPTree { protected: std::vector _examples; - VPLevelPartition *_rootPartition = nullptr; + VPLevelPartition *_rootPartition = nullptr; }; } // namespace vptree diff --git a/pyvptree/src/PythonBindings.cpp b/pyvptree/src/PythonBindings.cpp index 8326c69..3b0b923 100644 --- a/pyvptree/src/PythonBindings.cpp +++ b/pyvptree/src/PythonBindings.cpp @@ -17,11 +17,11 @@ class VPTreeNumpyAdapter { public: VPTreeNumpyAdapter() = default; - void set(const ndarrayf &array) { _tree = vptree::VPTree(array); } + void set(const ndarrayf &array) { _tree = vptree::VPTree(array); } std::tuple>, std::vector>> searchKNN(const ndarrayf &queries, unsigned int k) { - std::vector::VPTreeSearchResultElement> results; + std::vector::VPTreeSearchResultElement> results; _tree.searchKNN(queries, k, results); std::vector> indexes; @@ -46,22 +46,22 @@ class VPTreeNumpyAdapter { } private: - vptree::VPTree _tree; + vptree::VPTree _tree; }; class VPTreeBinaryNumpyAdapter { public: VPTreeBinaryNumpyAdapter() = default; - void set(const ndarrayli &array) { _tree = vptree::VPTree(array); } + void set(const ndarrayli &array) { _tree = vptree::VPTree(array); } - std::tuple>, std::vector>> searchKNN(const ndarrayli &queries, unsigned int k) { + std::tuple>, std::vector>> searchKNN(const ndarrayli &queries, unsigned int k) { - std::vector::VPTreeSearchResultElement> results; + std::vector::VPTreeSearchResultElement> results; _tree.searchKNN(queries, k, results); std::vector> indexes; - std::vector> distances; + std::vector> distances; indexes.resize(results.size()); distances.resize(results.size()); for (int i = 0; i < results.size(); ++i) { @@ -71,17 +71,17 @@ class VPTreeBinaryNumpyAdapter { return std::make_tuple(indexes, distances); } - std::tuple, std::vector> search1NN(const ndarrayli &queries) { + std::tuple, std::vector> search1NN(const ndarrayli &queries) { std::vector indices; - std::vector distances; + std::vector distances; _tree.search1NN(queries, indices, distances); return std::make_tuple(std::move(indices), std::move(distances)); } private: - vptree::VPTree _tree; + vptree::VPTree _tree; }; PYBIND11_MODULE(_pyvptree, m) { diff --git a/pyvptree/tests/test_vptree.py b/pyvptree/tests/test_vptree.py index 4e4c238..dec6dd7 100644 --- a/pyvptree/tests/test_vptree.py +++ b/pyvptree/tests/test_vptree.py @@ -81,7 +81,7 @@ def test_binary(): vptree.set(data) vptree_indices, vptree_distances = vptree.searchKNN(queries, k) - vptree_indices = np.array(vptree_indices, dtype=np.int64)[:, ::-1] + vptree_indices = np.array(vptree_indices, dtype=np.uint64)[:, ::-1] vptree_distances = np.array(vptree_distances, dtype=np.int64)[:, ::-1] assert np.array_equal(exaustive_distances, vptree_distances) @@ -106,7 +106,7 @@ def test_large_binary(): vptree.set(data) vptree_indices, vptree_distances = vptree.searchKNN(queries, k) - vptree_indices = np.array(vptree_indices, dtype=np.int64)[:, ::-1] + vptree_indices = np.array(vptree_indices, dtype=np.uint64)[:, ::-1] vptree_distances = np.array(vptree_distances, dtype=np.int64)[:, ::-1] assert np.array_equal(exaustive_distances, vptree_distances) @@ -132,7 +132,7 @@ def test_k_equals_dataset(): vptree.set(data) vptree_indices, vptree_distances = vptree.searchKNN(queries, k) - vptree_indices = np.array(vptree_indices, dtype=np.int64)[:, ::-1] + vptree_indices = np.array(vptree_indices, dtype=np.uint64)[:, ::-1] vptree_distances = np.array(vptree_distances, dtype=np.float32)[:, ::-1] np.testing.assert_allclose(exaustive_distances, vptree_distances, rtol=1e-06) @@ -158,7 +158,7 @@ def test_large_dataset(): vptree.set(data) vptree_indices, vptree_distances = vptree.searchKNN(queries, k) - vptree_indices = np.array(vptree_indices, dtype=np.int64)[:, ::-1] + vptree_indices = np.array(vptree_indices, dtype=np.uint64)[:, ::-1] vptree_distances = np.array(vptree_distances, dtype=np.float32)[:, ::-1] vptree_distances2 = np.sort(vptree_distances, axis=-1) @@ -186,7 +186,7 @@ def test_large_dataset_highdim(): vptree.set(data) vptree_indices, vptree_distances = vptree.searchKNN(queries, k) - vptree_indices = np.array(vptree_indices, dtype=np.int64)[:, ::-1] + vptree_indices = np.array(vptree_indices, dtype=np.uint64)[:, ::-1] vptree_distances = np.array(vptree_distances, dtype=np.float32)[:, ::-1] vptree_distances2 = np.sort(vptree_distances, axis=-1) @@ -211,7 +211,7 @@ def test_dataset_split_less_than_k(): vptree.set(data) vptree_indices, vptree_distances = vptree.searchKNN(queries, k) - vptree_indices = np.array(vptree_indices, dtype=np.int64)[:, ::-1] + vptree_indices = np.array(vptree_indices, dtype=np.uint64)[:, ::-1] vptree_distances = np.array(vptree_distances, dtype=np.float32)[:, ::-1] assert np.array_equal(exaustive_indices, vptree_indices) @@ -236,7 +236,7 @@ def test_query_larger_than_dataset(): vptree.set(data) vptree_indices, vptree_distances = vptree.searchKNN(queries, k) - vptree_indices = np.array(vptree_indices, dtype=np.int64)[:, ::-1] + vptree_indices = np.array(vptree_indices, dtype=np.uint64)[:, ::-1] vptree_distances = np.array(vptree_distances, dtype=np.float32)[:, ::-1] assert np.array_equal(exaustive_indices, vptree_indices) @@ -261,7 +261,7 @@ def test_compare_with_exaustive_knn(): vptree.set(data) vptree_indices, vptree_distances = vptree.searchKNN(queries, k) - vptree_indices = np.array(vptree_indices, dtype=np.int64)[:, ::-1] + vptree_indices = np.array(vptree_indices, dtype=np.uint64)[:, ::-1] vptree_distances = np.array(vptree_distances, dtype=np.float32)[:, ::-1] assert np.array_equal(exaustive_indices, vptree_indices) @@ -287,7 +287,7 @@ def test_compare_with_exaustive_1nn(): vptree.set(data) vptree_indices, vptree_distances = vptree.search1NN(queries) - vptree_indices = np.array(vptree_indices, dtype=np.int64) + vptree_indices = np.array(vptree_indices, dtype=np.uint64) vptree_distances = np.array(vptree_distances, dtype=np.float32) assert np.array_equal(exaustive_indices, vptree_indices)