Skip to content

Commit

Permalink
Merge pull request #23 from Dobatymo/generic-typing
Browse files Browse the repository at this point in the history
Support other distance types apart from float
  • Loading branch information
Dobatymo committed Aug 3, 2023
2 parents f764e56 + 9eca9f7 commit 61cc736
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 67 deletions.
14 changes: 6 additions & 8 deletions pyvptree/include/DistanceFunctions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,28 @@ using ndarrayd = std::vector<arrayd>;
using ndarrayf = std::vector<arrayf>;
using ndarrayli = std::vector<arrayli>;

typedef float hamdis_t;

#if defined(_MSC_VER)
#define ALIGN_16 __declspec(align(16))
#elif defined(__GNUC__)
#define ALIGN_16 __attribute__((__aligned__(16)))
#endif

/* Hamming distances for multiples of 64 bits */
template <size_t nbits> hamdis_t hamming(const uint64_t *bs1, const uint64_t *bs2) {
template <size_t nbits> int64_t hamming(const uint64_t *bs1, const uint64_t *bs2) {
const size_t nwords = nbits / 64;
size_t i;
hamdis_t h = 0;
int64_t h = 0;
for (i = 0; i < nwords; i++)
h += _mm_popcnt_u64(bs1[i] ^ bs2[i]);
return h;
}

/* specialized (optimized) functions */
template <> hamdis_t hamming<64>(const uint64_t *pa, const uint64_t *pb) { return _mm_popcnt_u64(pa[0] ^ pb[0]); }
template <> int64_t hamming<64>(const uint64_t *pa, const uint64_t *pb) { return _mm_popcnt_u64(pa[0] ^ pb[0]); }

template <> hamdis_t hamming<128>(const uint64_t *pa, const uint64_t *pb) { return _mm_popcnt_u64(pa[0] ^ pb[0]) + _mm_popcnt_u64(pa[1] ^ pb[1]); }
template <> int64_t hamming<128>(const uint64_t *pa, const uint64_t *pb) { return _mm_popcnt_u64(pa[0] ^ pb[0]) + _mm_popcnt_u64(pa[1] ^ pb[1]); }

template <> hamdis_t hamming<256>(const uint64_t *pa, const uint64_t *pb) {
template <> int64_t hamming<256>(const uint64_t *pa, const uint64_t *pb) {
return _mm_popcnt_u64(pa[0] ^ pb[0]) + _mm_popcnt_u64(pa[1] ^ pb[1]) + _mm_popcnt_u64(pa[2] ^ pb[2]) + _mm_popcnt_u64(pa[3] ^ pb[3]);
}

Expand Down Expand Up @@ -182,7 +180,7 @@ float distL2f(const arrayf &p1, const arrayf &p2) {
return std::sqrt(result);
}

hamdis_t distHamming(const arrayli &p1, const arrayli &p2) {
int64_t distHamming(const arrayli &p1, const arrayli &p2) {

return hamming<256>(reinterpret_cast<const uint64_t *>(&p1[0]), reinterpret_cast<const uint64_t *>(&p2[0]));
}
79 changes: 39 additions & 40 deletions pyvptree/include/VPTree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@

namespace vptree {

class VPLevelPartition {
template <typename distance_type> class VPLevelPartition {
public:
VPLevelPartition(float radius, unsigned int start, unsigned int end) {
VPLevelPartition(distance_type radius, unsigned int start, unsigned int end) {
// For each partition, the vantage point is the first point within the partition (pointed by indexStart)

_radius = radius;
Expand All @@ -44,16 +44,16 @@ class VPLevelPartition {
unsigned int start() { return _indexStart; }
unsigned int end() { return _indexEnd; }
unsigned int size() { return _indexEnd - _indexStart + 1; }
void setRadius(float radius) { _radius = radius; }
float radius() { return _radius; }
void setRadius(distance_type radius) { _radius = radius; }
distance_type radius() { return _radius; }

void setChild(VPLevelPartition *left, VPLevelPartition *right) {
void setChild(VPLevelPartition<distance_type> *left, VPLevelPartition<distance_type> *right) {
_left = left;
_right = right;
}

VPLevelPartition *left() { return _left; }
VPLevelPartition *right() { return _right; }
VPLevelPartition<distance_type> *left() { return _left; }
VPLevelPartition<distance_type> *right() { return _right; }

private:
void clear() {
Expand All @@ -67,7 +67,7 @@ class VPLevelPartition {
_right = nullptr;
}

float _radius;
distance_type _radius;

// _indexStart and _indexEnd are index pointers to examples within the examples list, not index of coordinates
// within the coordinate buffer.For instance, _indexEnd pointing to last element of a coordinate buffer of 9 entries
Expand All @@ -76,11 +76,11 @@ class VPLevelPartition {
unsigned int _indexStart; // points to the first of the example in which this level starts
unsigned int _indexEnd;

VPLevelPartition *_left = nullptr;
VPLevelPartition *_right = nullptr;
VPLevelPartition<distance_type> *_left = nullptr;
VPLevelPartition<distance_type> *_right = nullptr;
};

template <typename T, float (*distance)(const T &, const T &)> class VPTree {
template <typename T, typename distance_type, distance_type (*distance)(const T &, const T &)> class VPTree {
public:
struct VPTreeElement {

Expand All @@ -96,7 +96,7 @@ template <typename T, float (*distance)(const T &, const T &)> class VPTree {

struct VPTreeSearchResultElement {
std::vector<unsigned int> indexes;
std::vector<float> distances;
std::vector<distance_type> distances;
};

VPTree() = default;
Expand Down Expand Up @@ -137,7 +137,7 @@ template <typename T, float (*distance)(const T &, const T &)> class VPTree {
}

// An optimized version for 1 NN search
void search1NN(const std::vector<T> &queries, std::vector<unsigned int> &indices, std::vector<float> &distances) {
void search1NN(const std::vector<T> &queries, std::vector<unsigned int> &indices, std::vector<distance_type> &distances) {

if (_rootPartition == nullptr) {
return;
Expand All @@ -152,7 +152,7 @@ template <typename T, float (*distance)(const T &, const T &)> class VPTree {
#endif
for (int i = 0; i < queries.size(); ++i) {
const T &query = queries[i];
float dist = 0;
distance_type dist = 0;
unsigned int index = -1;
search1NN(_rootPartition, query, index, dist);
distances[i] = dist;
Expand All @@ -168,15 +168,15 @@ template <typename T, float (*distance)(const T &, const T &)> class VPTree {
void build(const std::vector<VPTreeElement> &array) {

// Select vantage point
std::vector<VPLevelPartition *> _toSplit;
std::vector<VPLevelPartition<distance_type> *> _toSplit;

auto *root = new VPLevelPartition(0, 0, _examples.size() - 1);
auto *root = new VPLevelPartition<distance_type>(0, 0, _examples.size() - 1);
_toSplit.push_back(root);
_rootPartition = root;

while (!_toSplit.empty()) {

VPLevelPartition *current = _toSplit.back();
VPLevelPartition<distance_type> *current = _toSplit.back();
_toSplit.pop_back();

unsigned int start = current->start();
Expand All @@ -199,20 +199,20 @@ template <typename T, float (*distance)(const T &, const T &)> class VPTree {
VPDistanceComparator(_examples[start]));

/* // distance from vantage point (which is at start index) and the median element */
float medianDistance = distance(_examples[start].val, _examples[median].val);
auto medianDistance = distance(_examples[start].val, _examples[median].val);
current->setRadius(medianDistance);

// Schedule to build next levels
// Left is every one within the median distance radius
VPLevelPartition *left = nullptr;
VPLevelPartition<distance_type> *left = nullptr;
if (start + 1 <= median) {
left = new VPLevelPartition(-1, start + 1, median);
left = new VPLevelPartition<distance_type>(-1, start + 1, median);
_toSplit.push_back(left);
}

VPLevelPartition *right = nullptr;
VPLevelPartition<distance_type> *right = nullptr;
if (median + 1 <= end) {
right = new VPLevelPartition(-1, median + 1, end);
right = new VPLevelPartition<distance_type>(-1, median + 1, end);
_toSplit.push_back(right);
}

Expand All @@ -222,17 +222,16 @@ template <typename T, float (*distance)(const T &, const T &)> class VPTree {

// Internal temporary struct to organize K closest elements in a priorty queue
struct VPTreeSearchElement {
VPTreeSearchElement(int index, float dist) : index(index), dist(dist) {}
VPTreeSearchElement(int index, distance_type dist) : index(index), dist(dist) {}
int index;
float dist;
distance_type dist;
bool operator<(const VPTreeSearchElement &v) const { return dist < v.dist; }
};

void exaustivePartitionSearch(VPLevelPartition *partition, const T &val, unsigned int k, std::priority_queue<VPTreeSearchElement> &knnQueue,
float tau) {
void exaustivePartitionSearch(VPLevelPartition<distance_type> *partition, const T &val, unsigned int k, std::priority_queue<VPTreeSearchElement> &knnQueue, distance_type tau) {
for (int i = partition->start(); i <= partition->end(); ++i) {

float dist = distance(val, _examples[i].val);
auto dist = distance(val, _examples[i].val);
if (dist < tau || knnQueue.size() < k) {

if (knnQueue.size() == k) {
Expand All @@ -246,20 +245,20 @@ template <typename T, float (*distance)(const T &, const T &)> class VPTree {
}
}

void searchKNN(VPLevelPartition *partition, const T &val, unsigned int k, std::priority_queue<VPTreeSearchElement> &knnQueue) {
void searchKNN(VPLevelPartition<distance_type> *partition, const T &val, unsigned int k, std::priority_queue<VPTreeSearchElement> &knnQueue) {

float tau = std::numeric_limits<float>::max();
auto tau = std::numeric_limits<distance_type>::max();

// stores the distance to the partition border at the time of the storage. Since tau value will change
// whiling performing the DFS search from on level, the storage distance will be checked again when about
// to dive into that partition. It might not be necessary to dig into the partition anymore if tau decreased.
std::vector<std::tuple<float, VPLevelPartition *>> toSearch = {{-1, partition}};
std::vector<std::tuple<distance_type, VPLevelPartition<distance_type> *>> toSearch = {{-1, partition}};

while (!toSearch.empty()) {
auto [distToBorder, current] = toSearch.back();
toSearch.pop_back();

float dist = distance(val, _examples[current->start()].val);
auto dist = distance(val, _examples[current->start()].val);
if (dist < tau || knnQueue.size() < k) {

if (knnQueue.size() == k) {
Expand Down Expand Up @@ -298,7 +297,7 @@ template <typename T, float (*distance)(const T &, const T &)> class VPTree {

unsigned int rightPartitionSize = (current->right() != nullptr) ? current->right()->size() : 0;
bool notEnoughPointsOutside = rightPartitionSize < (k - neighborsSoFar);
float toBorder = dist - current->radius();
auto toBorder = dist - current->radius();

// we might not have enough point outside to reject the inside partition, so we might need to search
// for both
Expand All @@ -321,7 +320,7 @@ template <typename T, float (*distance)(const T &, const T &)> class VPTree {

unsigned int leftPartitionSize = (current->left() != nullptr) ? current->left()->size() : 0;
bool notEnoughPointsInside = leftPartitionSize < (k - neighborsSoFar);
float toBorder = current->radius() - dist;
auto toBorder = current->radius() - dist;

if (notEnoughPointsInside) {
toSearch.push_back({-1, current->right()});
Expand All @@ -338,19 +337,19 @@ template <typename T, float (*distance)(const T &, const T &)> class VPTree {
}
}

void search1NN(VPLevelPartition *partition, const T &val, unsigned int &resultIndex, float &resultDist) {
void search1NN(VPLevelPartition<distance_type> *partition, const T &val, unsigned int &resultIndex, distance_type &resultDist) {

resultDist = std::numeric_limits<float>::max();
resultDist = std::numeric_limits<distance_type>::max();
resultIndex = -1;

std::vector<std::tuple<float, VPLevelPartition *>> toSearch = {{-1, partition}};
std::vector<std::tuple<distance_type, VPLevelPartition<distance_type> *>> toSearch = {{-1, partition}};

while (!toSearch.empty()) {

auto [distToBorder, current] = toSearch.back();
toSearch.pop_back();

float dist = distance(val, _examples[current->start()].val);
auto dist = distance(val, _examples[current->start()].val);
if (dist < resultDist) {
resultDist = dist;
resultIndex = _examples[current->start()].originalIndex;
Expand All @@ -364,7 +363,7 @@ template <typename T, float (*distance)(const T &, const T &)> class VPTree {

if (dist > current->radius()) {
// may need to search inside as well
float toBorder = dist - current->radius();
auto toBorder = dist - current->radius();
if (toBorder < resultDist && current->left() != nullptr) {
toSearch.push_back({toBorder, current->left()});
}
Expand All @@ -374,7 +373,7 @@ template <typename T, float (*distance)(const T &, const T &)> class VPTree {
toSearch.push_back({-1, current->right()});
}
} else {
float toBorder = current->radius() - dist;
auto toBorder = current->radius() - dist;
// may need to search outside as well
if (toBorder < resultDist && current->right() != nullptr) {
toSearch.push_back({toBorder, current->right()});
Expand Down Expand Up @@ -427,7 +426,7 @@ template <typename T, float (*distance)(const T &, const T &)> class VPTree {

protected:
std::vector<VPTreeElement> _examples;
VPLevelPartition *_rootPartition = nullptr;
VPLevelPartition<distance_type> *_rootPartition = nullptr;
};

} // namespace vptree
20 changes: 10 additions & 10 deletions pyvptree/src/PythonBindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ class VPTreeNumpyAdapter {
public:
VPTreeNumpyAdapter() = default;

void set(const ndarrayf &array) { _tree = vptree::VPTree<arrayf, dist_optimized_float>(array); }
void set(const ndarrayf &array) { _tree = vptree::VPTree<arrayf, float, dist_optimized_float>(array); }

std::tuple<std::vector<std::vector<unsigned int>>, std::vector<std::vector<float>>> searchKNN(const ndarrayf &queries, unsigned int k) {

std::vector<vptree::VPTree<arrayf, dist_optimized_float>::VPTreeSearchResultElement> results;
std::vector<vptree::VPTree<arrayf, float, dist_optimized_float>::VPTreeSearchResultElement> results;
_tree.searchKNN(queries, k, results);

std::vector<std::vector<unsigned int>> indexes;
Expand All @@ -46,22 +46,22 @@ class VPTreeNumpyAdapter {
}

private:
vptree::VPTree<arrayf, dist_optimized_float> _tree;
vptree::VPTree<arrayf, float, dist_optimized_float> _tree;
};

class VPTreeBinaryNumpyAdapter {
public:
VPTreeBinaryNumpyAdapter() = default;

void set(const ndarrayli &array) { _tree = vptree::VPTree<arrayli, distHamming>(array); }
void set(const ndarrayli &array) { _tree = vptree::VPTree<arrayli, int64_t, distHamming>(array); }

std::tuple<std::vector<std::vector<unsigned int>>, std::vector<std::vector<float>>> searchKNN(const ndarrayli &queries, unsigned int k) {
std::tuple<std::vector<std::vector<unsigned int>>, std::vector<std::vector<int64_t>>> searchKNN(const ndarrayli &queries, unsigned int k) {

std::vector<vptree::VPTree<arrayli, distHamming>::VPTreeSearchResultElement> results;
std::vector<vptree::VPTree<arrayli, int64_t, distHamming>::VPTreeSearchResultElement> results;
_tree.searchKNN(queries, k, results);

std::vector<std::vector<unsigned int>> indexes;
std::vector<std::vector<float>> distances;
std::vector<std::vector<int64_t>> distances;
indexes.resize(results.size());
distances.resize(results.size());
for (int i = 0; i < results.size(); ++i) {
Expand All @@ -71,17 +71,17 @@ class VPTreeBinaryNumpyAdapter {
return std::make_tuple(indexes, distances);
}

std::tuple<std::vector<unsigned int>, std::vector<float>> search1NN(const ndarrayli &queries) {
std::tuple<std::vector<unsigned int>, std::vector<int64_t>> search1NN(const ndarrayli &queries) {

std::vector<unsigned int> indices;
std::vector<float> distances;
std::vector<int64_t> distances;
_tree.search1NN(queries, indices, distances);

return std::make_tuple(std::move(indices), std::move(distances));
}

private:
vptree::VPTree<arrayli, distHamming> _tree;
vptree::VPTree<arrayli, int64_t, distHamming> _tree;
};

PYBIND11_MODULE(_pyvptree, m) {
Expand Down
Loading

0 comments on commit 61cc736

Please sign in to comment.