Skip to content

Commit

Permalink
Tests for angular distances and fix small bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Erik Bernhardsson committed Nov 11, 2015
1 parent b33c45a commit ae74624
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 2 deletions.
4 changes: 4 additions & 0 deletions annoy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,7 @@ def add_item(self, i, vector):
def get_nns_by_vector(self, vector, n, search_k=-1, include_distances=False):
# Same
return super(AnnoyIndex, self).get_nns_by_vector(self.check_list(vector), n, search_k, include_distances)

def get_nns_by_item(self, i, n, search_k=-1, include_distances=False):
# Wrapper to support named arguments
return super(AnnoyIndex, self).get_nns_by_item(i, n, search_k, include_distances)
6 changes: 4 additions & 2 deletions src/annoylib.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ struct Angular {
}
static inline T normalized_distance(T distance) {
// Used when requesting distances from Python layer
return sqrt(distance);
// Turns out sometimes the squared distance is -0.0
// so we have to make sure it's a positive number.
return sqrt(std::max(distance, T(0)));
}
};

Expand Down Expand Up @@ -201,7 +203,7 @@ struct Euclidean {
n->a += -n->v[z] * (iv[z] + jv[z]) / 2;
}
static inline T normalized_distance(T distance) {
return sqrt(distance);
return sqrt(std::max(distance, T(0)));
}
};

Expand Down
24 changes: 24 additions & 0 deletions test/annoy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,30 @@ def test_get_nns_search_k(self):
self.assertEqual(i.get_nns_by_item(0, 3, 10), [0, 1, 2])
self.assertEqual(i.get_nns_by_vector([3, 2, 1], 3, 10), [2, 1, 0])

def test_include_dists(self):
# Double checking issue 112
f = 40
i = AnnoyIndex(f)
v = numpy.random.normal(size=f)
i.add_item(0, v)
i.add_item(1, -v)
i.build(10)

indices, dists = i.get_nns_by_item(0, 2, 10, True)
self.assertEqual(indices, [0, 1])
self.assertAlmostEquals(dists[0], 0.0)
self.assertAlmostEquals(dists[1], 2.0)

def test_include_dists_check_ranges(self):
f = 3
i = AnnoyIndex(f)
for j in xrange(100000):
i.add_item(j, numpy.random.normal(size=f))
i.build(10)
indices, dists = i.get_nns_by_item(0, 100000, include_distances=True)
self.assertTrue(max(dists) < 2.0)
self.assertAlmostEquals(min(dists), 0.0)


class EuclideanIndexTest(TestCase):
def test_get_nns_by_vector(self):
Expand Down

0 comments on commit ae74624

Please sign in to comment.