Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: add internal with_operation_context #2525

Merged
merged 8 commits into from
Jun 16, 2023
95 changes: 74 additions & 21 deletions src/awkward/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import sys
import threading
import warnings
from collections.abc import Mapping, Sequence
from collections.abc import Callable, Collection, Iterable, Mapping
from functools import wraps

import numpy # noqa: TID251

from awkward._nplikes.numpylike import NumpyMetadata
from awkward._typing import TypeVar
from awkward._typing import Any, TypeVar

np = NumpyMetadata.instance()

Expand Down Expand Up @@ -161,7 +162,7 @@ def format_argument(self, width, value):
if len(valuestr) > width:
valuestr = valuestr[: width - 3] + "..."

elif isinstance(value, (Sequence, Mapping)) and len(value) < 10000:
elif isinstance(value, (Collection, Mapping)) and len(value) < 10000:
valuestr = repr(value)
if len(valuestr) > width:
valuestr = valuestr[: width - 3] + "..."
Expand All @@ -182,66 +183,105 @@ def note(self) -> str:
class OperationErrorContext(ErrorContext):
_width = 80 - 8

def __init__(self, name, arguments):
def any_backend_is_delayed(
self, iterable: Iterable, *, depth: int = 1, depth_limit: int = 2
) -> bool:
from awkward._backends.dispatch import backend_of
from awkward._backends.numpy import NumpyBackend

numpy_backend = NumpyBackend.instance()
if self.primary() is not None or all(
backend_of(x, default=numpy_backend).nplike.is_eager for x in arguments
for obj in iterable:
backend = backend_of(obj, default=None)
# Do we not recognise this as an object with a backend?
if backend is None:
# Is this an iterable object, and are we permitted to recurse?
if isinstance(obj, Collection) and depth != depth_limit:
return self.any_backend_is_delayed(
obj, depth=depth + 1, depth_limit=depth_limit
)
# Assume not delayed!
else:
return False
# Eager backends aren't delayed!
elif backend.nplike.is_eager:
continue
else:
return True
return False

def __init__(self, name, args: Iterable[Any], kwargs: Mapping[str, Any]):
if self.primary() is None and (
self.any_backend_is_delayed(args)
or self.any_backend_is_delayed(kwargs.values())
):
string_args = self._format_args(args)
string_kwargs = self._format_kwargs(kwargs)
else:
# if primary is not None: we won't be setting an ErrorContext
# if all nplikes are eager: no accumulation of large arrays
# --> in either case, delay string generation
string_arguments = PartialFunction(self._string_arguments, arguments)
else:
string_arguments = self._string_arguments(arguments)
string_args = PartialFunction(self._format_args, args)
string_kwargs = PartialFunction(self._format_kwargs, kwargs)

super().__init__(
name=name,
arguments=string_arguments,
args=string_args,
kwargs=string_kwargs,
)

def _string_arguments(self, arguments):
def _format_args(self, arguments: Iterable) -> list[str]:
string_arguments = []
for value in arguments:
string_arguments.append(self.format_argument(self._width, value))

return string_arguments

def _format_kwargs(self, arguments: Mapping[str, Any]) -> dict[str, str]:
string_arguments = {}
for key, value in arguments.items():
if isinstance(key, str):
width = self._width - len(key) - 3
else:
width = self._width

string_arguments[key] = self.format_argument(width, value)

return string_arguments

@property
def name(self):
return self._kwargs["name"]

@property
def arguments(self):
out = self._kwargs["arguments"]
def args(self) -> list:
out = self._kwargs["args"]
if isinstance(out, PartialFunction):
out = self._kwargs["arguments"] = out()
out = self._kwargs["args"] = out()
return out

def format_exception(self, exception):
@property
def kwargs(self) -> dict:
out = self._kwargs["kwargs"]
if isinstance(out, PartialFunction):
out = self._kwargs["kwargs"] = out()
return out

def format_exception(self, exception: Exception) -> str:
return f"{exception}\n{self.note}"

@property
def note(self) -> str:
arguments = []
for name, valuestr in self.arguments.items():
for valuestr in self.args:
arguments.append(f"\n {valuestr}")
for name, valuestr in self.kwargs.items():
if isinstance(name, str):
arguments.append(f"\n {name} = {valuestr}")
else:
arguments.append(f"\n {valuestr}")

extra_line = "" if len(arguments) == 0 else "\n "
calling_note = f'{self.name}({"".join(arguments)}{extra_line})'
return f"""
This error occurred while calling

{self.name}({"".join(arguments)}{extra_line})"""
{calling_note}"""


class SlicingErrorContext(ErrorContext):
Expand Down Expand Up @@ -381,3 +421,16 @@ class FieldNotFoundError(IndexError):


AxisError = numpy.AxisError


T = TypeVar("T", bound=Callable)


def with_operation_context(func: T) -> T:
@wraps(func)
def wrapper(*args, **kwargs):
# NOTE: this decorator assumes that the operation is exposed under `ak.`
with OperationErrorContext(f"ak.{func.__qualname__}", args, kwargs):
return func(*args, **kwargs)

return wrapper
40 changes: 14 additions & 26 deletions src/awkward/highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def __init__(self, array):

def __getitem__(self, where):
with ak._errors.OperationErrorContext(
"ak.Array.mask", {0: self._array, 1: where}
"ak.Array.mask", args=[self._array, where], kwargs={}
):
return ak.operations.mask(self._array, where, valid_when=True)

Expand Down Expand Up @@ -1018,7 +1018,8 @@ def __setitem__(self, where, what):
"""
with ak._errors.OperationErrorContext(
"ak.Array.__setitem__",
{"self": self, "field_name": where, "field_value": what},
(self,),
{"where": where, "what": what},
):
if not (
isinstance(where, str)
Expand Down Expand Up @@ -1049,7 +1050,8 @@ def __delitem__(self, where):
"""
with ak._errors.OperationErrorContext(
"ak.Array.__delitem__",
{"self": self, "field_name": where},
(self,),
{"where": where},
):
if not (
isinstance(where, str)
Expand Down Expand Up @@ -1281,11 +1283,7 @@ def __array__(self, *args, **kwargs):
nested lists in a NumPy `"O"` array are severed from the array and
cannot be sliced as dimensions.
"""
arguments = {0: self}
for i, arg in enumerate(args):
arguments[i + 1] = arg
arguments.update(kwargs)
with ak._errors.OperationErrorContext("numpy.asarray", arguments):
with ak._errors.OperationErrorContext("numpy.asarray", (self, *args), kwargs):
return ak._connect.numpy.convert_to_array(self._layout, args, kwargs)

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
Expand Down Expand Up @@ -1354,11 +1352,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
See also #__array_function__.
"""
name = f"{type(ufunc).__module__}.{ufunc.__name__}.{method!s}"
arguments = {}
for i, arg in enumerate(inputs):
arguments[i] = arg
arguments.update(kwargs)
with ak._errors.OperationErrorContext(name, arguments):
with ak._errors.OperationErrorContext(name, inputs, kwargs):
return ak._connect.numpy.array_ufunc(ufunc, method, inputs, kwargs)

def __array_function__(self, func, types, args, kwargs):
Expand Down Expand Up @@ -1787,7 +1781,8 @@ def __setitem__(self, where, what):
"""
with ak._errors.OperationErrorContext(
"ak.Record.__setitem__",
{"self": self, "field_name": where, "field_value": what},
(self,),
{"where": where, "what": what},
):
if not (
isinstance(where, str)
Expand Down Expand Up @@ -1819,7 +1814,8 @@ def __delitem__(self, where):
"""
with ak._errors.OperationErrorContext(
"ak.Record.__delitem__",
{"self": self, "field_name": where},
(self,),
{"where": where},
):
if not (
isinstance(where, str)
Expand Down Expand Up @@ -2030,11 +2026,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
See #ak.Array.__array_ufunc__ for a more complete description.
"""
name = f"{type(ufunc).__module__}.{ufunc.__name__}.{method!s}"
arguments = {}
for i, arg in enumerate(inputs):
arguments[i] = arg
arguments.update(kwargs)
with ak._errors.OperationErrorContext(name, arguments):
with ak._errors.OperationErrorContext(name, inputs, kwargs):
return ak._connect.numpy.array_ufunc(ufunc, method, inputs, kwargs)

@property
Expand Down Expand Up @@ -2376,11 +2368,7 @@ def __array__(self, *args, **kwargs):

See #ak.Array.__array__ for a more complete description.
"""
arguments = {0: self}
for i, arg in enumerate(args):
arguments[i + 1] = arg
arguments.update(kwargs)
with ak._errors.OperationErrorContext("numpy.asarray", arguments):
with ak._errors.OperationErrorContext("numpy.asarray", (self, *args), kwargs):
return ak._connect.numpy.convert_to_array(self.snapshot(), args, kwargs)

@property
Expand Down Expand Up @@ -2415,7 +2403,7 @@ def snapshot(self):
formstr, length, container = self._layout.to_buffers()
form = ak.forms.from_json(formstr)

with ak._errors.OperationErrorContext("ak.ArrayBuilder.snapshot", {}):
with ak._errors.OperationErrorContext("ak.ArrayBuilder.snapshot", [], {}):
return ak.operations.ak_from_buffers._impl(
form,
length,
Expand Down
15 changes: 3 additions & 12 deletions src/awkward/operations/ak_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import awkward as ak
from awkward._behavior import behavior_of
from awkward._connect.numpy import UNSUPPORTED
from awkward._errors import with_operation_context
from awkward._layout import wrap_layout
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._regularize import regularize_axis

np = NumpyMetadata.instance()


@with_operation_context
def all(
array,
axis=None,
Expand Down Expand Up @@ -51,18 +53,7 @@ def all(
See #ak.sum for a more complete description of nested list and missing
value (None) handling in reducers.
"""
with ak._errors.OperationErrorContext(
"ak.all",
{
"array": array,
"axis": axis,
"keepdims": keepdims,
"mask_identity": mask_identity,
"highlevel": highlevel,
"behavior": behavior,
},
):
return _impl(array, axis, keepdims, mask_identity, highlevel, behavior)
return _impl(array, axis, keepdims, mask_identity, highlevel, behavior)


def _impl(array, axis, keepdims, mask_identity, highlevel, behavior):
Expand Down
4 changes: 2 additions & 2 deletions src/awkward/operations/ak_almost_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@
from __future__ import annotations

__all__ = ("almost_equal",)


from awkward._backends.dispatch import backend_of
from awkward._behavior import behavior_of, get_array_class, get_record_class
from awkward._errors import with_operation_context
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._parameters import parameters_are_equal
from awkward.operations.ak_to_layout import to_layout

np = NumpyMetadata.instance()


@with_operation_context
def almost_equal(
left,
right,
Expand Down
15 changes: 3 additions & 12 deletions src/awkward/operations/ak_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import awkward as ak
from awkward._behavior import behavior_of
from awkward._connect.numpy import UNSUPPORTED
from awkward._errors import with_operation_context
from awkward._layout import wrap_layout
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._regularize import regularize_axis

np = NumpyMetadata.instance()


@with_operation_context
def any(
array,
axis=None,
Expand Down Expand Up @@ -51,18 +53,7 @@ def any(
See #ak.sum for a more complete description of nested list and missing
value (None) handling in reducers.
"""
with ak._errors.OperationErrorContext(
"ak.any",
{
"array": array,
"axis": axis,
"keepdims": keepdims,
"mask_identity": mask_identity,
"highlevel": highlevel,
"behavior": behavior,
},
):
return _impl(array, axis, keepdims, mask_identity, highlevel, behavior)
return _impl(array, axis, keepdims, mask_identity, highlevel, behavior)


def _impl(array, axis, keepdims, mask_identity, highlevel, behavior):
Expand Down
Loading