diff --git a/src/xarray_dataclass/dataarray.py b/src/xarray_dataclass/dataarray.py index aaed391..affcd03 100644 --- a/src/xarray_dataclass/dataarray.py +++ b/src/xarray_dataclass/dataarray.py @@ -7,12 +7,20 @@ from functools import partial from inspect import signature from types import MethodType -from typing import Any, Callable, Optional, Protocol, Type, TypeVar, Union, overload +from typing import ( + Any, + Callable, + Optional, + Protocol, + TYPE_CHECKING, + Type, + TypeVar, + Union, + overload, +) # dependencies -import numpy as np -import xarray as xr from typing_extensions import ParamSpec @@ -20,11 +28,21 @@ from .datamodel import DataModel from .dataoptions import DataOptions from .typing import AnyArray, AnyXarray, DataClass, Order, Shape, Sizes +from .util import lazy_import -# type hints +# lazy imports of large modules +if TYPE_CHECKING: + import numpy as np + import xarray as xr +else: + np = lazy_import("numpy") + xr = lazy_import("xarray") + + +# private type hints PInit = ParamSpec("PInit") -TDataArray = TypeVar("TDataArray", bound=xr.DataArray) +TDataArray = TypeVar("TDataArray", bound="xr.DataArray") class OptionedClass(DataClass[PInit], Protocol[PInit, TDataArray]): @@ -33,29 +51,28 @@ class OptionedClass(DataClass[PInit], Protocol[PInit, TDataArray]): __dataoptions__: DataOptions[TDataArray] -# runtime functions -@overload -def asdataarray( - dataclass: OptionedClass[PInit, TDataArray], - reference: Optional[AnyXarray] = None, - dataoptions: None = None, -) -> TDataArray: ... - - -@overload -def asdataarray( - dataclass: DataClass[PInit], - reference: Optional[AnyXarray] = None, - dataoptions: None = None, -) -> xr.DataArray: ... +if TYPE_CHECKING: + # runtime functions + @overload + def asdataarray( + dataclass: OptionedClass[PInit, TDataArray], + reference: Optional[AnyXarray] = None, + dataoptions: None = None, + ) -> xr.DataArray: ... + @overload + def asdataarray( + dataclass: DataClass[PInit], + reference: Optional[AnyXarray] = None, + dataoptions: None = None, + ) -> xr.DataArray: ... -@overload -def asdataarray( - dataclass: Any, - reference: Optional[AnyXarray] = None, - dataoptions: DataOptions[TDataArray] = DataOptions(xr.DataArray), -) -> TDataArray: ... + @overload + def asdataarray( + dataclass: Any, + reference: Optional[AnyXarray] = None, + dataoptions: DataOptions[TDataArray] = DataOptions(xr.DataArray), + ) -> xr.DataArray: ... def asdataarray( @@ -112,19 +129,21 @@ class classproperty: def __init__(self, func: Any) -> None: self.__func__ = func - @overload - def __get__( - self, - obj: Any, - cls: Type[OptionedClass[PInit, TDataArray]], - ) -> Callable[PInit, TDataArray]: ... + if TYPE_CHECKING: - @overload - def __get__( - self, - obj: Any, - cls: Type[DataClass[PInit]], - ) -> Callable[PInit, xr.DataArray]: ... + @overload + def __get__( + self, + obj: Any, + cls: Type[OptionedClass[PInit, TDataArray]], + ) -> Callable[PInit, TDataArray]: ... + + @overload + def __get__( + self, + obj: Any, + cls: Type[DataClass[PInit]], + ) -> Callable[PInit, xr.DataArray]: ... def __get__(self, obj: Any, cls: Any) -> Any: return self.__func__(cls) @@ -147,23 +166,25 @@ def new(cls: Any, *args: Any, **kwargs: Any) -> Any: setattr(new, "__signature__", sig) return MethodType(new, cls) - @overload - @classmethod - def shaped( - cls: Type[OptionedClass[PInit, TDataArray]], - func: Callable[[Shape], AnyArray], - shape: Union[Shape, Sizes], - **kwargs: Any, - ) -> TDataArray: ... - - @overload - @classmethod - def shaped( - cls: Type[DataClass[PInit]], - func: Callable[[Shape], AnyArray], - shape: Union[Shape, Sizes], - **kwargs: Any, - ) -> xr.DataArray: ... + if TYPE_CHECKING: + + @overload + @classmethod + def shaped( + cls: Type[OptionedClass[PInit, TDataArray]], + func: Callable[[Shape], AnyArray], + shape: Union[Shape, Sizes], + **kwargs: Any, + ) -> TDataArray: ... + + @overload + @classmethod + def shaped( + cls: Type[DataClass[PInit]], + func: Callable[[Shape], AnyArray], + shape: Union[Shape, Sizes], + **kwargs: Any, + ) -> xr.DataArray: ... @classmethod def shaped( @@ -191,23 +212,25 @@ def shaped( return asdataarray(cls(**{key: func(shape)}, **kwargs)) - @overload - @classmethod - def empty( - cls: Type[OptionedClass[PInit, TDataArray]], - shape: Union[Shape, Sizes], - order: Order = "C", - **kwargs: Any, - ) -> TDataArray: ... - - @overload - @classmethod - def empty( - cls: Type[DataClass[PInit]], - shape: Union[Shape, Sizes], - order: Order = "C", - **kwargs: Any, - ) -> xr.DataArray: ... + if TYPE_CHECKING: + + @overload + @classmethod + def empty( + cls: Type[OptionedClass[PInit, TDataArray]], + shape: Union[Shape, Sizes], + order: Order = "C", + **kwargs: Any, + ) -> TDataArray: ... + + @overload + @classmethod + def empty( + cls: Type[DataClass[PInit]], + shape: Union[Shape, Sizes], + order: Order = "C", + **kwargs: Any, + ) -> xr.DataArray: ... @classmethod def empty( @@ -231,23 +254,25 @@ def empty( func = partial(np.empty, order=order) return cls.shaped(func, shape, **kwargs) - @overload - @classmethod - def zeros( - cls: Type[OptionedClass[PInit, TDataArray]], - shape: Union[Shape, Sizes], - order: Order = "C", - **kwargs: Any, - ) -> TDataArray: ... - - @overload - @classmethod - def zeros( - cls: Type[DataClass[PInit]], - shape: Union[Shape, Sizes], - order: Order = "C", - **kwargs: Any, - ) -> xr.DataArray: ... + if TYPE_CHECKING: + + @overload + @classmethod + def zeros( + cls: Type[OptionedClass[PInit, TDataArray]], + shape: Union[Shape, Sizes], + order: Order = "C", + **kwargs: Any, + ) -> TDataArray: ... + + @overload + @classmethod + def zeros( + cls: Type[DataClass[PInit]], + shape: Union[Shape, Sizes], + order: Order = "C", + **kwargs: Any, + ) -> xr.DataArray: ... @classmethod def zeros( @@ -271,23 +296,25 @@ def zeros( func = partial(np.zeros, order=order) return cls.shaped(func, shape, **kwargs) - @overload - @classmethod - def ones( - cls: Type[OptionedClass[PInit, TDataArray]], - shape: Union[Shape, Sizes], - order: Order = "C", - **kwargs: Any, - ) -> TDataArray: ... - - @overload - @classmethod - def ones( - cls: Type[DataClass[PInit]], - shape: Union[Shape, Sizes], - order: Order = "C", - **kwargs: Any, - ) -> xr.DataArray: ... + if TYPE_CHECKING: + + @overload + @classmethod + def ones( + cls: Type[OptionedClass[PInit, TDataArray]], + shape: Union[Shape, Sizes], + order: Order = "C", + **kwargs: Any, + ) -> TDataArray: ... + + @overload + @classmethod + def ones( + cls: Type[DataClass[PInit]], + shape: Union[Shape, Sizes], + order: Order = "C", + **kwargs: Any, + ) -> xr.DataArray: ... @classmethod def ones( @@ -311,25 +338,27 @@ def ones( func = partial(np.ones, order=order) return cls.shaped(func, shape, **kwargs) - @overload - @classmethod - def full( - cls: Type[OptionedClass[PInit, TDataArray]], - shape: Union[Shape, Sizes], - fill_value: Any, - order: Order = "C", - **kwargs: Any, - ) -> TDataArray: ... - - @overload - @classmethod - def full( - cls: Type[DataClass[PInit]], - shape: Union[Shape, Sizes], - fill_value: Any, - order: Order = "C", - **kwargs: Any, - ) -> xr.DataArray: ... + if TYPE_CHECKING: + + @overload + @classmethod + def full( + cls: Type[OptionedClass[PInit, TDataArray]], + shape: Union[Shape, Sizes], + fill_value: Any, + order: Order = "C", + **kwargs: Any, + ) -> TDataArray: ... + + @overload + @classmethod + def full( + cls: Type[DataClass[PInit]], + shape: Union[Shape, Sizes], + fill_value: Any, + order: Order = "C", + **kwargs: Any, + ) -> xr.DataArray: ... @classmethod def full( diff --git a/src/xarray_dataclass/datamodel.py b/src/xarray_dataclass/datamodel.py index 450b9d8..8abd1f1 100644 --- a/src/xarray_dataclass/datamodel.py +++ b/src/xarray_dataclass/datamodel.py @@ -14,6 +14,7 @@ Literal, Optional, Tuple, + TYPE_CHECKING, Type, Union, cast, @@ -21,9 +22,7 @@ # dependencies -import numpy as np -import xarray as xr -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, get_type_hints # submodules @@ -41,7 +40,15 @@ get_name, get_role, ) +from .util import lazy_import +# lazy imports of large modules +if TYPE_CHECKING: + import numpy as np + import xarray as xr +else: + np = lazy_import("numpy") + xr = lazy_import("xarray") # type hints PInit = ParamSpec("PInit") @@ -134,7 +141,7 @@ def __post_init__(self) -> None: if model.names: setattr(self, "name", model.names[0].value) - def __call__(self, reference: Optional[AnyXarray] = None) -> xr.DataArray: # pyright: ignore[reportUnknownParameterType] + def __call__(self, reference: Optional[AnyXarray] = None) -> "xr.DataArray": # pyright: ignore[reportUnknownParameterType] """Create a DataArray object according to the entry.""" from .dataarray import asdataarray @@ -256,7 +263,7 @@ def get_typedarray( # pyright: ignore[reportUnknownParameterType] dims: Dims, dtype: Optional[AnyDType], # pyright: ignore[reportUnknownParameterType] reference: Optional[AnyXarray] = None, # pyright: ignore[reportUnknownParameterType] -) -> xr.DataArray: +) -> "xr.DataArray": """Create a DataArray object with given dims and dtype. Args: diff --git a/src/xarray_dataclass/dataset.py b/src/xarray_dataclass/dataset.py index 9f4bd2f..590faa8 100644 --- a/src/xarray_dataclass/dataset.py +++ b/src/xarray_dataclass/dataset.py @@ -7,12 +7,20 @@ from functools import partial from inspect import signature from types import MethodType -from typing import Any, Callable, Dict, Optional, Protocol, Type, TypeVar, overload +from typing import ( + Any, + Callable, + Dict, + Optional, + Protocol, + Type, + TYPE_CHECKING, + TypeVar, + overload, +) # dependencies -import numpy as np -import xarray as xr from typing_extensions import ParamSpec @@ -20,11 +28,20 @@ from .datamodel import DataModel from .dataoptions import DataOptions from .typing import AnyArray, AnyXarray, DataClass, Order, Shape, Sizes +from .util import lazy_import + +# lazy imports of large modules +if TYPE_CHECKING: + import numpy as np + import xarray as xr +else: + numpy = lazy_import("xarray") + xr = lazy_import("xarray") # type hints PInit = ParamSpec("PInit") -TDataset = TypeVar("TDataset", bound=xr.Dataset) +TDataset = TypeVar("TDataset", bound="xr.Dataset") class OptionedClass(DataClass[PInit], Protocol[PInit, TDataset]): @@ -34,28 +51,28 @@ class OptionedClass(DataClass[PInit], Protocol[PInit, TDataset]): # runtime functions and classes -@overload -def asdataset( - dataclass: OptionedClass[PInit, TDataset], - reference: Optional[AnyXarray] = None, - dataoptions: None = None, -) -> TDataset: ... +if TYPE_CHECKING: + @overload + def asdataset( + dataclass: OptionedClass[PInit, TDataset], + reference: Optional[AnyXarray] = None, + dataoptions: None = None, + ) -> TDataset: ... -@overload -def asdataset( - dataclass: DataClass[PInit], - reference: Optional[AnyXarray] = None, - dataoptions: None = None, -) -> xr.Dataset: ... - + @overload + def asdataset( + dataclass: DataClass[PInit], + reference: Optional[AnyXarray] = None, + dataoptions: None = None, + ) -> xr.Dataset: ... -@overload -def asdataset( - dataclass: Any, - reference: Optional[AnyXarray] = None, - dataoptions: DataOptions[TDataset] = DataOptions(xr.Dataset), -) -> TDataset: ... + @overload + def asdataset( + dataclass: Any, + reference: Optional[AnyXarray] = None, + dataoptions: DataOptions[TDataset] = DataOptions(xr.Dataset), + ) -> TDataset: ... def asdataset( @@ -112,19 +129,21 @@ class classproperty: def __init__(self, func: Callable[..., Any]) -> None: self.__func__ = func - @overload - def __get__( - self, - obj: Any, - cls: Type[OptionedClass[PInit, TDataset]], - ) -> Callable[PInit, TDataset]: ... + if TYPE_CHECKING: - @overload - def __get__( - self, - obj: Any, - cls: Type[DataClass[PInit]], - ) -> Callable[PInit, xr.Dataset]: ... + @overload + def __get__( + self, + obj: Any, + cls: Type[OptionedClass[PInit, TDataset]], + ) -> Callable[PInit, TDataset]: ... + + @overload + def __get__( + self, + obj: Any, + cls: Type[DataClass[PInit]], + ) -> Callable[PInit, xr.Dataset]: ... def __get__(self, obj: Any, cls: Any) -> Any: return self.__func__(cls) @@ -147,23 +166,25 @@ def new(cls: Any, *args: Any, **kwargs: Any) -> Any: setattr(new, "__signature__", sig) return MethodType(new, cls) - @overload - @classmethod - def shaped( - cls: Type[OptionedClass[PInit, TDataset]], - func: Callable[[Shape], AnyArray], - sizes: Sizes, - **kwargs: Any, - ) -> TDataset: ... - - @overload - @classmethod - def shaped( - cls: Type[DataClass[PInit]], - func: Callable[[Shape], AnyArray], - sizes: Sizes, - **kwargs: Any, - ) -> xr.Dataset: ... + if TYPE_CHECKING: + + @overload + @classmethod + def shaped( + cls: Type[OptionedClass[PInit, TDataset]], + func: Callable[[Shape], AnyArray], + sizes: Sizes, + **kwargs: Any, + ) -> TDataset: ... + + @overload + @classmethod + def shaped( + cls: Type[DataClass[PInit]], + func: Callable[[Shape], AnyArray], + sizes: Sizes, + **kwargs: Any, + ) -> xr.Dataset: ... @classmethod def shaped( @@ -192,23 +213,25 @@ def shaped( return asdataset(cls(**data_vars, **kwargs)) - @overload - @classmethod - def empty( - cls: Type[OptionedClass[PInit, TDataset]], - sizes: Sizes, - order: Order = "C", - **kwargs: Any, - ) -> TDataset: ... - - @overload - @classmethod - def empty( - cls: Type[DataClass[PInit]], - sizes: Sizes, - order: Order = "C", - **kwargs: Any, - ) -> xr.Dataset: ... + if TYPE_CHECKING: + + @overload + @classmethod + def empty( + cls: Type[OptionedClass[PInit, TDataset]], + sizes: Sizes, + order: Order = "C", + **kwargs: Any, + ) -> TDataset: ... + + @overload + @classmethod + def empty( + cls: Type[DataClass[PInit]], + sizes: Sizes, + order: Order = "C", + **kwargs: Any, + ) -> xr.Dataset: ... @classmethod def empty( @@ -232,23 +255,25 @@ def empty( func = partial(np.empty, order=order) return cls.shaped(func, sizes, **kwargs) - @overload - @classmethod - def zeros( - cls: Type[OptionedClass[PInit, TDataset]], - sizes: Sizes, - order: Order = "C", - **kwargs: Any, - ) -> TDataset: ... - - @overload - @classmethod - def zeros( - cls: Type[DataClass[PInit]], - sizes: Sizes, - order: Order = "C", - **kwargs: Any, - ) -> xr.Dataset: ... + if TYPE_CHECKING: + + @overload + @classmethod + def zeros( + cls: Type[OptionedClass[PInit, TDataset]], + sizes: Sizes, + order: Order = "C", + **kwargs: Any, + ) -> TDataset: ... + + @overload + @classmethod + def zeros( + cls: Type[DataClass[PInit]], + sizes: Sizes, + order: Order = "C", + **kwargs: Any, + ) -> xr.Dataset: ... @classmethod def zeros( @@ -272,23 +297,25 @@ def zeros( func = partial(np.zeros, order=order) return cls.shaped(func, sizes, **kwargs) - @overload - @classmethod - def ones( - cls: Type[OptionedClass[PInit, TDataset]], - sizes: Sizes, - order: Order = "C", - **kwargs: Any, - ) -> TDataset: ... - - @overload - @classmethod - def ones( - cls: Type[DataClass[PInit]], - sizes: Sizes, - order: Order = "C", - **kwargs: Any, - ) -> xr.Dataset: ... + if TYPE_CHECKING: + + @overload + @classmethod + def ones( + cls: Type[OptionedClass[PInit, TDataset]], + sizes: Sizes, + order: Order = "C", + **kwargs: Any, + ) -> TDataset: ... + + @overload + @classmethod + def ones( + cls: Type[DataClass[PInit]], + sizes: Sizes, + order: Order = "C", + **kwargs: Any, + ) -> xr.Dataset: ... @classmethod def ones( @@ -312,25 +339,27 @@ def ones( func = partial(np.ones, order=order) return cls.shaped(func, sizes, **kwargs) - @overload - @classmethod - def full( - cls: Type[OptionedClass[PInit, TDataset]], - sizes: Sizes, - fill_value: Any, - order: Order = "C", - **kwargs: Any, - ) -> TDataset: ... - - @overload - @classmethod - def full( - cls: Type[DataClass[PInit]], - sizes: Sizes, - fill_value: Any, - order: Order = "C", - **kwargs: Any, - ) -> xr.Dataset: ... + if TYPE_CHECKING: + + @overload + @classmethod + def full( + cls: Type[OptionedClass[PInit, TDataset]], + sizes: Sizes, + fill_value: Any, + order: Order = "C", + **kwargs: Any, + ) -> TDataset: ... + + @overload + @classmethod + def full( + cls: Type[DataClass[PInit]], + sizes: Sizes, + fill_value: Any, + order: Order = "C", + **kwargs: Any, + ) -> xr.Dataset: ... @classmethod def full( diff --git a/src/xarray_dataclass/typing.py b/src/xarray_dataclass/typing.py index 8a9741c..fceb59a 100644 --- a/src/xarray_dataclass/typing.py +++ b/src/xarray_dataclass/typing.py @@ -17,7 +17,6 @@ __all__ = ["Attr", "Coord", "Coordof", "Data", "Dataof", "Name"] - # standard library from dataclasses import Field, is_dataclass from enum import Enum @@ -39,6 +38,7 @@ Protocol, Sequence, Tuple, + TYPE_CHECKING, Type, TypeVar, Union, @@ -46,10 +46,18 @@ # dependencies -import numpy as np -import xarray as xr from typing_extensions import ParamSpec, TypeAlias +# submodules +from .util import lazy_import + +# lazy imports of large modules +if TYPE_CHECKING: + import numpy as np + import xarray as xr +else: + np = lazy_import("numpy") + xr = lazy_import("xarray") # type hints (private) PInit = ParamSpec("PInit") @@ -62,7 +70,7 @@ AnyArray: TypeAlias = np.ndarray[Any, Any] AnyDType: TypeAlias = np.dtype[Any] AnyField: TypeAlias = Field[Any] -AnyXarray: TypeAlias = Union[xr.DataArray, xr.Dataset] +AnyXarray: TypeAlias = Union["xr.DataArray", "xr.Dataset"] Dims = Tuple[str, ...] Order = Literal["C", "F"] Shape = Union[Sequence[int], int] diff --git a/src/xarray_dataclass/util.py b/src/xarray_dataclass/util.py new file mode 100644 index 0000000..8a639dc --- /dev/null +++ b/src/xarray_dataclass/util.py @@ -0,0 +1,28 @@ +import importlib +import importlib.util +import sys + + +def lazy_import(module_name: str): + """postponed import of the module with the specified name. + + The import is not performed until the module is accessed in the code. This + reduces the total time for an initial import by waiting to import the module + until its attributes are accessed. + """ + + # see https://docs.python.org/3/library/importlib.html#implementing-lazy-imports + try: + ret = sys.modules[module_name] + return ret + except KeyError: + pass + + spec = importlib.util.find_spec(module_name) + if spec is None: + raise ImportError(f'no module found named "{module_name}"') + spec.loader = importlib.util.LazyLoader(spec.loader) # type: ignore + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module