Skip to content

Commit a963725

Browse files
committed
add tests
1 parent 54116a4 commit a963725

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

tests/test_stats.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,20 @@ def test_sum_to_int(array_type: ArrayType[CpuArray | DiskArray | types.DaskArray
201201
np.testing.assert_array_equal(sum_, expected)
202202

203203

204+
@pytest.mark.array_type(skip=ATS_SPARSE_DS)
205+
@pytest.mark.parametrize("func", [stats.min, stats.max])
206+
def test_min_max(array_type: ArrayType[CpuArray | GpuArray | DiskArray | types.DaskArray], axis: Literal[0, 1] | None, func: StatFunNoDtype) -> None:
207+
rng = np.random.default_rng(0)
208+
np_arr = rng.random((100, 100))
209+
arr = array_type(np_arr)
210+
211+
result = to_np_dense_checked(func(arr, axis=axis), axis, arr)
212+
213+
expected = (np.min if func is stats.min else np.max)(np_arr, axis=axis)
214+
np.testing.assert_array_equal(result, expected)
215+
216+
217+
@pytest.mark.parametrize("func", [stats.sum, stats.min, stats.max])
204218
@pytest.mark.parametrize(
205219
"data",
206220
[
@@ -211,14 +225,15 @@ def test_sum_to_int(array_type: ArrayType[CpuArray | DiskArray | types.DaskArray
211225
)
212226
@pytest.mark.parametrize("axis", [0, 1])
213227
@pytest.mark.array_type(Flags.Dask)
214-
def test_sum_dask_shapes(array_type: ArrayType[types.DaskArray], axis: Literal[0, 1], data: list[list[int]]) -> None:
228+
def test_dask_shapes(array_type: ArrayType[types.DaskArray], axis: Literal[0, 1], data: list[list[int]], func: StatFunNoDtype) -> None:
215229
np_arr = np.array(data, dtype=np.float32)
216230
arr = array_type(np_arr)
217231
assert 1 in arr.chunksize, "This test is supposed to test 1×n and n×1 chunk sizes"
218-
sum_ = cast("NDArray[Any] | types.CupyArray", stats.sum(arr, axis=axis).compute())
219-
if isinstance(sum_, types.CupyArray):
220-
sum_ = sum_.get()
221-
np.testing.assert_almost_equal(np_arr.sum(axis=axis), sum_)
232+
stat = cast("NDArray[Any] | types.CupyArray", func(arr, axis=axis).compute())
233+
if isinstance(stat, types.CupyArray):
234+
stat = stat.get()
235+
np_func = getattr(np, func.__name__)
236+
np.testing.assert_almost_equal(stat, np_func(np_arr, axis=axis))
222237

223238

224239
@pytest.mark.array_type(skip=ATS_SPARSE_DS)

0 commit comments

Comments
 (0)