Skip to content

Commit

Permalink
remove componentwise comparison of complex values in TestCase.assertE…
Browse files Browse the repository at this point in the history
…qual (#63572)

Summary:
Pull Request resolved: #63572

Addresses #61906. Issue will be fixed later in the stack when `torch.testing.assert_close` got the same treatment.

cc ezyang gchanan

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D30633527

Pulled By: mruberry

fbshipit-source-id: c2002a4998a7a75cb2ab83f87190bde43a9d4f7c
  • Loading branch information
pmeier authored and facebook-github-bot committed Aug 30, 2021
1 parent a8ffe81 commit 401bbb2
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 119 deletions.
2 changes: 1 addition & 1 deletion test/test_tensor_creation_ops.py
Expand Up @@ -3258,7 +3258,7 @@ def seed(generator):
self.assertTrue((res1 >= 0).all().item())

@dtypes(torch.half, torch.float, torch.bfloat16, torch.double,
torch.complex32, torch.complex64, torch.complex128)
torch.complex64, torch.complex128)
def test_randn(self, device, dtype):
SIZE = 100
for size in [0, SIZE]:
Expand Down
54 changes: 9 additions & 45 deletions test/test_testing.py
Expand Up @@ -88,25 +88,19 @@ def test__comparescalars_debug_msg(self, device):
"atol=1e-05 is only 1.9100000000000003e-05!")
self.assertEqual(debug_msg, expected_msg)

# complex x complex, real difference
# complex x complex
result, debug_msg = self._compareScalars(complex(1, 3), complex(3, 1))
expected_msg = ("Comparing the real part 1.0 and 3.0 gives a difference "
"of 2.0, but the allowed difference with rtol=1.3e-06 "
"and atol=1e-05 is only 1.39e-05!")
self.assertEqual(debug_msg, expected_msg)

# complex x complex, imaginary difference
result, debug_msg = self._compareScalars(complex(1, 3), complex(1, 5.5))
expected_msg = ("Comparing the imaginary part 3.0 and 5.5 gives a "
"difference of 2.5, but the allowed difference with "
"rtol=1.3e-06 and atol=1e-05 is only 1.715e-05!")
expected_msg = ("Comparing (1+3j) and (3+1j) gives a difference "
"of 2.8284271247461903, but the allowed difference "
"with rtol=1.3e-06 and atol=1e-05 is only "
"1.4110960958218895e-05!")
self.assertEqual(debug_msg, expected_msg)

# complex x int
result, debug_msg = self._compareScalars(complex(1, -2), 1)
expected_msg = ("Comparing the imaginary part -2.0 and 0.0 gives a "
"difference of 2.0, but the allowed difference with "
"rtol=1.3e-06 and atol=1e-05 is only 1e-05!")
expected_msg = ("Comparing (1-2j) and 1 gives a difference of 2.0, "
"but the allowed difference with rtol=1.3e-06 and "
"atol=1e-05 is only 1.13e-05!")
self.assertEqual(debug_msg, expected_msg)

# NaN x NaN, equal_nan=False
Expand Down Expand Up @@ -170,28 +164,6 @@ def test__comparetensors_debug_msg(self, device):
"occuring at index 0.")
self.assertEqual(debug_msg, expected_msg)

# Checks complex tensor comparisons (real part)
a = torch.tensor((1 - 1j, 4 + 3j), device=device)
b = torch.tensor((1 - 1j, 1 + 3j), device=device)
result, debug_msg = self._compareTensors(a, b)
expected_msg = ("Real parts failed to compare as equal! "
"With rtol=1.3e-06 and atol={0}, "
"found 1 element(s) (out of 2) whose difference(s) exceeded the "
"margin of error (including 0 nan comparisons). The greatest difference was "
"3.0 (4.0 vs. 1.0), which occurred at index 1.").format(atol)
self.assertEqual(debug_msg, expected_msg)

# Checks complex tensor comparisons (imaginary part)
a = torch.tensor((1 - 1j, 4 + 3j), device=device)
b = torch.tensor((1 - 1j, 4 - 21j), device=device)
result, debug_msg = self._compareTensors(a, b)
expected_msg = ("Imaginary parts failed to compare as equal! "
"With rtol=1.3e-06 and atol={0}, "
"found 1 element(s) (out of 2) whose difference(s) exceeded the "
"margin of error (including 0 nan comparisons). The greatest difference was "
"24.0 (3.0 vs. -21.0), which occurred at index 1.").format(atol)
self.assertEqual(debug_msg, expected_msg)

# Checks size mismatch
a = torch.tensor((1, 2), device=device)
b = torch.tensor((3), device=device)
Expand Down Expand Up @@ -407,7 +379,7 @@ def test_isclose_comparetensors_complex(self, device, dtype):
tests = (
(complex(1, -1), complex(-1, 1), False),
(complex(1, -1), complex(2, -2), True),
(complex(1, 99), complex(4, 100), False),
(complex(1, 99), complex(4, 100), True),
)

self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5)
Expand All @@ -421,14 +393,6 @@ def test_isclose_comparetensors_complex(self, device, dtype):
(complex(float('nan'), float('nan')), complex(float('nan'), float('nan')), True),
)
self._isclose_helper(tests, device, dtype, True)

tests = (
(complex(1, 1), complex(1, float('nan')), False),
(complex(1, 1), complex(float('nan'), 1), False),
(complex(float('nan'), 1), complex(float('nan'), 1), True),
(complex(float('nan'), 1), complex(1, float('nan')), False),
(complex(float('nan'), float('nan')), complex(float('nan'), float('nan')), True),
)
self._comparetensors_helper(tests, device, dtype, True)

# Tests that isclose with rtol or atol values less than zero throws a
Expand Down
4 changes: 2 additions & 2 deletions test/test_torch.py
Expand Up @@ -5121,7 +5121,7 @@ def filter_shape(shape, dim):
spacing = [space.cpu().detach().numpy() for space in spacing]
expected = np.gradient(t_np, *self._wrap_to_list(spacing), axis=dims, edge_order=edge_order)
actual, expected = self._inf_nan_preprocess(list(actual), self._wrap_to_list(expected))
self.assertEqual(actual, expected, equal_nan="relaxed", atol=1e-4, rtol=0, exact_dtype=False)
self.assertEqual(actual, expected, equal_nan=True, atol=1e-4, rtol=0, exact_dtype=False)

@onlyOnCPUAndCUDA
@dtypes(torch.long, torch.float32, torch.complex64)
Expand Down Expand Up @@ -5188,7 +5188,7 @@ def test_gradient_type_promotion(self, device):
self.assertEqual(expected[i].imag, torch.zeros(actual[i].shape), exact_dtype=False)
else:
actual, expected = self._inf_nan_preprocess(list(actual), expected)
self.assertEqual(actual, expected, equal_nan="relaxed", exact_dtype=False)
self.assertEqual(actual, expected, equal_nan=True, exact_dtype=False)

@onlyOnCPUAndCUDA
@dtypes(torch.long, torch.float32, torch.complex64)
Expand Down
5 changes: 1 addition & 4 deletions test/test_unary_ufuncs.py
Expand Up @@ -359,10 +359,7 @@ def test_reference_numerics_extremal(self, device, dtype, op):
tensors = generate_numeric_tensors_extremal(device, dtype,
domain=op.domain)

# https://github.com/pytorch/pytorch/issues/50749
equal_nan = "relaxed" if device.startswith('cuda') else True

self._test_reference_numerics(dtype, op, tensors, equal_nan)
self._test_reference_numerics(dtype, op, tensors)

# Tests for testing (non)contiguity consistency

Expand Down
75 changes: 8 additions & 67 deletions torch/testing/_core.py
Expand Up @@ -6,7 +6,7 @@
import random
import math
import cmath
from typing import cast, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import operator

FileCheck = torch._C.FileCheck
Expand Down Expand Up @@ -78,27 +78,12 @@ def _unravel_index(flat_index, shape):
# Two tensors are "equal" if they are "close", in the sense of torch.allclose.
# The only exceptions are complex tensors and bool tensors.
#
# Complex tensors are "equal" if both the
# real and complex parts (separately) are close. This is divergent from
# torch.allclose's behavior, which compares the absolute values of the
# complex numbers instead.
#
# Using torch.allclose would be a less strict
# comparison that would allow large complex values with
# significant real or imaginary differences to be considered "equal,"
# and would make setting rtol and atol for complex tensors distinct from
# other tensor types.
#
# Bool tensors are equal only if they are identical, regardless of
# the rtol and atol values.
#
# The `equal_nan` can be True or False, which maps to the True or False
# in `torch.allclose`. `equal_nan` can also be "relaxed", which means
# the complex will be compared in the relaxed mode:
# 2 + nan j == 3 + nan j ---> False when equal_nan=True
# True when equal_nan="relaxed"
def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, equal_nan: Union[str, bool]) -> _compare_return_type:
assert equal_nan in {True, False, "relaxed"}
# in `torch.allclose`.
def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, equal_nan) -> _compare_return_type:
debug_msg : Optional[str]
# Integer (including bool) comparisons are identity comparisons
# when rtol is zero and atol is less than one
Expand Down Expand Up @@ -129,48 +114,19 @@ def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, e
_unravel_index(greatest_diff_index, a.shape)))
return (False, debug_msg)

# Compares complex tensors' real and imaginary parts separately.
# (see NOTE Test Framework Tensor "Equality")
if a.is_complex():
if equal_nan == "relaxed":
a = a.clone()
b = b.clone()
a.real[a.imag.isnan()] = math.nan
a.imag[a.real.isnan()] = math.nan
b.real[b.imag.isnan()] = math.nan
b.imag[b.real.isnan()] = math.nan

real_result, debug_msg = _compare_tensors_internal(a.real, b.real,
rtol=rtol, atol=atol,
equal_nan=equal_nan)

if not real_result:
debug_msg = "Real parts failed to compare as equal! " + cast(str, debug_msg)
return (real_result, debug_msg)

imag_result, debug_msg = _compare_tensors_internal(a.imag, b.imag,
rtol=rtol, atol=atol,
equal_nan=equal_nan)

if not imag_result:
debug_msg = "Imaginary parts failed to compare as equal! " + cast(str, debug_msg)
return (imag_result, debug_msg)

return (True, None)

# All other comparisons use torch.allclose directly
if torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=(equal_nan in {"relaxed", True})):
if torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan):
return (True, None)

# Gathers debug info for failed float tensor comparison
# NOTE: converts to float64 to best represent differences
a_flat = a.to(torch.float64).flatten()
b_flat = b.to(torch.float64).flatten()
a_flat = a.to(torch.float64 if not a.dtype.is_complex else torch.complex128).flatten()
b_flat = b.to(torch.float64 if not a.dtype.is_complex else torch.complex128).flatten()
diff = torch.abs(a_flat - b_flat)

# Masks close values
# NOTE: this avoids (inf - inf) oddities when computing the difference
close = torch.isclose(a_flat, b_flat, rtol, atol, (equal_nan in {"relaxed", True}))
close = torch.isclose(a_flat, b_flat, rtol, atol, equal_nan)
diff[close] = 0
nans = torch.isnan(diff)
num_nans = nans.sum()
Expand Down Expand Up @@ -212,7 +168,7 @@ def _helper(a, b, s) -> _compare_return_type:

# Special-case for infinity comparisons
# NOTE: if b is inf then allowed_diff will be inf when rtol is not 0
if ((math.isinf(a) or math.isinf(b)) and a != b):
if ((cmath.isinf(a) or cmath.isinf(b)) and a != b):
result = False

msg = None
Expand All @@ -228,21 +184,6 @@ def _helper(a, b, s) -> _compare_return_type:
)
return result, msg

if isinstance(a, complex) or isinstance(b, complex):
a = complex(a)
b = complex(b)

if equal_nan == "relaxed":
if cmath.isnan(a) and cmath.isnan(b):
return (True, None)

result, msg = _helper(a.real, b.real, " the real part ")

if not result:
return (False, msg)

return _helper(a.imag, b.imag, " the imaginary part ")

return _helper(a, b, " ")


Expand Down

0 comments on commit 401bbb2

Please sign in to comment.