Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

split TestAsserts by functionality #58919

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
198 changes: 102 additions & 96 deletions test/test_testing.py
Expand Up @@ -835,6 +835,22 @@ def test_quantized_support(self):
with self.assertRaises(UsageError):
fn()

def test_type_inequality(self):
actual = torch.empty(2)
expected = actual.tolist()

for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, str(type(expected))):
fn()

def test_unknown_type(self):
actual = "0"
expected = "0"

for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(UsageError, str(type(actual))):
fn()

def test_mismatching_shape(self):
actual = torch.empty(())
expected = actual.clone().reshape((1,))
Expand Down Expand Up @@ -951,46 +967,71 @@ def test_matching_nan_with_equal_nan(self):
for fn in assert_close_with_inputs(actual, expected):
fn(equal_nan=True)

def test_mismatching_complex_nan_with_equal_nan(self):
actual = torch.tensor(complex(1, float("NaN")))
expected = torch.tensor(complex(float("NaN"), 1))
def test_numpy(self):
tensor = torch.rand(2, 2, dtype=torch.float32)
actual = tensor.numpy()
expected = actual.copy()

for fn in assert_close_with_inputs(actual, expected):
with self.assertRaises(AssertionError):
fn(equal_nan=True)
fn()

def test_mismatching_complex_nan_with_equal_nan_relaxed(self):
actual = torch.tensor(complex(1, float("NaN")))
expected = torch.tensor(complex(float("NaN"), 1))
def test_scalar(self):
number = torch.randint(10, size=()).item()
for actual, expected in itertools.product((int(number), float(number), complex(number)), repeat=2):
check_dtype = type(actual) is type(expected)

for fn in assert_close_with_inputs(actual, expected):
fn(check_dtype=check_dtype)


class TestAssertCloseMultiDevice(TestCase):
@deviceCountAtLeast(1)
def test_mismatching_device(self, devices):
for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2):
actual = torch.empty((), device=actual_device)
expected = actual.clone().to(expected_device)
for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, "device"):
fn()

@deviceCountAtLeast(1)
def test_mismatching_device_no_check(self, devices):
for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2):
actual = torch.rand((), device=actual_device)
expected = actual.clone().to(expected_device)
for fn in assert_close_with_inputs(actual, expected):
fn(check_device=False)

for fn in assert_close_with_inputs(actual, expected):
fn(equal_nan="relaxed")

def test_mismatching_values_msg_mismatches(self):
instantiate_device_type_tests(TestAssertCloseMultiDevice, globals(), only_for="cuda")


class TestAssertCloseErrorMessage(TestCase):
def test_mismatched_elements(self):
actual = torch.tensor([1, 2, 3, 4])
expected = torch.tensor([1, 2, 5, 6])

for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("Mismatched elements: 2 / 4 (50.0%)")):
fn()

def test_mismatching_values_msg_abs_diff(self):
def test_abs_diff(self):
actual = torch.tensor([[1, 2], [3, 4]])
expected = torch.tensor([[1, 2], [5, 4]])

for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("Greatest absolute difference: 2 at (1, 0)")):
fn()

def test_mismatching_values_msg_rel_diff(self):
def test_rel_diff(self):
actual = torch.tensor([[1, 2], [3, 4]])
expected = torch.tensor([[1, 4], [3, 4]])

for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("Greatest relative difference: 0.5 at (0, 1)")):
fn()

def test_mismatching_values_zero_div_zero(self):
def test_zero_div_zero(self):
actual = torch.tensor([1.0, 0.0])
expected = torch.tensor([2.0, 0.0])

Expand All @@ -1000,23 +1041,7 @@ def test_mismatching_values_zero_div_zero(self):
with self.assertRaisesRegex(AssertionError, "((?!nan).)*"):
fn()

def test_mismatching_values_msg_complex_real(self):
actual = torch.tensor(complex(0, 1))
expected = torch.tensor(complex(1, 1))

for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the real part")):
fn()

def test_mismatching_values_msg_complex_imag(self):
actual = torch.tensor(complex(1, 0))
expected = torch.tensor(complex(1, 1))

for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the imaginary part")):
fn()

def test_mismatching_values_msg_rtol(self):
def test_rtol(self):
rtol = 1e-3

actual = torch.tensor(1)
Expand All @@ -1028,7 +1053,7 @@ def test_mismatching_values_msg_rtol(self):
):
fn(rtol=rtol, atol=0.0)

def test_mismatching_values_msg_atol(self):
def test_atol(self):
atol = 1e-3

actual = torch.tensor(1)
Expand All @@ -1040,6 +1065,31 @@ def test_mismatching_values_msg_atol(self):
):
fn(rtol=0.0, atol=atol)

def test_msg_str(self):
msg = "Custom error message!"

actual = torch.tensor(1)
expected = torch.tensor(2)

for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, msg):
fn(msg=msg)

def test_msg_callable(self):
msg = "Custom error message!"

def make_msg(actual, expected, trace):
return msg

actual = torch.tensor(1)
expected = torch.tensor(2)

for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, msg):
fn(msg=make_msg)


class TestAssertCloseContainer(TestCase):
def test_sequence_mismatching_len(self):
actual = (torch.empty(()),)
expected = ()
Expand Down Expand Up @@ -1074,82 +1124,38 @@ def test_mapping_mismatching_values_msg(self):
with self.assertRaisesRegex(AssertionError, r"key\s+'b'"):
torch.testing.assert_close(actual, expected)

def test_type_inequality(self):
actual = torch.empty(2)
expected = actual.tolist()

for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, str(type(expected))):
fn()

def test_unknown_type(self):
actual = "0"
expected = "0"
class TestAssertCloseComplex(TestCase):
def test_mismatching_nan_with_equal_nan(self):
actual = torch.tensor(complex(1, float("NaN")))
expected = torch.tensor(complex(float("NaN"), 1))

for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(UsageError, str(type(actual))):
fn()
with self.assertRaises(AssertionError):
fn(equal_nan=True)

def test_numpy(self):
tensor = torch.rand(2, 2, dtype=torch.float32)
actual = tensor.numpy()
expected = actual.copy()
def test_mismatching_nan_with_equal_nan_relaxed(self):
actual = torch.tensor(complex(1, float("NaN")))
expected = torch.tensor(complex(float("NaN"), 1))

for fn in assert_close_with_inputs(actual, expected):
fn()

def test_scalar(self):
number = torch.randint(10, size=()).item()
for actual, expected in itertools.product((int(number), float(number), complex(number)), repeat=2):
check_dtype = type(actual) is type(expected)

for fn in assert_close_with_inputs(actual, expected):
fn(check_dtype=check_dtype)

def test_msg_str(self):
msg = "Custom error message!"
fn(equal_nan="relaxed")

actual = torch.tensor(1)
expected = torch.tensor(2)
def test_mismatching_values_msg_real(self):
actual = torch.tensor(complex(0, 1))
expected = torch.tensor(complex(1, 1))

for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, msg):
fn(msg=msg)

def test_msg_callable(self):
msg = "Custom error message!"

def make_msg(actual, expected, trace):
return msg
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the real part")):
fn()

actual = torch.tensor(1)
expected = torch.tensor(2)
def test_mismatching_values_msg_imag(self):
actual = torch.tensor(complex(1, 0))
expected = torch.tensor(complex(1, 1))

for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, msg):
fn(msg=make_msg)


class TestAssertCloseMultiDevice(TestCase):
@deviceCountAtLeast(1)
def test_mismatching_device(self, devices):
for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2):
actual = torch.empty((), device=actual_device)
expected = actual.clone().to(expected_device)
for fn in assert_close_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, "device"):
fn()

@deviceCountAtLeast(1)
def test_mismatching_device_no_check(self, devices):
for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2):
actual = torch.rand((), device=actual_device)
expected = actual.clone().to(expected_device)
for fn in assert_close_with_inputs(actual, expected):
fn(check_device=False)


instantiate_device_type_tests(TestAssertCloseMultiDevice, globals(), only_for="cuda")
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the imaginary part")):
fn()


if __name__ == '__main__':
Expand Down