From a242c6ebeec46691c1583397d4430905775d03ca Mon Sep 17 00:00:00 2001 From: Angus Hollands Date: Thu, 22 Dec 2022 10:29:18 +0000 Subject: [PATCH] refactor: add array comparison test helper (#2024) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/awkward/_util.py | 88 +++++++++++++++++++++++++++++ tests/test_0355-mixins.py | 116 ++++++++++++++++++-------------------- 2 files changed, 142 insertions(+), 62 deletions(-) diff --git a/src/awkward/_util.py b/src/awkward/_util.py index b9b2fdb370..f94400387d 100644 --- a/src/awkward/_util.py +++ b/src/awkward/_util.py @@ -844,3 +844,91 @@ def maybe_posaxis(layout, axis, depth): return axis + depth + additional_depth - 1 else: return None + + +def arrays_approx_equal( + left, + right, + rtol: float = 1e-5, + atol: float = 1e-8, + dtype_exact: bool = True, + check_parameters=True, +) -> bool: + # TODO: this should not be needed after refactoring nplike mechanism + import numpy + + import awkward.forms.form + + left_behavior = ak._util.behavior_of(left, behavior=ak.behavior) + right_behavior = ak._util.behavior_of(right, behavior=ak.behavior) + + left = ak.to_packed(ak.to_layout(left, allow_record=False), highlevel=False) + right = ak.to_packed(ak.to_layout(right, allow_record=False), highlevel=False) + + def is_approx_dtype(left, right) -> bool: + if not dtype_exact: + for family in numpy.integer, numpy.floating: + if numpy.issubdtype(left, family): + return numpy.issubdtype(right, family) + return left == right + + def visitor(left, right) -> bool: + if not type(left) is type(right): + return False + + if left.length != right.length: + return False + + if check_parameters and not awkward.forms.form._parameters_equal( + left.parameters, right.parameters + ): + return False + + # Allow an `__array__` to be set with no value in `ak.behavior`; + # this is sometimes useful in testing. What we _don't_ want is for one + # array to have a behavior class and another to lack it. + array = left.parameter("__array__") + if not ( + array is None + or (left_behavior.get(array) is right_behavior.get(array)) + or not check_parameters + ): + return False + + if left.is_list: + return numpy.array_equal(left.offsets, right.offsets) and visitor( + left.content, right.content + ) + elif left.is_regular: + return (left.size == right.size) and visitor(left.content, right.content) + elif left.is_numpy: + return is_approx_dtype(left.dtype, right.dtype) and numpy.allclose( + left.data, right.data, rtol=rtol, atol=atol, equal_nan=False + ) + elif left.is_option: + return numpy.array_equal( + left.index.data < 0, right.index.data < 0 + ) and visitor(left.content, right.content) + elif left.is_union: + return (len(left.contents) == len(right.contents)) and all( + [ + visitor(left.project(i).to_packed(), right.project(i).to_packed()) + for i, _ in enumerate(left.contents) + ] + ) + elif left.is_record: + record = left.parameter("__record__") + return ( + ( + record is None + or (left_behavior.get(record) is right_behavior.get(record)) + or not check_parameters + ) + and (left.fields == right.fields) + and (left.is_tuple == right.is_tuple) + and all([visitor(x, y) for x, y in zip(left.contents, right.contents)]) + ) + elif left.is_empty: + return True + + return visitor(left, right) diff --git a/tests/test_0355-mixins.py b/tests/test_0355-mixins.py index 5ff11fe708..727532acd6 100644 --- a/tests/test_0355-mixins.py +++ b/tests/test_0355-mixins.py @@ -1,41 +1,13 @@ # BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE -import numbers import numpy as np -import pytest import awkward as ak to_list = ak.operations.to_list -def _assert_equal_enough(obtained, expected): - if isinstance(obtained, dict): - assert isinstance(expected, dict) - assert set(obtained.keys()) == set(expected.keys()) - for key in obtained.keys(): - _assert_equal_enough(obtained[key], expected[key]) - elif isinstance(obtained, list): - assert isinstance(expected, list) - assert len(obtained) == len(expected) - for x, y in zip(obtained, expected): - _assert_equal_enough(x, y) - elif isinstance(obtained, tuple): - assert isinstance(expected, tuple) - assert len(obtained) == len(expected) - for x, y in zip(obtained, expected): - _assert_equal_enough(x, y) - elif isinstance(obtained, numbers.Real) and isinstance(expected, numbers.Real): - assert pytest.approx(obtained) == expected - else: - assert obtained == expected - - -def assert_equal_enough(obtained, expected): - _assert_equal_enough(obtained.tolist(), expected) - - def test_make_mixins(): behavior = {} @@ -105,51 +77,71 @@ def weighted_add(self, other): with_name="WeightedPoint", behavior=behavior, ) - assert to_list(one + wone) == [ - [{"x": 2, "y": 2.2}, {"x": 4, "y": 4.4}, {"x": 6, "y": 6.6}], - [], - [{"x": 8, "y": 8.8}, {"x": 10, "y": 11.0}], - ] - assert_equal_enough( - wone + wtwo, - [ + assert ak._util.arrays_approx_equal( + one + wone, + ak.Array( [ - { - "x": 0.9524937500390619, - "y": 1.052493750039062, - "weight": 2.831969279439222, - }, - {"x": 2.0, "y": 2.2, "weight": 5.946427498927402}, - { - "x": 2.9516640394605282, - "y": 3.1549921183815837, - "weight": 8.632349833200564, - }, + [{"x": 2, "y": 2.2}, {"x": 4, "y": 4.4}, {"x": 6, "y": 6.6}], + [], + [{"x": 8, "y": 8.8}, {"x": 10, "y": 11.0}], ], - [], + behavior=behavior, + with_name="Point", + ), + dtype_exact=False, + ) + assert ak._util.arrays_approx_equal( + wone + wtwo, + ak.Array( [ - { - "x": 3.9515600270076154, - "y": 4.206240108030463, - "weight": 11.533018588312771, - }, - {"x": 5.0, "y": 5.5, "weight": 14.866068747318506}, + [ + { + "x": 0.9524937500390619, + "y": 1.052493750039062, + "weight": 2.831969279439222, + }, + {"x": 2.0, "y": 2.2, "weight": 5.946427498927402}, + { + "x": 2.9516640394605282, + "y": 3.1549921183815837, + "weight": 8.632349833200564, + }, + ], + [], + [ + { + "x": 3.9515600270076154, + "y": 4.206240108030463, + "weight": 11.533018588312771, + }, + {"x": 5.0, "y": 5.5, "weight": 14.866068747318506}, + ], ], - ], + behavior=behavior, + with_name="WeightedPoint", + ), + dtype_exact=False, ) - assert_equal_enough( + assert ak._util.arrays_approx_equal( abs(one), - [ - [1.4866068747318506, 2.973213749463701, 4.459820624195552], - [], - [5.946427498927402, 7.433034373659253], - ], + ak.Array( + [ + [1.4866068747318506, 2.973213749463701, 4.459820624195552], + [], + [5.946427498927402, 7.433034373659253], + ], + behavior=behavior, + with_name="Point", + ), + dtype_exact=False, ) - assert_equal_enough( + assert ak._util.arrays_approx_equal( one.distance(wtwo), [ [0.14142135623730953, 0.0, 0.31622776601683783], [], [0.4123105625617664, 0.0], ], + dtype_exact=False, + check_parameters=False, )