Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 154 additions & 125 deletions src/xarray_dataclass/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,42 @@
from functools import partial
from inspect import signature
from types import MethodType
from typing import Any, Callable, Optional, Protocol, Type, TypeVar, Union, overload
from typing import (
Any,
Callable,
Optional,
Protocol,
TYPE_CHECKING,
Type,
TypeVar,
Union,
overload,
)


# dependencies
import numpy as np
import xarray as xr
from typing_extensions import ParamSpec


# submodules
from .datamodel import DataModel
from .dataoptions import DataOptions
from .typing import AnyArray, AnyXarray, DataClass, Order, Shape, Sizes
from .util import lazy_import


# type hints
# lazy imports of large modules
if TYPE_CHECKING:
import numpy as np
import xarray as xr
else:
np = lazy_import("numpy")
xr = lazy_import("xarray")


# private type hints
PInit = ParamSpec("PInit")
TDataArray = TypeVar("TDataArray", bound=xr.DataArray)
TDataArray = TypeVar("TDataArray", bound="xr.DataArray")


class OptionedClass(DataClass[PInit], Protocol[PInit, TDataArray]):
Expand All @@ -33,29 +51,28 @@ class OptionedClass(DataClass[PInit], Protocol[PInit, TDataArray]):
__dataoptions__: DataOptions[TDataArray]


# runtime functions
@overload
def asdataarray(
dataclass: OptionedClass[PInit, TDataArray],
reference: Optional[AnyXarray] = None,
dataoptions: None = None,
) -> TDataArray: ...


@overload
def asdataarray(
dataclass: DataClass[PInit],
reference: Optional[AnyXarray] = None,
dataoptions: None = None,
) -> xr.DataArray: ...
if TYPE_CHECKING:
# runtime functions
@overload
def asdataarray(
dataclass: OptionedClass[PInit, TDataArray],
reference: Optional[AnyXarray] = None,
dataoptions: None = None,
) -> xr.DataArray: ...

@overload
def asdataarray(
dataclass: DataClass[PInit],
reference: Optional[AnyXarray] = None,
dataoptions: None = None,
) -> xr.DataArray: ...

@overload
def asdataarray(
dataclass: Any,
reference: Optional[AnyXarray] = None,
dataoptions: DataOptions[TDataArray] = DataOptions(xr.DataArray),
) -> TDataArray: ...
@overload
def asdataarray(
dataclass: Any,
reference: Optional[AnyXarray] = None,
dataoptions: DataOptions[TDataArray] = DataOptions(xr.DataArray),
) -> xr.DataArray: ...


def asdataarray(
Expand Down Expand Up @@ -112,19 +129,21 @@ class classproperty:
def __init__(self, func: Any) -> None:
self.__func__ = func

@overload
def __get__(
self,
obj: Any,
cls: Type[OptionedClass[PInit, TDataArray]],
) -> Callable[PInit, TDataArray]: ...
if TYPE_CHECKING:

@overload
def __get__(
self,
obj: Any,
cls: Type[DataClass[PInit]],
) -> Callable[PInit, xr.DataArray]: ...
@overload
def __get__(
self,
obj: Any,
cls: Type[OptionedClass[PInit, TDataArray]],
) -> Callable[PInit, TDataArray]: ...

@overload
def __get__(
self,
obj: Any,
cls: Type[DataClass[PInit]],
) -> Callable[PInit, xr.DataArray]: ...

def __get__(self, obj: Any, cls: Any) -> Any:
return self.__func__(cls)
Expand All @@ -147,23 +166,25 @@ def new(cls: Any, *args: Any, **kwargs: Any) -> Any:
setattr(new, "__signature__", sig)
return MethodType(new, cls)

@overload
@classmethod
def shaped(
cls: Type[OptionedClass[PInit, TDataArray]],
func: Callable[[Shape], AnyArray],
shape: Union[Shape, Sizes],
**kwargs: Any,
) -> TDataArray: ...

@overload
@classmethod
def shaped(
cls: Type[DataClass[PInit]],
func: Callable[[Shape], AnyArray],
shape: Union[Shape, Sizes],
**kwargs: Any,
) -> xr.DataArray: ...
if TYPE_CHECKING:

@overload
@classmethod
def shaped(
cls: Type[OptionedClass[PInit, TDataArray]],
func: Callable[[Shape], AnyArray],
shape: Union[Shape, Sizes],
**kwargs: Any,
) -> TDataArray: ...

@overload
@classmethod
def shaped(
cls: Type[DataClass[PInit]],
func: Callable[[Shape], AnyArray],
shape: Union[Shape, Sizes],
**kwargs: Any,
) -> xr.DataArray: ...

@classmethod
def shaped(
Expand Down Expand Up @@ -191,23 +212,25 @@ def shaped(

return asdataarray(cls(**{key: func(shape)}, **kwargs))

@overload
@classmethod
def empty(
cls: Type[OptionedClass[PInit, TDataArray]],
shape: Union[Shape, Sizes],
order: Order = "C",
**kwargs: Any,
) -> TDataArray: ...

@overload
@classmethod
def empty(
cls: Type[DataClass[PInit]],
shape: Union[Shape, Sizes],
order: Order = "C",
**kwargs: Any,
) -> xr.DataArray: ...
if TYPE_CHECKING:

@overload
@classmethod
def empty(
cls: Type[OptionedClass[PInit, TDataArray]],
shape: Union[Shape, Sizes],
order: Order = "C",
**kwargs: Any,
) -> TDataArray: ...

@overload
@classmethod
def empty(
cls: Type[DataClass[PInit]],
shape: Union[Shape, Sizes],
order: Order = "C",
**kwargs: Any,
) -> xr.DataArray: ...

@classmethod
def empty(
Expand All @@ -231,23 +254,25 @@ def empty(
func = partial(np.empty, order=order)
return cls.shaped(func, shape, **kwargs)

@overload
@classmethod
def zeros(
cls: Type[OptionedClass[PInit, TDataArray]],
shape: Union[Shape, Sizes],
order: Order = "C",
**kwargs: Any,
) -> TDataArray: ...

@overload
@classmethod
def zeros(
cls: Type[DataClass[PInit]],
shape: Union[Shape, Sizes],
order: Order = "C",
**kwargs: Any,
) -> xr.DataArray: ...
if TYPE_CHECKING:

@overload
@classmethod
def zeros(
cls: Type[OptionedClass[PInit, TDataArray]],
shape: Union[Shape, Sizes],
order: Order = "C",
**kwargs: Any,
) -> TDataArray: ...

@overload
@classmethod
def zeros(
cls: Type[DataClass[PInit]],
shape: Union[Shape, Sizes],
order: Order = "C",
**kwargs: Any,
) -> xr.DataArray: ...

@classmethod
def zeros(
Expand All @@ -271,23 +296,25 @@ def zeros(
func = partial(np.zeros, order=order)
return cls.shaped(func, shape, **kwargs)

@overload
@classmethod
def ones(
cls: Type[OptionedClass[PInit, TDataArray]],
shape: Union[Shape, Sizes],
order: Order = "C",
**kwargs: Any,
) -> TDataArray: ...

@overload
@classmethod
def ones(
cls: Type[DataClass[PInit]],
shape: Union[Shape, Sizes],
order: Order = "C",
**kwargs: Any,
) -> xr.DataArray: ...
if TYPE_CHECKING:

@overload
@classmethod
def ones(
cls: Type[OptionedClass[PInit, TDataArray]],
shape: Union[Shape, Sizes],
order: Order = "C",
**kwargs: Any,
) -> TDataArray: ...

@overload
@classmethod
def ones(
cls: Type[DataClass[PInit]],
shape: Union[Shape, Sizes],
order: Order = "C",
**kwargs: Any,
) -> xr.DataArray: ...

@classmethod
def ones(
Expand All @@ -311,25 +338,27 @@ def ones(
func = partial(np.ones, order=order)
return cls.shaped(func, shape, **kwargs)

@overload
@classmethod
def full(
cls: Type[OptionedClass[PInit, TDataArray]],
shape: Union[Shape, Sizes],
fill_value: Any,
order: Order = "C",
**kwargs: Any,
) -> TDataArray: ...

@overload
@classmethod
def full(
cls: Type[DataClass[PInit]],
shape: Union[Shape, Sizes],
fill_value: Any,
order: Order = "C",
**kwargs: Any,
) -> xr.DataArray: ...
if TYPE_CHECKING:

@overload
@classmethod
def full(
cls: Type[OptionedClass[PInit, TDataArray]],
shape: Union[Shape, Sizes],
fill_value: Any,
order: Order = "C",
**kwargs: Any,
) -> TDataArray: ...

@overload
@classmethod
def full(
cls: Type[DataClass[PInit]],
shape: Union[Shape, Sizes],
fill_value: Any,
order: Order = "C",
**kwargs: Any,
) -> xr.DataArray: ...

@classmethod
def full(
Expand Down
Loading