Skip to content

Commit 54116a4

Browse files
committed
undo disk array
1 parent 6a82c72 commit 54116a4

File tree

5 files changed

+18
-12
lines changed

5 files changed

+18
-12
lines changed

src/fast_array_utils/stats/_generic_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def _generic_op_cs(
8787
if TYPE_CHECKING: # scipy-stubs thinks e.g. "int64" is invalid, which isn’t true
8888
assert isinstance(dtype, np.dtype | type | None)
8989
# convert to array so dimensions collapse as expected
90-
x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **_dtype_kw(dtype, op)) # type: ignore[call-overload]
90+
x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **_dtype_kw(dtype, op)) # type: ignore[arg-type]
9191
return cast("NDArray[Any] | np.number[Any]", getattr(x, op)(axis=axis))
9292

9393

src/fast_array_utils/stats/_mean_var.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515

1616
from numpy.typing import NDArray
1717

18-
from ..typing import CpuArray, DiskArray, GpuArray
18+
from ..typing import CpuArray, GpuArray
1919

2020

2121
@no_type_check # mypy is extremely confused
2222
def mean_var_(
23-
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
23+
x: CpuArray | GpuArray | types.DaskArray,
2424
/,
2525
*,
2626
axis: Literal[0, 1] | None = None,

src/fast_array_utils/stats/_power.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414

1515
from numpy.typing import DTypeLike
1616

17-
from fast_array_utils.typing import CpuArray, DiskArray, GpuArray
17+
from fast_array_utils.typing import CpuArray, GpuArray
1818

1919
# All supported array types except for disk ones and CSDataset
20-
Array: TypeAlias = CpuArray | GpuArray | DiskArray | types.DaskArray
20+
Array: TypeAlias = CpuArray | GpuArray | types.DaskArray
2121

2222
_Arr = TypeVar("_Arr", bound=Array)
2323
_Mat = TypeVar("_Mat", bound=types.CSBase | types.CupyCSMatrix)
@@ -33,7 +33,7 @@ def power(x: _Arr, n: int, /, dtype: DTypeLike | None = None) -> _Arr:
3333
def _power(x: Array, n: int, /, dtype: DTypeLike | None = None) -> Array:
3434
if TYPE_CHECKING:
3535
assert not isinstance(x, types.DaskArray | types.CSBase | types.CupyCSMatrix)
36-
return np.power(x, n, dtype=dtype) # type: ignore[operator]
36+
return x**n if dtype is None else np.power(x, n, dtype=dtype) # type: ignore[operator]
3737

3838

3939
@_power.register(types.CSBase | types.CupyCSMatrix)

src/fast_array_utils/stats/_typing.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: MPL-2.0
22
from __future__ import annotations
33

4-
from typing import TYPE_CHECKING, Literal, Protocol, TypedDict
4+
from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypedDict, TypeVar
55

66
import numpy as np
77

@@ -51,5 +51,8 @@ def __call__(
5151
Ops: TypeAlias = NoDtypeOps | DtypeOps
5252

5353

54-
class DTypeKw(TypedDict, total=False):
55-
dtype: DTypeLike
54+
_DT = TypeVar("_DT", bound="DTypeLike")
55+
56+
57+
class DTypeKw(TypedDict, Generic[_DT], total=False):
58+
dtype: _DT

src/fast_array_utils/stats/_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from __future__ import annotations
33

44
from functools import partial
5-
from typing import TYPE_CHECKING, Literal, cast, get_args
5+
from typing import TYPE_CHECKING, Literal, TypeVar, cast, get_args
66

77
import numpy as np
88
from numpy.exceptions import AxisError
@@ -71,7 +71,7 @@ def _dask_block(
7171
fns = {fn.__name__: fn for fn in (min, max, sum)}
7272

7373
axis = _normalize_axis(axis, a.ndim)
74-
rv = fns[op](a, axis=axis, keep_cupy_as_array=True, **_dtype_kw(dtype, op)) # type: ignore[misc,call-overload]
74+
rv = fns[op](a, axis=axis, keep_cupy_as_array=True, **_dtype_kw(dtype, op)) # type: ignore[call-overload]
7575
shape = _get_shape(rv, axis=axis, keepdims=keepdims)
7676
return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape))
7777

@@ -109,5 +109,8 @@ def _get_shape(a: NDArray[Any] | np.number[Any] | types.CupyArray, *, axis: Lite
109109
raise AssertionError(msg)
110110

111111

112-
def _dtype_kw(dtype: DTypeLike | None, op: Ops) -> DTypeKw:
112+
DT = TypeVar("DT", bound="DTypeLike")
113+
114+
115+
def _dtype_kw(dtype: DT | None, op: Ops) -> DTypeKw[DT]:
113116
return {"dtype": dtype} if dtype is not None and op in get_args(DtypeOps) else {}

0 commit comments

Comments
 (0)