diff --git a/test/test_testing.py b/test/test_testing.py index a7f0f3bc76db7..feb408773f4c8 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -1870,6 +1870,33 @@ def test_two_things_default_names(self, device, x, y): test_names = _get_test_names_for_test_class(device_cls) self.assertEqual(expected_test_names, test_names) + def test_default_name_non_primitive(self, device): + device = self.device_type + + class TestParametrized(TestCase): + @parametrize("x", [1, .5, "foo", object()]) + def test_default_names(self, device, x): + pass + + @parametrize("x,y", [(1, object()), (object(), .5), (object(), object())]) + def test_two_things_default_names(self, device, x, y): + pass + + instantiate_device_type_tests(TestParametrized, locals(), only_for=device) + + device_cls = locals()[f'TestParametrized{device.upper()}'] + expected_test_names = sorted(name.format(device_cls.__name__, device) for name in ( + '{}.test_default_names_x_1_{}', + '{}.test_default_names_x_0_5_{}', + '{}.test_default_names_x_foo_{}', + '{}.test_default_names_x3_{}', + '{}.test_two_things_default_names_x_1_y0_{}', + '{}.test_two_things_default_names_x1_y_0_5_{}', + '{}.test_two_things_default_names_x2_y2_{}') + ) + test_names = _get_test_names_for_test_class(device_cls) + self.assertEqual(expected_test_names, test_names) + def test_name_fn(self, device): device = self.device_type diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 9e16d7174560c..4f63b41e48036 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -459,29 +459,30 @@ def __init__(self, arg_str, arg_values, name_fn=None): self.arg_values = arg_values self.name_fn = name_fn - def _formatted_str_repr(self, name, value): + def _formatted_str_repr(self, idx, name, value): """ Returns a string representation for the given arg that is suitable for use in test function names. """ if isinstance(value, torch.dtype): return dtype_name(value) elif isinstance(value, torch.device): return str(value) # Can't use isinstance as it would cause a circular import - elif value.__class__.__name__ == 'OpInfo' or value.__class__.__name__ == 'ModuleInfo': + elif type(value).__name__ in {'OpInfo', 'ModuleInfo'}: return value.formatted_name - else: - # Include name and value separated by underscore. + elif isinstance(value, (int, float, str)): return f"{name}_{str(value).replace('.', '_')}" + else: + return f"{name}{idx}" - def _default_subtest_name(self, values): - return '_'.join([self._formatted_str_repr(a, v) for a, v in zip(self.arg_names, values)]) + def _default_subtest_name(self, idx, values): + return '_'.join([self._formatted_str_repr(idx, a, v) for a, v in zip(self.arg_names, values)]) - def _get_subtest_name(self, values, explicit_name=None): + def _get_subtest_name(self, idx, values, explicit_name=None): if explicit_name: subtest_name = explicit_name elif self.name_fn: subtest_name = self.name_fn(*values) else: - subtest_name = self._default_subtest_name(values) + subtest_name = self._default_subtest_name(idx, values) return subtest_name def _parametrize_test(self, test, generic_cls, device_cls): @@ -494,7 +495,7 @@ def _parametrize_test(self, test, generic_cls, device_cls): # * A tuple of values with one for each arg. For a single arg, a single item is expected. # * A subtest instance with arg_values matching the previous. values = check_exhausted_iterator = object() - for values in self.arg_values: + for idx, values in enumerate(self.arg_values): maybe_name = None decorators = [] @@ -519,7 +520,7 @@ def test_wrapper(*args, **kwargs): param_kwargs = dict(zip(self.arg_names, values)) - test_name = self._get_subtest_name(values, explicit_name=maybe_name) + test_name = self._get_subtest_name(idx, values, explicit_name=maybe_name) def decorator_fn(_, decorators=decorators): return decorators