|
2 | 2 | from __future__ import annotations |
3 | 3 |
|
4 | 4 | from functools import partial |
5 | | -from typing import TYPE_CHECKING, Literal, cast, get_args |
| 5 | +from typing import TYPE_CHECKING, Literal, TypeVar, cast, get_args |
6 | 6 |
|
7 | 7 | import numpy as np |
8 | 8 | from numpy.exceptions import AxisError |
@@ -71,7 +71,7 @@ def _dask_block( |
71 | 71 | fns = {fn.__name__: fn for fn in (min, max, sum)} |
72 | 72 |
|
73 | 73 | axis = _normalize_axis(axis, a.ndim) |
74 | | - rv = fns[op](a, axis=axis, keep_cupy_as_array=True, **_dtype_kw(dtype, op)) # type: ignore[misc,call-overload] |
| 74 | + rv = fns[op](a, axis=axis, keep_cupy_as_array=True, **_dtype_kw(dtype, op)) # type: ignore[call-overload] |
75 | 75 | shape = _get_shape(rv, axis=axis, keepdims=keepdims) |
76 | 76 | return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape)) |
77 | 77 |
|
@@ -109,5 +109,8 @@ def _get_shape(a: NDArray[Any] | np.number[Any] | types.CupyArray, *, axis: Lite |
109 | 109 | raise AssertionError(msg) |
110 | 110 |
|
111 | 111 |
|
112 | | -def _dtype_kw(dtype: DTypeLike | None, op: Ops) -> DTypeKw: |
| 112 | +DT = TypeVar("DT", bound="DTypeLike") |
| 113 | + |
| 114 | + |
| 115 | +def _dtype_kw(dtype: DT | None, op: Ops) -> DTypeKw[DT]: |
113 | 116 | return {"dtype": dtype} if dtype is not None and op in get_args(DtypeOps) else {} |
0 commit comments