Skip to content

Commit

Permalink
split TestAsserts by functionality
Browse files Browse the repository at this point in the history
Instead of having one large `TestAsserts` test case, we split of tests
for self-contained functionality like container or complex checking into
separate test cases. That makes it a lot easier to keep an overview over
what is tested.

ghstack-source-id: f93caedb2d4cb044447362c0ed6f27657ca52780
Pull Request resolved: #58919
  • Loading branch information
pmeier committed May 25, 2021
1 parent 30657d3 commit 8472541
Showing 1 changed file with 122 additions and 113 deletions.
235 changes: 122 additions & 113 deletions test/test_testing.py
Expand Up @@ -826,6 +826,22 @@ def test_quantized_support(self):
with self.assertRaises(UsageError):
fn()

def test_type_inequality(self):
actual = torch.empty(2)
expected = actual.tolist()

for fn in assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, str(type(expected))):
fn()

def test_unknown_type(self):
actual = "0"
expected = "0"

for fn in assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(UsageError, str(type(actual))):
fn()

def test_mismatching_shape(self):
actual = torch.empty(())
expected = actual.clone().reshape((1,))
Expand Down Expand Up @@ -871,17 +887,12 @@ def test_mismatching_values(self):
with self.assertRaises(AssertionError):
fn()

def test_assert_equal(self):
def test_matching(self):
actual = torch.tensor(1)
expected = actual.clone()

torch.testing.assert_equal(actual, expected)

def test_assert_close(self):
actual = torch.tensor(1.0)
expected = actual.clone()

torch.testing.assert_close(actual, expected)
for fn in assert_fns_with_inputs(actual, expected):
fn()

def test_assert_close_only_rtol(self):
actual = torch.empty(())
Expand Down Expand Up @@ -942,72 +953,71 @@ def test_assert_close_equal_nan(self):
for inputs in make_assert_inputs(a, b):
torch.testing.assert_close(*inputs, equal_nan=True)

def test_assert_close_equal_nan_complex(self):
a = torch.tensor(complex(1, float("NaN")))
b = torch.tensor(complex(float("NaN"), 1))
def test_numpy(self):
tensor = torch.rand(2, 2, dtype=torch.float32)
actual = tensor.numpy()
expected = actual.copy()

for inputs in make_assert_inputs(a, b):
with self.assertRaises(AssertionError):
torch.testing.assert_close(*inputs, equal_nan=True)
for fn in assert_fns_with_inputs(actual, expected):
fn()

def test_assert_close_equal_nan_complex_relaxed(self):
a = torch.tensor(complex(1, float("NaN")))
b = torch.tensor(complex(float("NaN"), 1))
def test_scalar(self):
number = torch.randint(10, size=()).item()
for actual, expected in itertools.product((int(number), float(number), complex(number)), repeat=2):
check_dtype = type(actual) is type(expected)

for inputs in make_assert_inputs(a, b):
torch.testing.assert_close(*inputs, equal_nan="relaxed")
for fn in assert_fns_with_inputs(actual, expected):
fn(check_dtype=check_dtype)

def test_mismatching_values_msg_mismatches(self):

class TestAssertsMultiDevice(TestCase):
@deviceCountAtLeast(1)
def test_mismatching_device(self, devices):
for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2):
actual = torch.empty((), device=actual_device)
expected = actual.clone().to(expected_device)
for fn in assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, "device"):
fn()

@deviceCountAtLeast(1)
def test_mismatching_device_no_check(self, devices):
for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2):
actual = torch.rand((), device=actual_device)
expected = actual.clone().to(expected_device)
for fn in assert_fns_with_inputs(actual, expected):
fn(check_device=False)


instantiate_device_type_tests(TestAssertsMultiDevice, globals(), only_for="cuda")


class TestAssertsErrorMessage(TestCase):
def test_mismatches(self):
actual = torch.tensor([1, 2, 3, 4])
expected = torch.tensor([1, 2, 5, 6])

for fn in assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("Mismatched elements: 2 / 4 (50.0%)")):
fn()

def test_mismatching_values_msg_abs_diff(self):
def test_max_abs_diff(self):
actual = torch.tensor([[1, 2], [3, 4]])
expected = torch.tensor([[1, 2], [5, 4]])

for fn in assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("Greatest absolute difference: 2 at (1, 0)")):
fn()

def test_mismatching_values_msg_rel_diff(self):
def test_max_rel_diff(self):
actual = torch.tensor([[1, 2], [3, 4]])
expected = torch.tensor([[1, 4], [3, 4]])

for fn in assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("Greatest relative difference: 0.5 at (0, 1)")):
fn()

def test_mismatching_values_zero_div_zero(self):
actual = torch.tensor([1.0, 0.0])
expected = torch.tensor([2.0, 0.0])

for fn in assert_fns_with_inputs(actual, expected):
# Although it looks complicated, this regex just makes sure that the word 'nan' is not part of the error
# message. That would happen if the 0 / 0 is used for the mismatch computation although it matches.
with self.assertRaisesRegex(AssertionError, "((?!nan).)*"):
fn()

def test_mismatching_values_msg_complex_real(self):
actual = torch.tensor(complex(0, 1))
expected = torch.tensor(complex(1, 1))

for fn in assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the real part")):
fn()

def test_mismatching_values_msg_complex_imag(self):
actual = torch.tensor(complex(1, 0))
expected = torch.tensor(complex(1, 1))

for fn in assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the imaginary part")):
fn()

def test_assert_close_mismatching_values_msg_rtol(self):
def test_assert_close_rtol(self):
rtol = 1e-3

actual = torch.tensor(1)
Expand All @@ -1019,7 +1029,7 @@ def test_assert_close_mismatching_values_msg_rtol(self):
):
torch.testing.assert_close(*inputs, rtol=rtol, atol=0.0)

def test_assert_close_mismatching_values_msg_atol(self):
def test_assert_close_atol(self):
atol = 1e-3

actual = torch.tensor(1)
Expand All @@ -1031,6 +1041,42 @@ def test_assert_close_mismatching_values_msg_atol(self):
):
torch.testing.assert_close(*inputs, rtol=0.0, atol=atol)

def test_mismatching_values_zero_div_zero(self):
actual = torch.tensor([1.0, 0.0])
expected = torch.tensor([2.0, 0.0])

for fn in assert_fns_with_inputs(actual, expected):
# Although it looks complicated, this regex just makes sure that the word 'nan' is not part of the error
# message. That would happen if a matching 0.0 pair is used for the mismatch computation although it
# matches.
with self.assertRaisesRegex(AssertionError, "((?!nan).)*"):
fn()

def test_msg_str(self):
msg = "Custom error message!"

actual = torch.tensor(1)
expected = torch.tensor(2)

for fn in assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, msg):
fn(msg=msg)

def test_msg_callable(self):
msg = "Custom error message!"

def make_msg(actual, expected, trace):
return msg

actual = torch.tensor(1)
expected = torch.tensor(2)

for fn in assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, msg):
fn(msg=make_msg)


class TestAssertsContainer(TestCase):
def test_sequence_mismatching_len(self):
actual = (torch.empty(()),)
expected = ()
Expand Down Expand Up @@ -1069,82 +1115,45 @@ def test_mapping_mismatching_values_msg(self):
with self.assertRaisesRegex(AssertionError, r"key\s+'b'"):
fn(actual, expected)

def test_type_inequality(self):
actual = torch.empty(2)
expected = actual.tolist()

for fn in assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, str(type(expected))):
fn()

def test_unknown_type(self):
actual = "0"
expected = "0"

for fn in assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(UsageError, str(type(actual))):
fn()

def test_numpy(self):
tensor = torch.rand(2, 2, dtype=torch.float32)
actual = tensor.numpy()
expected = actual.copy()
class TestAssertsComplex(TestCase):
def test_matching(self):
actual = torch.tensor(complex(1, 2))
expected = actual.clone()

for fn in assert_fns_with_inputs(actual, expected):
fn()

def test_scalar(self):
number = torch.randint(10, size=()).item()
for actual, expected in itertools.product((int(number), float(number), complex(number)), repeat=2):
check_dtype = type(actual) is type(expected)

for fn in assert_fns_with_inputs(actual, expected):
fn(check_dtype=check_dtype)

def test_msg_str(self):
msg = "Custom error message!"

actual = torch.tensor(1)
expected = torch.tensor(2)
def test_assert_close_equal_nan(self):
a = torch.tensor(complex(1, float("NaN")))
b = torch.tensor(complex(float("NaN"), 1))

for fn in assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, msg):
fn(msg=msg)
for inputs in make_assert_inputs(a, b):
with self.assertRaises(AssertionError):
torch.testing.assert_close(*inputs, equal_nan=True)

def test_msg_callable(self):
msg = "Custom error message!"
def test_assert_close_equal_nan_relaxed(self):
a = torch.tensor(complex(1, float("NaN")))
b = torch.tensor(complex(float("NaN"), 1))

def make_msg(actual, expected, trace):
return msg
for inputs in make_assert_inputs(a, b):
torch.testing.assert_close(*inputs, equal_nan="relaxed")

actual = torch.tensor(1)
expected = torch.tensor(2)
def test_mismatching_values_msg_real(self):
actual = torch.tensor(complex(0, 1))
expected = torch.tensor(complex(1, 1))

for fn in assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, msg):
fn(msg=make_msg)


class TestAssertsMultiDevice(TestCase):
@deviceCountAtLeast(1)
def test_mismatching_device(self, devices):
for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2):
actual = torch.empty((), device=actual_device)
expected = actual.clone().to(expected_device)
for fn in assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, "device"):
fn()

@deviceCountAtLeast(1)
def test_mismatching_device_no_check(self, devices):
for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2):
actual = torch.rand((), device=actual_device)
expected = actual.clone().to(expected_device)
for fn in assert_fns_with_inputs(actual, expected):
fn(check_device=False)
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the real part")):
fn()

def test_mismatching_values_msg_imag(self):
actual = torch.tensor(complex(1, 0))
expected = torch.tensor(complex(1, 1))

instantiate_device_type_tests(TestAssertsMultiDevice, globals(), only_for="cuda")
for fn in assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the imaginary part")):
fn()


if __name__ == '__main__':
Expand Down

0 comments on commit 8472541

Please sign in to comment.