Skip to content

Commit

Permalink
Improve numpy.approx array-scalar comparisons
Browse files Browse the repository at this point in the history
So that `self.expected` in ApproxNumpy is always a numpy array.
  • Loading branch information
tadeu committed Mar 16, 2018
1 parent 42c84f4 commit a754f00
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions _pytest/python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,31 +87,27 @@ def __repr__(self):
def __eq__(self, actual):
import numpy as np

# self.expected is supposed to always be an array here

if not np.isscalar(actual):
try:
actual = np.asarray(actual)
except: # noqa
raise TypeError("cannot compare '{0}' to numpy.ndarray".format(actual))

if (not np.isscalar(self.expected) and not np.isscalar(actual)
and actual.shape != self.expected.shape):
if not np.isscalar(actual) and actual.shape != self.expected.shape:
return False

return ApproxBase.__eq__(self, actual)

def _yield_comparisons(self, actual):
import numpy as np

# For both `actual` and `self.expected`, they can independently be
# either a `numpy.array` or a scalar (but both can't be scalar,
# in this case an `ApproxScalar` is used).
# They are treated in `__eq__` before being passed to
# `ApproxBase.__eq__`, which is the only method that calls this one.
# `actual` can either be a numpy array or a scalar, it is treated in
# `__eq__` before being passed to `ApproxBase.__eq__`, which is the
# only method that calls this one.

if np.isscalar(self.expected):
for i in np.ndindex(actual.shape):
yield np.asscalar(actual[i]), self.expected
elif np.isscalar(actual):
if np.isscalar(actual):
for i in np.ndindex(self.expected.shape):
yield actual, np.asscalar(self.expected[i])
else:
Expand Down Expand Up @@ -202,7 +198,7 @@ def __eq__(self, actual):
the pre-specified tolerance.
"""
if _is_numpy_array(actual):
return actual == ApproxNumpy(self.expected, self.abs, self.rel, self.nan_ok)
return ApproxNumpy(actual, self.abs, self.rel, self.nan_ok) == self.expected

# Short-circuit exact equality.
if actual == self.expected:
Expand Down

0 comments on commit a754f00

Please sign in to comment.