Skip to content

Commit

Permalink
wip: try decorator approach
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 committed Jan 26, 2024
1 parent 54f6c25 commit d6d2e52
Show file tree
Hide file tree
Showing 72 changed files with 649 additions and 388 deletions.
56 changes: 8 additions & 48 deletions src/awkward/_connect/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,75 +14,35 @@
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._parameters import parameters_union
from awkward._requirements import import_required_module

np = NumpyMetadata.instance()
numpy = Numpy.instance()

try:
import pyarrow

error_message = None
if parse_version(pyarrow.__version__) < parse_version("7.0.0"):
raise ImportError

except ModuleNotFoundError:
except (ModuleNotFoundError, ImportError):
pyarrow = None
error_message = """to use {0}, you must install pyarrow:
pip install pyarrow
or
conda install -c conda-forge pyarrow
"""

else:
if parse_version(pyarrow.__version__) < parse_version("7.0.0"):
pyarrow = None
error_message = "pyarrow 7.0.0 or later required for {0}"


def import_pyarrow(name: str) -> ModuleType:
if pyarrow is None:
raise ImportError(error_message.format(name))
return pyarrow
return import_required_module("pyarrow")


def import_pyarrow_parquet(name: str) -> ModuleType:
if pyarrow is None:
raise ImportError(error_message.format(name))

import pyarrow.parquet as out

return out
return import_required_module("pyarrow.parquet")


def import_pyarrow_compute(name: str) -> ModuleType:
if pyarrow is None:
raise ImportError(error_message.format(name))

import pyarrow.compute as out

return out
return import_required_module("pyarrow.compute")


def import_fsspec(name: str) -> ModuleType:
try:
import fsspec

except ModuleNotFoundError as err:
raise ImportError(
f"""to use {name}, you must install fsspec:
pip install fsspec
or
conda install -c conda-forge fsspec
"""
) from err

import_pyarrow_parquet(name)

return fsspec
return import_required_module("fsspec")


if pyarrow is not None:
Expand Down
227 changes: 36 additions & 191 deletions src/awkward/_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,185 +2,27 @@

from __future__ import annotations

import importlib
import sys
import warnings
import contextlib
from collections.abc import Callable, Collection, Generator, Mapping
from functools import lru_cache, partial, wraps
from functools import wraps
from inspect import isgeneratorfunction

import packaging.version
from packaging.requirements import Requirement
from packaging.version import Version

if sys.version_info < (3, 12):
import importlib_metadata
else:
import importlib.metadata as importlib_metadata

from awkward._errors import with_named_operation_context
from awkward._typing import Any, NamedTuple, TypeAlias, TypeVar
from awkward._requirements import (
dependency_specification_context,
with_has_requirements,
)
from awkward._typing import Any, TypeAlias, TypeVar

# First, we parse the dependency information into an immutable specification
T = TypeVar("T")
DispatcherType: TypeAlias = "Callable[..., Generator[Collection[Any], None, T]]"
DispatchedType: TypeAlias = "Callable[..., T]"


class DependencyGroup(NamedTuple):
name: str
dependencies: tuple[str, ...]


class DependencySpecification(NamedTuple):
groups: tuple[DependencyGroup, ...]
non_groups: tuple[str, ...]


def normalize_dependency_specification(
dependencies: Collection[str] | Mapping[str, Collection[str]] | None,
) -> DependencySpecification:
"""Normalize a dependency specification into a hashable object"""
if dependencies is None:
return DependencySpecification((), ())
elif isinstance(dependencies, Mapping):
return DependencySpecification(
groups=tuple(
DependencyGroup(key, tuple(values))
for key, values in dependencies.items()
),
non_groups=(),
)
else:
return DependencySpecification(groups=(), non_groups=tuple(dependencies))


def regularize_dunder_version(version: str) -> str:
return version.replace("/", ".")


def iter_missing_dependencies(dependencies: tuple[str, ...]):
"""Build an iterator over the requirement, version pairs of dependencies that are not
satisfied by the current environment"""
for _requirement in dependencies:
requirement = Requirement(_requirement)
if requirement.extras:
raise RuntimeError(
"High-level functions must not declare dependencies specifications with extras"
)

# Try and find version
try:
# Try and get the version the canonical way
_version = importlib_metadata.version(requirement.name)
except importlib_metadata.PackageNotFoundError:
# Otherwise, fall back on `__version__` (e.g. for ROOT)
mod = importlib.import_module(requirement.name)
try:
mod_version = mod.__version__
except AttributeError:
warnings.warn(
f"Could not identify the version of installed package {requirement.name}",
stacklevel=2,
)
# Don't treat this as an error
continue
# Packages like ROOT seem to be playing poorly with standards, so we'll apply a simple regularization transform
_version = regularize_dunder_version(mod_version)

# Try and parse version
try:
version = Version(_version)
except packaging.version.InvalidVersion:
warnings.warn(
f"Could not parse the version of installed package {requirement.name}: {_version!r}",
stacklevel=2,
)
# Don't treat this as an error
continue

if version not in requirement.specifier:
yield requirement, version


@lru_cache
def build_runtime_dependency_validation_error(
dependency_spec: DependencySpecification,
) -> Exception | None:
"""Build an exception object for the given dependency specification if it
is not satisfied by the current environment"""
missing_extras = []
missing_dependencies: list[tuple[Requirement, Version | None]] = []
if dependency_spec.groups:
for extra, extra_dependencies in dependency_spec.groups:
extra_missing_dependencies = [
*iter_missing_dependencies(extra_dependencies)
]
missing_dependencies.extend(extra_missing_dependencies)
if extra_missing_dependencies:
missing_extras.append(extra)
else:
missing_dependencies[:] = iter_missing_dependencies(dependency_spec.non_groups)

if not missing_dependencies:
return None

missing_requirement_lines = [
(
f" * {req} — you do not have this package"
if ver is None
else f" * {req} — you have {ver} installed"
)
for req, ver in missing_dependencies
]
missing_requirement_message = "\n".join(missing_requirement_lines)

# Install string
missing_dependencies_string = " ".join(
[str(req) for req, ver in missing_dependencies]
)
missing_requirements_direct_message = (
f"If you use pip, you can install these packages with "
f"`python -m pip install {missing_dependencies_string}`.\n"
"Otherwise, if you use Conda, install the corresponding packages "
"for the correct versions. "
)

missing_extras_lines = [f" * {extra}" for extra in missing_extras]
missing_extras_list_string = "\n".join(missing_extras_lines)
missing_extras_string = ",".join(missing_extras)
maybe_missing_extras_message = (
(
f"{missing_requirements_direct_message}\n\n"
"These dependencies can also be conveniently installed using the following extras:\n\n"
f"{missing_extras_list_string}\n\n"
f"If you're using `pip`, then you can install these extras with `pip install awkward[{missing_extras_string}]`"
)
if missing_extras
else missing_requirements_direct_message
)
return ImportError(
f"This function has the following dependency requirements that are not met by your current environment:\n\n"
f"{missing_requirement_message}\n\n"
f"{maybe_missing_extras_message}"
)


def validate_runtime_dependencies(
dependency_spec: DependencySpecification,
):
exception = build_runtime_dependency_validation_error(dependency_spec)
if exception is None:
return
else:
raise exception


def on_dispatch_trivial():
return


def with_type_dispatch(
func: DispatcherType, on_dispatch_internal: Callable[[], None] = on_dispatch_trivial
func: DispatcherType,
internal_dispatch_context_factory=contextlib.nullcontext,
) -> DispatchedType:
if isgeneratorfunction(func):

Expand Down Expand Up @@ -208,22 +50,21 @@ def dispatch(*args, **kwargs):
return result

# Failed to find a custom overload, so resume the original function
on_dispatch_internal()

try:
next(gen_or_result)
except StopIteration as err:
return err.value
else:
raise AssertionError(
"high-level functions should only implement a single yield statement"
)
with internal_dispatch_context_factory():
try:
next(gen_or_result)
except StopIteration as err:
return err.value
else:
raise AssertionError(
"high-level functions should only implement a single yield statement"
)
else:

@wraps(func)
def dispatch(*args, **kwargs):
on_dispatch_internal()
return func(*args, **kwargs)
with internal_dispatch_context_factory():
return func(*args, **kwargs)

return dispatch

Expand All @@ -232,24 +73,28 @@ def high_level_function(
module: str = "ak",
name: str | None = None,
*,
dependencies: Collection[str] | Mapping[str, Collection[str]] | None = None,
dependencies: Collection[str | Mapping[str, Collection[str]]] | None = None,
) -> Callable[[DispatcherType], DispatchedType]:
"""Decorate a high-level function such that it may be overloaded by third-party array objects"""

# Callback for dispatches that use internal implementation
on_dispatch_internal = partial(
validate_runtime_dependencies, normalize_dependency_specification(dependencies)
)

def capture_func(func: DispatcherType) -> DispatchedType:
def decorator(func: DispatcherType) -> DispatchedType:
if name is None:
captured_name = func.__qualname__
else:
captured_name = name

return with_named_operation_context(
with_type_dispatch(func, on_dispatch_internal=on_dispatch_internal),
f"{module}.{captured_name}",
# Context manager for dispatches that use internal implementation
def context_factory():
spec = func_has_requirements.get_specification()
return dependency_specification_context(spec)

func_has_requirements = with_has_requirements(
with_named_operation_context(
with_type_dispatch(func, context_factory),
f"{module}.{captured_name}",
)
)

return capture_func
return func_has_requirements

return decorator
Loading

0 comments on commit d6d2e52

Please sign in to comment.