Skip to content

Commit

Permalink
python_api: let approx() take nonnumeric values
Browse files Browse the repository at this point in the history
  • Loading branch information
jvansanten committed Sep 3, 2020
1 parent 91dbdb6 commit da8a223
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 16 deletions.
31 changes: 19 additions & 12 deletions src/_pytest/python_api.py
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Mapping
from collections.abc import Sized
from decimal import Decimal
from numbers import Number
from numbers import Complex
from types import TracebackType
from typing import Any
from typing import Callable
Expand Down Expand Up @@ -145,7 +145,10 @@ def __repr__(self) -> str:
)

def __eq__(self, actual) -> bool:
if set(actual.keys()) != set(self.expected.keys()):
try:
if set(actual.keys()) != set(self.expected.keys()):
return False
except AttributeError:
return False

return ApproxBase.__eq__(self, actual)
Expand All @@ -160,8 +163,6 @@ def _check_type(self) -> None:
if isinstance(value, type(self.expected)):
msg = "pytest.approx() does not support nested dictionaries: key={!r} value={!r}\n full mapping={}"
raise TypeError(msg.format(key, value, pprint.pformat(self.expected)))
elif not isinstance(value, Number):
raise _non_numeric_type_error(self.expected, at="key={!r}".format(key))


class ApproxSequencelike(ApproxBase):
Expand All @@ -176,7 +177,10 @@ def __repr__(self) -> str:
)

def __eq__(self, actual) -> bool:
if len(actual) != len(self.expected):
try:
if len(actual) != len(self.expected):
return False
except TypeError:
return False
return ApproxBase.__eq__(self, actual)

Expand All @@ -189,10 +193,6 @@ def _check_type(self) -> None:
if isinstance(x, type(self.expected)):
msg = "pytest.approx() does not support nested data structures: {!r} at index {}\n full sequence: {}"
raise TypeError(msg.format(x, index, pprint.pformat(self.expected)))
elif not isinstance(x, Number):
raise _non_numeric_type_error(
self.expected, at="index {}".format(index)
)


class ApproxScalar(ApproxBase):
Expand Down Expand Up @@ -238,6 +238,15 @@ def __eq__(self, actual) -> bool:
if actual == self.expected:
return True

# If either type is non-numeric, fall back to strict equality.
# NB: we need Complex, rather than just Number, to ensure that __abs__,
# __sub__, and __float__ are defined.
if not (
isinstance(self.expected, (Complex, Decimal))
and isinstance(actual, (Complex, Decimal))
):
return False

# Allow the user to control whether NaNs are considered equal to each
# other or not. The abs() calls are for compatibility with complex
# numbers.
Expand Down Expand Up @@ -486,8 +495,6 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:

if isinstance(expected, Decimal):
cls = ApproxDecimal # type: Type[ApproxBase]
elif isinstance(expected, Number):
cls = ApproxScalar
elif isinstance(expected, Mapping):
cls = ApproxMapping
elif _is_numpy_array(expected):
Expand All @@ -500,7 +507,7 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
):
cls = ApproxSequencelike
else:
raise _non_numeric_type_error(expected, at=None)
cls = ApproxScalar

return cls(expected, rel, abs, nan_ok)

Expand Down
48 changes: 44 additions & 4 deletions testing/python/approx.py
Expand Up @@ -329,6 +329,9 @@ def test_tuple_wrong_len(self):
assert (1, 2) != approx((1,))
assert (1, 2) != approx((1, 2, 3))

def test_tuple_vs_other(self):
assert 1 != approx((1,))

def test_dict(self):
actual = {"a": 1 + 1e-7, "b": 2 + 1e-8}
# Dictionaries became ordered in python3.6, so switch up the order here
Expand All @@ -346,6 +349,13 @@ def test_dict_wrong_len(self):
assert {"a": 1, "b": 2} != approx({"a": 1, "c": 2})
assert {"a": 1, "b": 2} != approx({"a": 1, "b": 2, "c": 3})

def test_dict_nonnumeric(self):
assert {"a": 1.0, "b": None} == pytest.approx({"a": 1.0, "b": None})
assert {"a": 1.0, "b": 1} != pytest.approx({"a": 1.0, "b": None})

def test_dict_vs_other(self):
assert 1 != approx({"a": 0})

def test_numpy_array(self):
np = pytest.importorskip("numpy")

Expand Down Expand Up @@ -466,18 +476,48 @@ def test_foo():
@pytest.mark.parametrize(
"x",
[
pytest.param(None),
pytest.param("string"),
pytest.param(["string"], id="nested-str"),
pytest.param([[1]], id="nested-list"),
pytest.param({"key": "string"}, id="dict-with-string"),
pytest.param({"key": {"key": 1}}, id="nested-dict"),
],
)
def test_expected_value_type_error(self, x):
with pytest.raises(TypeError):
approx(x)

@pytest.mark.parametrize(
"x",
[
pytest.param(None),
pytest.param("string"),
pytest.param(["string"], id="nested-str"),
pytest.param({"key": "string"}, id="dict-with-string"),
],
)
def test_nonnumeric_okay_if_equal(self, x):
assert x == approx(x)

@pytest.mark.parametrize(
"x",
[
pytest.param("string"),
pytest.param(["string"], id="nested-str"),
pytest.param({"key": "string"}, id="dict-with-string"),
],
)
def test_nonnumeric_false_if_unequal(self, x):
"""For nonnumeric types, x != pytest.approx(y) reduces to x != y"""
assert "ab" != approx("abc")
assert ["ab"] != approx(["abc"])
# in particular, both of these should return False
assert {"a": 1.0} != approx({"a": None})
assert {"a": None} != approx({"a": 1.0})

assert 1.0 != approx(None)
assert None != approx(1.0) # noqa: E711

assert 1.0 != approx([None])
assert None != approx([1.0]) # noqa: E711

@pytest.mark.parametrize(
"op",
[
Expand Down

0 comments on commit da8a223

Please sign in to comment.