Skip to content

Commit

Permalink
TST: Use more pytest fixtures (#53567)
Browse files Browse the repository at this point in the history
* TST: Use more fixtures

* Use more fixtures in test_indexing_slow

* Move addition compression_to_extension

* Use more fixture in sparse test_indexing

* fixturize libsparse
  • Loading branch information
mroeschke committed Jun 9, 2023
1 parent 2a2002d commit 6af5861
Show file tree
Hide file tree
Showing 15 changed files with 367 additions and 324 deletions.
29 changes: 18 additions & 11 deletions pandas/tests/arrays/sparse/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,25 @@
import pandas._testing as tm
from pandas.core.arrays.sparse import SparseArray

arr_data = np.array([np.nan, np.nan, 1, 2, 3, np.nan, 4, 5, np.nan, 6])
arr = SparseArray(arr_data)

@pytest.fixture
def arr_data():
return np.array([np.nan, np.nan, 1, 2, 3, np.nan, 4, 5, np.nan, 6])


@pytest.fixture
def arr(arr_data):
return SparseArray(arr_data)


class TestGetitem:
def test_getitem(self):
def test_getitem(self, arr):
dense = arr.to_dense()
for i, value in enumerate(arr):
tm.assert_almost_equal(value, dense[i])
tm.assert_almost_equal(arr[-i], dense[-i])

def test_getitem_arraylike_mask(self):
def test_getitem_arraylike_mask(self, arr):
arr = SparseArray([0, 1, 2])
result = arr[[True, False, True]]
expected = SparseArray([0, 2])
Expand Down Expand Up @@ -81,7 +88,7 @@ def test_boolean_slice_empty(self):
res = arr[[False, False, False]]
assert res.dtype == arr.dtype

def test_getitem_bool_sparse_array(self):
def test_getitem_bool_sparse_array(self, arr):
# GH 23122
spar_bool = SparseArray([False, True] * 5, dtype=np.bool_, fill_value=True)
exp = SparseArray([np.nan, 2, np.nan, 5, 6])
Expand All @@ -106,7 +113,7 @@ def test_getitem_bool_sparse_array_as_comparison(self):
exp = SparseArray([3.0, 4.0], fill_value=np.nan)
tm.assert_sp_array_equal(res, exp)

def test_get_item(self):
def test_get_item(self, arr):
zarr = SparseArray([0, 0, 1, 2, 3, 0, 4, 5, 0, 6], fill_value=0)

assert np.isnan(arr[1])
Expand All @@ -129,7 +136,7 @@ def test_get_item(self):


class TestSetitem:
def test_set_item(self):
def test_set_item(self, arr_data):
arr = SparseArray(arr_data).copy()

def setitem():
Expand All @@ -146,12 +153,12 @@ def setslice():


class TestTake:
def test_take_scalar_raises(self):
def test_take_scalar_raises(self, arr):
msg = "'indices' must be an array, not a scalar '2'."
with pytest.raises(ValueError, match=msg):
arr.take(2)

def test_take(self):
def test_take(self, arr_data, arr):
exp = SparseArray(np.take(arr_data, [2, 3]))
tm.assert_sp_array_equal(arr.take([2, 3]), exp)

Expand All @@ -173,14 +180,14 @@ def test_take_fill_value(self):
exp = SparseArray(np.take(data, [1, 3, 4]), fill_value=0)
tm.assert_sp_array_equal(sparse.take([1, 3, 4]), exp)

def test_take_negative(self):
def test_take_negative(self, arr_data, arr):
exp = SparseArray(np.take(arr_data, [-1]))
tm.assert_sp_array_equal(arr.take([-1]), exp)

exp = SparseArray(np.take(arr_data, [-4, -3, -2]))
tm.assert_sp_array_equal(arr.take([-4, -3, -2]), exp)

def test_bad_take(self):
def test_bad_take(self, arr):
with pytest.raises(IndexError, match="bounds"):
arr.take([11])

Expand Down
177 changes: 87 additions & 90 deletions pandas/tests/arrays/sparse/test_libsparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,77 +14,74 @@
make_sparse_index,
)

TEST_LENGTH = 20

plain_case = [
[0, 7, 15],
[3, 5, 5],
[2, 9, 14],
[2, 3, 5],
[2, 9, 15],
[1, 3, 4],
]
delete_blocks = [
[0, 5],
[4, 4],
[1],
[4],
[1],
[3],
]
split_blocks = [
[0],
[10],
[0, 5],
[3, 7],
[0, 5],
[3, 5],
]
skip_block = [
[10],
[5],
[0, 12],
[5, 3],
[12],
[3],
]

no_intersect = [
[0, 10],
[4, 6],
[5, 17],
[4, 2],
[],
[],
]

one_empty = [
[0],
[5],
[],
[],
[],
[],
]

both_empty = [ # type: ignore[var-annotated]
[],
[],
[],
[],
[],
[],
]

CASES = [plain_case, delete_blocks, split_blocks, skip_block, no_intersect, one_empty]
IDS = [
"plain_case",
"delete_blocks",
"split_blocks",
"skip_block",
"no_intersect",
"one_empty",
]

@pytest.fixture
def test_length():
return 20


@pytest.fixture(
params=[
[
[0, 7, 15],
[3, 5, 5],
[2, 9, 14],
[2, 3, 5],
[2, 9, 15],
[1, 3, 4],
],
[
[0, 5],
[4, 4],
[1],
[4],
[1],
[3],
],
[
[0],
[10],
[0, 5],
[3, 7],
[0, 5],
[3, 5],
],
[
[10],
[5],
[0, 12],
[5, 3],
[12],
[3],
],
[
[0, 10],
[4, 6],
[5, 17],
[4, 2],
[],
[],
],
[
[0],
[5],
[],
[],
[],
[],
],
],
ids=[
"plain_case",
"delete_blocks",
"split_blocks",
"skip_block",
"no_intersect",
"one_empty",
],
)
def cases(request):
return request.param


class TestSparseIndexUnion:
Expand All @@ -101,7 +98,7 @@ class TestSparseIndexUnion:
[[0, 10], [3, 3], [5, 15], [2, 2], [0, 5, 10, 15], [3, 2, 3, 2]],
],
)
def test_index_make_union(self, xloc, xlen, yloc, ylen, eloc, elen):
def test_index_make_union(self, xloc, xlen, yloc, ylen, eloc, elen, test_length):
# Case 1
# x: ----
# y: ----
Expand Down Expand Up @@ -132,8 +129,8 @@ def test_index_make_union(self, xloc, xlen, yloc, ylen, eloc, elen):
# Case 8
# x: ---- ---
# y: --- ---
xindex = BlockIndex(TEST_LENGTH, xloc, xlen)
yindex = BlockIndex(TEST_LENGTH, yloc, ylen)
xindex = BlockIndex(test_length, xloc, xlen)
yindex = BlockIndex(test_length, yloc, ylen)
bresult = xindex.make_union(yindex)
assert isinstance(bresult, BlockIndex)
tm.assert_numpy_array_equal(bresult.blocs, np.array(eloc, dtype=np.int32))
Expand Down Expand Up @@ -180,12 +177,12 @@ def test_int_index_make_union(self):

class TestSparseIndexIntersect:
@td.skip_if_windows
@pytest.mark.parametrize("xloc, xlen, yloc, ylen, eloc, elen", CASES, ids=IDS)
def test_intersect(self, xloc, xlen, yloc, ylen, eloc, elen):
xindex = BlockIndex(TEST_LENGTH, xloc, xlen)
yindex = BlockIndex(TEST_LENGTH, yloc, ylen)
expected = BlockIndex(TEST_LENGTH, eloc, elen)
longer_index = BlockIndex(TEST_LENGTH + 1, yloc, ylen)
def test_intersect(self, cases, test_length):
xloc, xlen, yloc, ylen, eloc, elen = cases
xindex = BlockIndex(test_length, xloc, xlen)
yindex = BlockIndex(test_length, yloc, ylen)
expected = BlockIndex(test_length, eloc, elen)
longer_index = BlockIndex(test_length + 1, yloc, ylen)

result = xindex.intersect(yindex)
assert result.equals(expected)
Expand Down Expand Up @@ -493,10 +490,10 @@ def test_equals(self):
assert index.equals(index)
assert not index.equals(IntIndex(10, [0, 1, 2, 3]))

@pytest.mark.parametrize("xloc, xlen, yloc, ylen, eloc, elen", CASES, ids=IDS)
def test_to_block_index(self, xloc, xlen, yloc, ylen, eloc, elen):
xindex = BlockIndex(TEST_LENGTH, xloc, xlen)
yindex = BlockIndex(TEST_LENGTH, yloc, ylen)
def test_to_block_index(self, cases, test_length):
xloc, xlen, yloc, ylen, _, _ = cases
xindex = BlockIndex(test_length, xloc, xlen)
yindex = BlockIndex(test_length, yloc, ylen)

# see if survive the round trip
xbindex = xindex.to_int_index().to_block_index()
Expand All @@ -512,13 +509,13 @@ def test_to_int_index(self):

class TestSparseOperators:
@pytest.mark.parametrize("opname", ["add", "sub", "mul", "truediv", "floordiv"])
@pytest.mark.parametrize("xloc, xlen, yloc, ylen, eloc, elen", CASES, ids=IDS)
def test_op(self, opname, xloc, xlen, yloc, ylen, eloc, elen):
def test_op(self, opname, cases, test_length):
xloc, xlen, yloc, ylen, _, _ = cases
sparse_op = getattr(splib, f"sparse_{opname}_float64")
python_op = getattr(operator, opname)

xindex = BlockIndex(TEST_LENGTH, xloc, xlen)
yindex = BlockIndex(TEST_LENGTH, yloc, ylen)
xindex = BlockIndex(test_length, xloc, xlen)
yindex = BlockIndex(test_length, yloc, ylen)

xdindex = xindex.to_int_index()
ydindex = yindex.to_int_index()
Expand All @@ -542,10 +539,10 @@ def test_op(self, opname, xloc, xlen, yloc, ylen, eloc, elen):

# check versus Series...
xseries = Series(x, xdindex.indices)
xseries = xseries.reindex(np.arange(TEST_LENGTH)).fillna(xfill)
xseries = xseries.reindex(np.arange(test_length)).fillna(xfill)

yseries = Series(y, ydindex.indices)
yseries = yseries.reindex(np.arange(TEST_LENGTH)).fillna(yfill)
yseries = yseries.reindex(np.arange(test_length)).fillna(yfill)

series_result = python_op(xseries, yseries)
series_result = series_result.reindex(ri_index.indices)
Expand Down

0 comments on commit 6af5861

Please sign in to comment.