Skip to content

Commit

Permalink
make assert_equal a special case of assert_close
Browse files Browse the repository at this point in the history
Instead of a distinct `torch.testing.assert_close` and
`torch.testing.assert_equal`, this makes `torch.testing.assert_equal` a
special case of `torch.testing.assert_close` for `rtol=atol=0`. In this
case the closeness definition `abs(actual - expected) <= atol + rtol *
abs(expected)` boils down to `abs(actual - expected) <= 0`. Since
`abs(x)` can never be `<0`, this is equivalent to `abs(a - b) == 0` and
this again boils down to `a == b`.

This makes maintaing the module a lot easier, because we don't need to
keep two functions in sync.

ghstack-source-id: 10cea61c08830ffecaf71e02377b616ed3f29d4c
Pull Request resolved: #58918
  • Loading branch information
pmeier committed May 25, 2021
1 parent c22a347 commit 30657d3
Showing 1 changed file with 82 additions and 172 deletions.
254 changes: 82 additions & 172 deletions torch/testing/_asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ._core import _unravel_index

__all__ = ["assert_equal", "assert_close"]
__all__ = ["assert_close", "assert_equal"]


# The UsageError should be raised in case the test function is not used correctly. With this the user is able to
Expand Down Expand Up @@ -242,43 +242,6 @@ def _trace_mismatches(actual: Tensor, expected: Tensor, mismatches: Tensor) -> D
)


@_check_complex_components_individually
def _check_values_equal(
actual: Tensor,
expected: Tensor,
*,
msg: Optional[Union[str, Callable[[Tensor, Tensor, SimpleNamespace], str]]] = None,
) -> Optional[_TestingErrorMeta]:
"""Checks if the values of two tensors are bitwise equal.
Args:
actual (Tensor): Actual tensor.
expected (Tensor): Expected tensor.
msg (Optional[Union[str, Callable[[Tensor, Tensor, SimpleNamespace], str]]]): Optional error message. Can be
passed as callable in which case it will be called with the inputs and the result of
:func:`_trace_mismatches`.
Returns:
(Optional[AssertionError]): If check did not pass.
"""
mismatches = torch.ne(actual, expected)
if not torch.any(mismatches):
return None

trace = _trace_mismatches(actual, expected, mismatches)

if msg is None:
msg = (
f"Tensors are not equal!\n\n"
f"Mismatched elements: {trace.total_mismatches} / {trace.number_of_elements} ({trace.mismatch_ratio:.1%})\n"
f"Greatest absolute difference: {trace.max_abs_diff} at {trace.max_abs_diff_idx}\n"
f"Greatest relative difference: {trace.max_rel_diff} at {trace.max_rel_diff_idx}"
)
elif callable(msg):
msg = msg(actual, expected, trace)
return _TestingErrorMeta(AssertionError, msg)


@_check_complex_components_individually
def _check_values_close(
actual: Tensor,
Expand Down Expand Up @@ -322,39 +285,6 @@ def _check_values_close(
return _TestingErrorMeta(AssertionError, msg)


def _check_tensors_equal(
actual: Tensor,
expected: Tensor,
*,
check_device: bool = True,
check_dtype: bool = True,
check_stride: bool = True,
msg: Optional[Union[str, Callable[[Tensor, Tensor, SimpleNamespace], str]]] = None,
) -> Optional[_TestingErrorMeta]:
"""Checks that the values of two tensors are bitwise equal.
For complex tensors the check is performed on the real and imaginary component separately. Optionally, checks that
some attributes of tensor pairs are equal.
For a description of the parameters see :func:`assert_equal`.
Returns:
Optional[_TestingErrorMeta]: If checks did not pass.
"""
error_meta = _check_attributes_equal(
actual, expected, check_device=check_device, check_dtype=check_dtype, check_stride=check_stride
)
if error_meta:
return error_meta
actual, expected = _equalize_attributes(actual, expected)

error_meta = _check_values_equal(actual, expected, msg=msg)
if error_meta:
return error_meta

return None


def _check_tensors_close(
actual: Tensor,
expected: Tensor,
Expand Down Expand Up @@ -402,10 +332,7 @@ def _check_tensors_close(
return error_meta
actual, expected = _equalize_attributes(actual, expected)

if (rtol == 0.0) and (atol == 0.0):
error_meta = _check_values_equal(actual, expected, msg=msg)
else:
error_meta = _check_values_close(actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=msg)
error_meta = _check_values_close(actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=msg)
if error_meta:
return error_meta

Expand All @@ -421,9 +348,9 @@ class _TensorPair(NamedTuple):
_MAPPING_MSG_FMTSTR = "The failure occurred for key '{}' of the mappings."


def _check_pair(
def _check_pair_close(
pair: Union[_TensorPair, List, Dict],
check_tensors: Callable[[Any, Any], Optional[_TestingErrorMeta]],
**kwargs: Any,
) -> Optional[_TestingErrorMeta]:
"""Checks input pairs.
Expand All @@ -432,27 +359,27 @@ def _check_pair(
Args:
pair (Union[_TensorPair, List, Dict]): Input pair.
check_tensors (Callable[[Any, Any], Optional[Exception]]): Callable used to check if a tensor pair matches.
**kwargs (Any): Keyword arguments passed to :func:`__check_tensors_close`.
Returns:
(Optional[_TestingErrorMeta]): Return value of :attr:`check_tensors`.
"""
if isinstance(pair, list):
for idx, pair_item in enumerate(pair):
error_meta = _check_pair(pair_item, check_tensors)
error_meta = _check_pair_close(pair_item, **kwargs)
if error_meta:
return error_meta.amend_msg(postfix=f"\n\n{_SEQUENCE_MSG_FMTSTR.format(idx)}")
else:
return None
elif isinstance(pair, dict):
for key, pair_item in pair.items():
error_meta = _check_pair(pair_item, check_tensors)
error_meta = _check_pair_close(pair_item, **kwargs)
if error_meta:
return error_meta.amend_msg(postfix=f"\n\n{_MAPPING_MSG_FMTSTR.format(key)}")
else:
return None
else: # isinstance(pair, TensorPair)
return check_tensors(pair.actual, pair.expected)
return _check_tensors_close(pair.actual, pair.expected, **kwargs)


def _to_tensor(array_or_scalar_like: Any) -> Tuple[Optional[_TestingErrorMeta], Optional[Tensor]]:
Expand Down Expand Up @@ -603,94 +530,6 @@ def _parse_inputs(
return _to_tensor_pair(actual, expected)


def assert_equal(
actual: Any,
expected: Any,
*,
check_device: bool = True,
check_dtype: bool = True,
check_stride: bool = True,
msg: Optional[Union[str, Callable[[Tensor, Tensor, SimpleNamespace], str]]] = None,
) -> None:
"""Asserts that the values of tensor pairs are bitwise equal.
For complex tensors the check is performed on the real and imaginary component separately. Optionally, checks that
some attributes of tensor pairs are equal.
Also supports array-or-scalar-like inputs from which a :class:`torch.Tensor` can be constructed with
:func:`torch.as_tensor`. Still, requires type equality, i.e. comparing a :class:`torch.Tensor` and a
:class:`numpy.ndarray` is not supported.
In case both inputs are :class:`~collections.abc.Sequence`'s or :class:`~collections.abc.Mapping`'s the checks are
performed elementwise.
Args:
actual (Any): Actual input.
expected (Any): Expected input.
check_device (bool): If ``True`` (default), asserts that each tensor pair is on the same
:attr:`~torch.Tensor.device` memory. If this check is disabled **and** it is not on the same
:attr:`~torch.Tensor.device` memory, it is moved CPU memory before the values are compared.
check_dtype (bool): If ``True`` (default), asserts that each tensor pair has the same
:attr:`~torch.Tensor.dtype`. If this check is disabled it does not have the same
:attr:`~torch.Tensor.dtype`, it is copied to the :class:`~torch.dtype` returned by
:func:`torch.promote_types` before the values are compared.
check_stride (bool): If ``True`` (default), asserts that each tensor pair has the same stride.
msg (Optional[Union[str, Callable[[Tensor, Tensor, SimpleNamespace], str]]]): Optional error message to use if
the values of a tensor pair mismatch. Can be passed as callable in which case it will be called with the
tensor pair and a namespace of diagnostic info about the mismatches. See below for details.
Raises:
UsageError: If an array-or-scalar-like pair has different types.
UsageError: If a :class:`torch.Tensor` can't be constructed from an array-or-scalar-like.
UsageError: If any tensor is quantized or sparse. This is a temporary restriction and will be relaxed in the
future.
AssertionError: If the inputs are :class:`~collections.abc.Sequence`'s, but their length does not match.
AssertionError: If the inputs are :class:`~collections.abc.Mapping`'s, but their set of keys do not match.
AssertionError: If a tensor pair does not have the same :attr:`~torch.Tensor.shape`.
AssertionError: If :attr:`check_device`, but a tensor pair is not on the same :attr:`~torch.Tensor.device`
memory.
AssertionError: If :attr:`check_dtype`, but a tensor pair does not have the same :attr:`~torch.Tensor.dtype`.
AssertionError: If :attr:`check_stride`, but a tensor pair does not have the same stride.
AssertionError: If the values of a tensor pair are not bitwise equal.
The namespace that will be passed to :attr:`msg` if its a callable comprises the following attributes:
- total_elements (int): Total number of values.
- total_mismatches (int): Total number of mismatches.
- mismatch_ratio (float): Quotient of total mismatches and total elements.
- max_abs_diff (Union[int, float]): Greatest absolute difference of the inputs.
- max_abs_diff_idx (Union[int, Tuple[int, ...]]): Index of greatest absolute difference.
- max_rel_diff (Union[int, float]): Greatest relative difference of the inputs.
- max_rel_diff_idx (Union[int, Tuple[int, ...]]): Index of greatest relative difference.
For ``max_abs_diff`` and ``max_rel_diff`` the type depends on the :attr:`~torch.Tensor.dtype` of the inputs.
.. seealso::
To assert that the values of a tensor pair are close but are not required to be bitwise equal, use
:func:`assert_close` instead.
"""
# Hide this function from `pytest`'s traceback
__tracebackhide__ = True

error_meta, pair = _parse_inputs(actual, expected)
if error_meta:
raise error_meta.to_error()
else:
pair = cast(Union[_TensorPair, List, Dict], pair)

check_tensors = functools.partial(
_check_tensors_equal,
check_device=check_device,
check_dtype=check_dtype,
check_stride=check_stride,
msg=msg,
)
error_meta = _check_pair(pair, check_tensors)
if error_meta:
raise error_meta.to_error()


def assert_close(
actual: Any,
expected: Any,
Expand Down Expand Up @@ -796,6 +635,11 @@ def assert_close(
For ``max_abs_diff`` and ``max_rel_diff`` the type depends on the :attr:`~torch.Tensor.dtype` of the inputs.
.. seealso::
To assert that the values of corresponding scalars or tensor-likes are bitwise equal, use :func:`assert_equal`
instead.
Examples:
>>> # tensor to tensor comparison
>>> expected = torch.tensor([1e0, 1e-1, 1e-2])
Expand Down Expand Up @@ -883,8 +727,8 @@ def assert_close(
else:
pair = cast(Union[_TensorPair, List, Dict], pair)

check_tensors = functools.partial(
_check_tensors_close,
error_meta = _check_pair_close(
pair,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
Expand All @@ -893,6 +737,72 @@ def assert_close(
check_stride=check_stride,
msg=msg,
)
error_meta = _check_pair(pair, check_tensors)
if error_meta:
raise error_meta.to_error()


assert_equal = functools.partial(assert_close, rtol=0.0, atol=0.0)
assert_equal.__doc__ = r"""Asserts that :attr:`actual` and :attr:`expected` are close.
If :attr:`actual` and :attr:`expected` are real-valued and finite, they are considered equal if there values are
bitwise equal and they have the same :attr:`~torch.Tensor.device` (if :attr:`check_device` is ``True``), same
``dtype`` (if :attr:`check_dtype` is ``True``), and the same stride (if :attr:`check_stride` is ``True``).
``NaN``'s are only considered equal to each other if :attr:`equal_nan` is ``True``.
If :attr:`actual` and :attr:`expected` are complex-valued, they are considered close if both their real and
imaginary components are considered close according to the definition above.
:attr:`actual` and :attr:`expected` can be :class:`~torch.Tensor`'s or any array-or-scalar-like of the same type,
from which :class:`torch.Tensor`'s can be constructed with :func:`torch.as_tensor`. In addition, :attr:`actual` and
:attr:`expected` can be :class:`~collections.abc.Sequence`'s or :class:`~collections.abc.Mapping`'s in which case
they are considered close if their structure matches and all their elements are considered close according to the
above definition.
Args:
actual (Any): Actual input.
expected (Any): Expected input.
equal_nan (Union[bool, str]): If ``True``, two ``NaN`` values will be considered equal. If ``"relaxed"``,
complex values are considered as ``NaN`` if either the real **or** imaginary component is ``NaN``.
check_device (bool): If ``True`` (default), asserts that corresponding tensors are on the same
:attr:`~torch.Tensor.device`. If this check is disabled, tensors on different
:attr:`~torch.Tensor.device`'s are moved to the CPU before being compared.
check_dtype (bool): If ``True`` (default), asserts that corresponding tensors have the same ``dtype``. If this
check is disabled, tensors with different ``dtype``'s are promoted to a common ``dtype`` (according to
:func:`torch.promote_types`) before being compared.
check_stride (bool): If ``True`` (default), asserts that corresponding tensors have the same stride.
msg (Optional[Union[str, Callable[[Tensor, Tensor, DiagnosticInfo], str]]]): Optional error message to use if
the values of corresponding tensors mismatch. Can be passed as callable in which case it will be called
with the mismatching tensors and a namespace of diagnostic info about the mismatches. See below for details.
Raises:
UsageError: If a :class:`torch.Tensor` can't be constructed from an array-or-scalar-like.
UsageError: If any tensor is quantized or sparse. This is a temporary restriction and will be relaxed in the
future.
AssertionError: If corresponding array-likes have different types.
AssertionError: If the inputs are :class:`~collections.abc.Sequence`'s, but their length does not match.
AssertionError: If the inputs are :class:`~collections.abc.Mapping`'s, but their set of keys do not match.
AssertionError: If corresponding tensors do not have the same :attr:`~torch.Tensor.shape`.
AssertionError: If :attr:`check_device`, but corresponding tensors are not on the same
:attr:`~torch.Tensor.device`.
AssertionError: If :attr:`check_dtype`, but corresponding tensors do not have the same ``dtype``.
AssertionError: If :attr:`check_stride`, but corresponding tensors do not have the same stride.
AssertionError: If the values of corresponding tensors are not equal.
The namespace of diagnostic information that will be passed to :attr:`msg` if its a callable has the following
attributes:
- ``number_of_elements`` (int): Number of elements in each tensor being compared.
- ``total_mismatches`` (int): Total number of mismatches.
- ``mismatch_ratio`` (float): Total mismatches divided by number of elements.
- ``max_abs_diff`` (Union[int, float]): Greatest absolute difference of the inputs.
- ``max_abs_diff_idx`` (Union[int, Tuple[int, ...]]): Index of greatest absolute difference.
- ``max_rel_diff`` (Union[int, float]): Greatest relative difference of the inputs.
- ``max_rel_diff_idx`` (Union[int, Tuple[int, ...]]): Index of greatest relative difference.
For ``max_abs_diff`` and ``max_rel_diff`` the type depends on the :attr:`~torch.Tensor.dtype` of the inputs.
.. seealso::
To assert that the values of corresponding scalars or tensor-likes are are close but are not required to be
bitwise equal, use :func:`assert_close` instead.
"""

0 comments on commit 30657d3

Please sign in to comment.