Skip to content

Commit

Permalink
TST: Fix dtype mismatch on 32bit in IntervalTree get_indexer test (pa…
Browse files Browse the repository at this point in the history
  • Loading branch information
jschendel authored and tm9k1 committed Nov 19, 2018
1 parent 34a76a0 commit cb1b288
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 24 deletions.
7 changes: 4 additions & 3 deletions pandas/_libs/intervaltree.pxi.in
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ cdef class IntervalTree(IntervalMixin):
self.root.query(result, key)
if not result.data.n:
raise KeyError(key)
return result.to_array()
return result.to_array().astype('intp')

def _get_partial_overlap(self, key_left, key_right, side):
"""Return all positions corresponding to intervals with the given side
Expand Down Expand Up @@ -155,7 +155,7 @@ cdef class IntervalTree(IntervalMixin):
raise KeyError(
'indexer does not intersect a unique set of intervals')
old_len = result.data.n
return result.to_array()
return result.to_array().astype('intp')

def get_indexer_non_unique(self, scalar_t[:] target):
"""Return the positions corresponding to intervals that overlap with
Expand All @@ -175,7 +175,8 @@ cdef class IntervalTree(IntervalMixin):
result.append(-1)
missing.append(i)
old_len = result.data.n
return result.to_array(), missing.to_array()
return (result.to_array().astype('intp'),
missing.to_array().astype('intp'))

def __repr__(self):
return ('<IntervalTree[{dtype},{closed}]: '
Expand Down
63 changes: 42 additions & 21 deletions pandas/tests/indexes/interval/test_interval_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,44 +49,64 @@ def tree(request, leaf_size):
class TestIntervalTree(object):

def test_get_loc(self, tree):
tm.assert_numpy_array_equal(tree.get_loc(1),
np.array([0], dtype='int64'))
tm.assert_numpy_array_equal(np.sort(tree.get_loc(2)),
np.array([0, 1], dtype='int64'))
result = tree.get_loc(1)
expected = np.array([0], dtype='intp')
tm.assert_numpy_array_equal(result, expected)

result = np.sort(tree.get_loc(2))
expected = np.array([0, 1], dtype='intp')
tm.assert_numpy_array_equal(result, expected)

with pytest.raises(KeyError):
tree.get_loc(-1)

def test_get_indexer(self, tree):
tm.assert_numpy_array_equal(
tree.get_indexer(np.array([1.0, 5.5, 6.5])),
np.array([0, 4, -1], dtype='int64'))
result = tree.get_indexer(np.array([1.0, 5.5, 6.5]))
expected = np.array([0, 4, -1], dtype='intp')
tm.assert_numpy_array_equal(result, expected)

with pytest.raises(KeyError):
tree.get_indexer(np.array([3.0]))

def test_get_indexer_non_unique(self, tree):
indexer, missing = tree.get_indexer_non_unique(
np.array([1.0, 2.0, 6.5]))
tm.assert_numpy_array_equal(indexer[:1],
np.array([0], dtype='int64'))
tm.assert_numpy_array_equal(np.sort(indexer[1:3]),
np.array([0, 1], dtype='int64'))
tm.assert_numpy_array_equal(np.sort(indexer[3:]),
np.array([-1], dtype='int64'))
tm.assert_numpy_array_equal(missing, np.array([2], dtype='int64'))

result = indexer[:1]
expected = np.array([0], dtype='intp')
tm.assert_numpy_array_equal(result, expected)

result = np.sort(indexer[1:3])
expected = np.array([0, 1], dtype='intp')
tm.assert_numpy_array_equal(result, expected)

result = np.sort(indexer[3:])
expected = np.array([-1], dtype='intp')
tm.assert_numpy_array_equal(result, expected)

result = missing
expected = np.array([2], dtype='intp')
tm.assert_numpy_array_equal(result, expected)

def test_duplicates(self, dtype):
left = np.array([0, 0, 0], dtype=dtype)
tree = IntervalTree(left, left + 1)
tm.assert_numpy_array_equal(np.sort(tree.get_loc(0.5)),
np.array([0, 1, 2], dtype='int64'))

result = np.sort(tree.get_loc(0.5))
expected = np.array([0, 1, 2], dtype='intp')
tm.assert_numpy_array_equal(result, expected)

with pytest.raises(KeyError):
tree.get_indexer(np.array([0.5]))

indexer, missing = tree.get_indexer_non_unique(np.array([0.5]))
tm.assert_numpy_array_equal(np.sort(indexer),
np.array([0, 1, 2], dtype='int64'))
tm.assert_numpy_array_equal(missing, np.array([], dtype='int64'))
result = np.sort(indexer)
expected = np.array([0, 1, 2], dtype='intp')
tm.assert_numpy_array_equal(result, expected)

result = missing
expected = np.array([], dtype='intp')
tm.assert_numpy_array_equal(result, expected)

def test_get_loc_closed(self, closed):
tree = IntervalTree([0], [1], closed=closed)
Expand All @@ -96,8 +116,9 @@ def test_get_loc_closed(self, closed):
with pytest.raises(KeyError):
tree.get_loc(p)
else:
tm.assert_numpy_array_equal(tree.get_loc(p),
np.array([0], dtype='int64'))
result = tree.get_loc(p)
expected = np.array([0], dtype='intp')
tm.assert_numpy_array_equal(result, expected)

@pytest.mark.parametrize('leaf_size', [
skipif_32bit(1), skipif_32bit(10), skipif_32bit(100), 10000])
Expand Down

0 comments on commit cb1b288

Please sign in to comment.