@@ -201,6 +201,20 @@ def test_sum_to_int(array_type: ArrayType[CpuArray | DiskArray | types.DaskArray
201201 np .testing .assert_array_equal (sum_ , expected )
202202
203203
204+ @pytest .mark .array_type (skip = ATS_SPARSE_DS )
205+ @pytest .mark .parametrize ("func" , [stats .min , stats .max ])
206+ def test_min_max (array_type : ArrayType [CpuArray | GpuArray | DiskArray | types .DaskArray ], axis : Literal [0 , 1 ] | None , func : StatFunNoDtype ) -> None :
207+ rng = np .random .default_rng (0 )
208+ np_arr = rng .random ((100 , 100 ))
209+ arr = array_type (np_arr )
210+
211+ result = to_np_dense_checked (func (arr , axis = axis ), axis , arr )
212+
213+ expected = (np .min if func is stats .min else np .max )(np_arr , axis = axis )
214+ np .testing .assert_array_equal (result , expected )
215+
216+
217+ @pytest .mark .parametrize ("func" , [stats .sum , stats .min , stats .max ])
204218@pytest .mark .parametrize (
205219 "data" ,
206220 [
@@ -211,14 +225,15 @@ def test_sum_to_int(array_type: ArrayType[CpuArray | DiskArray | types.DaskArray
211225)
212226@pytest .mark .parametrize ("axis" , [0 , 1 ])
213227@pytest .mark .array_type (Flags .Dask )
214- def test_sum_dask_shapes (array_type : ArrayType [types .DaskArray ], axis : Literal [0 , 1 ], data : list [list [int ]]) -> None :
228+ def test_dask_shapes (array_type : ArrayType [types .DaskArray ], axis : Literal [0 , 1 ], data : list [list [int ]], func : StatFunNoDtype ) -> None :
215229 np_arr = np .array (data , dtype = np .float32 )
216230 arr = array_type (np_arr )
217231 assert 1 in arr .chunksize , "This test is supposed to test 1×n and n×1 chunk sizes"
218- sum_ = cast ("NDArray[Any] | types.CupyArray" , stats .sum (arr , axis = axis ).compute ())
219- if isinstance (sum_ , types .CupyArray ):
220- sum_ = sum_ .get ()
221- np .testing .assert_almost_equal (np_arr .sum (axis = axis ), sum_ )
232+ stat = cast ("NDArray[Any] | types.CupyArray" , func (arr , axis = axis ).compute ())
233+ if isinstance (stat , types .CupyArray ):
234+ stat = stat .get ()
235+ np_func = getattr (np , func .__name__ )
236+ np .testing .assert_almost_equal (stat , np_func (np_arr , axis = axis ))
222237
223238
224239@pytest .mark .array_type (skip = ATS_SPARSE_DS )
0 commit comments