Skip to content

Commit

Permalink
Merge branch 'main' into agoose77/fix-replace-slow-protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 committed Dec 22, 2022
2 parents 8559aaa + a242c6e commit 61f2903
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 62 deletions.
88 changes: 88 additions & 0 deletions src/awkward/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,3 +844,91 @@ def maybe_posaxis(layout, axis, depth):
return axis + depth + additional_depth - 1
else:
return None


def arrays_approx_equal(
left,
right,
rtol: float = 1e-5,
atol: float = 1e-8,
dtype_exact: bool = True,
check_parameters=True,
) -> bool:
# TODO: this should not be needed after refactoring nplike mechanism
import numpy

import awkward.forms.form

left_behavior = ak._util.behavior_of(left, behavior=ak.behavior)
right_behavior = ak._util.behavior_of(right, behavior=ak.behavior)

left = ak.to_packed(ak.to_layout(left, allow_record=False), highlevel=False)
right = ak.to_packed(ak.to_layout(right, allow_record=False), highlevel=False)

def is_approx_dtype(left, right) -> bool:
if not dtype_exact:
for family in numpy.integer, numpy.floating:
if numpy.issubdtype(left, family):
return numpy.issubdtype(right, family)
return left == right

def visitor(left, right) -> bool:
if not type(left) is type(right):
return False

if left.length != right.length:
return False

if check_parameters and not awkward.forms.form._parameters_equal(
left.parameters, right.parameters
):
return False

# Allow an `__array__` to be set with no value in `ak.behavior`;
# this is sometimes useful in testing. What we _don't_ want is for one
# array to have a behavior class and another to lack it.
array = left.parameter("__array__")
if not (
array is None
or (left_behavior.get(array) is right_behavior.get(array))
or not check_parameters
):
return False

if left.is_list:
return numpy.array_equal(left.offsets, right.offsets) and visitor(
left.content, right.content
)
elif left.is_regular:
return (left.size == right.size) and visitor(left.content, right.content)
elif left.is_numpy:
return is_approx_dtype(left.dtype, right.dtype) and numpy.allclose(
left.data, right.data, rtol=rtol, atol=atol, equal_nan=False
)
elif left.is_option:
return numpy.array_equal(
left.index.data < 0, right.index.data < 0
) and visitor(left.content, right.content)
elif left.is_union:
return (len(left.contents) == len(right.contents)) and all(
[
visitor(left.project(i).to_packed(), right.project(i).to_packed())
for i, _ in enumerate(left.contents)
]
)
elif left.is_record:
record = left.parameter("__record__")
return (
(
record is None
or (left_behavior.get(record) is right_behavior.get(record))
or not check_parameters
)
and (left.fields == right.fields)
and (left.is_tuple == right.is_tuple)
and all([visitor(x, y) for x, y in zip(left.contents, right.contents)])
)
elif left.is_empty:
return True

return visitor(left, right)
116 changes: 54 additions & 62 deletions tests/test_0355-mixins.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,13 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

import numbers

import numpy as np
import pytest

import awkward as ak

to_list = ak.operations.to_list


def _assert_equal_enough(obtained, expected):
if isinstance(obtained, dict):
assert isinstance(expected, dict)
assert set(obtained.keys()) == set(expected.keys())
for key in obtained.keys():
_assert_equal_enough(obtained[key], expected[key])
elif isinstance(obtained, list):
assert isinstance(expected, list)
assert len(obtained) == len(expected)
for x, y in zip(obtained, expected):
_assert_equal_enough(x, y)
elif isinstance(obtained, tuple):
assert isinstance(expected, tuple)
assert len(obtained) == len(expected)
for x, y in zip(obtained, expected):
_assert_equal_enough(x, y)
elif isinstance(obtained, numbers.Real) and isinstance(expected, numbers.Real):
assert pytest.approx(obtained) == expected
else:
assert obtained == expected


def assert_equal_enough(obtained, expected):
_assert_equal_enough(obtained.tolist(), expected)


def test_make_mixins():
behavior = {}

Expand Down Expand Up @@ -105,51 +77,71 @@ def weighted_add(self, other):
with_name="WeightedPoint",
behavior=behavior,
)
assert to_list(one + wone) == [
[{"x": 2, "y": 2.2}, {"x": 4, "y": 4.4}, {"x": 6, "y": 6.6}],
[],
[{"x": 8, "y": 8.8}, {"x": 10, "y": 11.0}],
]
assert_equal_enough(
wone + wtwo,
[
assert ak._util.arrays_approx_equal(
one + wone,
ak.Array(
[
{
"x": 0.9524937500390619,
"y": 1.052493750039062,
"weight": 2.831969279439222,
},
{"x": 2.0, "y": 2.2, "weight": 5.946427498927402},
{
"x": 2.9516640394605282,
"y": 3.1549921183815837,
"weight": 8.632349833200564,
},
[{"x": 2, "y": 2.2}, {"x": 4, "y": 4.4}, {"x": 6, "y": 6.6}],
[],
[{"x": 8, "y": 8.8}, {"x": 10, "y": 11.0}],
],
[],
behavior=behavior,
with_name="Point",
),
dtype_exact=False,
)
assert ak._util.arrays_approx_equal(
wone + wtwo,
ak.Array(
[
{
"x": 3.9515600270076154,
"y": 4.206240108030463,
"weight": 11.533018588312771,
},
{"x": 5.0, "y": 5.5, "weight": 14.866068747318506},
[
{
"x": 0.9524937500390619,
"y": 1.052493750039062,
"weight": 2.831969279439222,
},
{"x": 2.0, "y": 2.2, "weight": 5.946427498927402},
{
"x": 2.9516640394605282,
"y": 3.1549921183815837,
"weight": 8.632349833200564,
},
],
[],
[
{
"x": 3.9515600270076154,
"y": 4.206240108030463,
"weight": 11.533018588312771,
},
{"x": 5.0, "y": 5.5, "weight": 14.866068747318506},
],
],
],
behavior=behavior,
with_name="WeightedPoint",
),
dtype_exact=False,
)
assert_equal_enough(
assert ak._util.arrays_approx_equal(
abs(one),
[
[1.4866068747318506, 2.973213749463701, 4.459820624195552],
[],
[5.946427498927402, 7.433034373659253],
],
ak.Array(
[
[1.4866068747318506, 2.973213749463701, 4.459820624195552],
[],
[5.946427498927402, 7.433034373659253],
],
behavior=behavior,
with_name="Point",
),
dtype_exact=False,
)
assert_equal_enough(
assert ak._util.arrays_approx_equal(
one.distance(wtwo),
[
[0.14142135623730953, 0.0, 0.31622776601683783],
[],
[0.4123105625617664, 0.0],
],
dtype_exact=False,
check_parameters=False,
)

0 comments on commit 61f2903

Please sign in to comment.