Skip to content

Commit

Permalink
robustify parametrize default name (#113856)
Browse files Browse the repository at this point in the history
#113340 was reverted initially due to a bad default parametrization name. The test looked like

```python
@common_utils.parametrize(
    "type_fn",
    [
        type,
        lambda obj: obj.__class__,
    ],
)
def test_access_class_method_from_user_class(self, type_fn):
```

This is a valid parametrization, but results in these default test names:

```bash
❯ pytest test/dynamo/test_export.py -k test_access_class_method_from_user_class --co -q
test/dynamo/test_export.py::ExportTests::test_access_class_method_from_user_class_type_fn_<class 'type'>
test/dynamo/test_export.py::ExportTests::test_access_class_method_from_user_class_type_fn_<function ExportTests_<lambda> at 0x7f3be5de0c10>
```

Ignoring the whitespace in the test names, which can lead to other issues down the line, the problem in #113340 was that the lambda parameter included a memory address. IIUC, internally, the tests are not collected and run in the same process. Meaning, the address of the lambda and in turn the test name is no longer valid on the runner. This is fixed earlier in the stack by giving the parametrization an explicit name with `subtest`, but this PR is about preventing issues in the default case.

`pytest` solves this by simply using the name of the parameter plus its index as id in the test name:

```python
import pytest

class Foo:
    def __repr__(self):
        return str(id(self))

@pytest.mark.parametrize(
    "bar",
    [
        pytest.param(type),
        pytest.param(lambda obj: obj.__class__),
        pytest.param(Foo()),
    ],
)
def test_foo(bar):
    pass
```

```
❯ pytest main.py --co -q
main.py::test_foo[type]
main.py::test_foo[<lambda>]
main.py::test_foo[bar2]
```

`pytest` has better defaults for `type` and `lambda` than we do, but is has a safe default for custom objects.

This PR aligns our default test name with `pytest`. Using the parametrization from above again, we now collect

```bash
❯ pytest test/dynamo/test_export.py -k test_access_class_method_from_user_class --co -q
test/dynamo/test_export.py::ExportTests::test_access_class_method_from_user_class_type_fn0
test/dynamo/test_export.py::ExportTests::test_access_class_method_from_user_class_type_fn1
```

which might not be as expressive at first glance, but at least prevents bugs.
Pull Request resolved: #113856
Approved by: https://github.com/malfet, https://github.com/huydhn
ghstack dependencies: #113855
  • Loading branch information
pmeier authored and pytorchmergebot committed Nov 16, 2023
1 parent 03bebd9 commit 769f924
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 10 deletions.
27 changes: 27 additions & 0 deletions test/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1870,6 +1870,33 @@ def test_two_things_default_names(self, device, x, y):
test_names = _get_test_names_for_test_class(device_cls)
self.assertEqual(expected_test_names, test_names)

def test_default_name_non_primitive(self, device):
device = self.device_type

class TestParametrized(TestCase):
@parametrize("x", [1, .5, "foo", object()])
def test_default_names(self, device, x):
pass

@parametrize("x,y", [(1, object()), (object(), .5), (object(), object())])
def test_two_things_default_names(self, device, x, y):
pass

instantiate_device_type_tests(TestParametrized, locals(), only_for=device)

device_cls = locals()[f'TestParametrized{device.upper()}']
expected_test_names = sorted(name.format(device_cls.__name__, device) for name in (
'{}.test_default_names_x_1_{}',
'{}.test_default_names_x_0_5_{}',
'{}.test_default_names_x_foo_{}',
'{}.test_default_names_x3_{}',
'{}.test_two_things_default_names_x_1_y0_{}',
'{}.test_two_things_default_names_x1_y_0_5_{}',
'{}.test_two_things_default_names_x2_y2_{}')
)
test_names = _get_test_names_for_test_class(device_cls)
self.assertEqual(expected_test_names, test_names)

def test_name_fn(self, device):
device = self.device_type

Expand Down
21 changes: 11 additions & 10 deletions torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,29 +459,30 @@ def __init__(self, arg_str, arg_values, name_fn=None):
self.arg_values = arg_values
self.name_fn = name_fn

def _formatted_str_repr(self, name, value):
def _formatted_str_repr(self, idx, name, value):
""" Returns a string representation for the given arg that is suitable for use in test function names. """
if isinstance(value, torch.dtype):
return dtype_name(value)
elif isinstance(value, torch.device):
return str(value)
# Can't use isinstance as it would cause a circular import
elif value.__class__.__name__ == 'OpInfo' or value.__class__.__name__ == 'ModuleInfo':
elif type(value).__name__ in {'OpInfo', 'ModuleInfo'}:
return value.formatted_name
else:
# Include name and value separated by underscore.
elif isinstance(value, (int, float, str)):
return f"{name}_{str(value).replace('.', '_')}"
else:
return f"{name}{idx}"

def _default_subtest_name(self, values):
return '_'.join([self._formatted_str_repr(a, v) for a, v in zip(self.arg_names, values)])
def _default_subtest_name(self, idx, values):
return '_'.join([self._formatted_str_repr(idx, a, v) for a, v in zip(self.arg_names, values)])

def _get_subtest_name(self, values, explicit_name=None):
def _get_subtest_name(self, idx, values, explicit_name=None):
if explicit_name:
subtest_name = explicit_name
elif self.name_fn:
subtest_name = self.name_fn(*values)
else:
subtest_name = self._default_subtest_name(values)
subtest_name = self._default_subtest_name(idx, values)
return subtest_name

def _parametrize_test(self, test, generic_cls, device_cls):
Expand All @@ -494,7 +495,7 @@ def _parametrize_test(self, test, generic_cls, device_cls):
# * A tuple of values with one for each arg. For a single arg, a single item is expected.
# * A subtest instance with arg_values matching the previous.
values = check_exhausted_iterator = object()
for values in self.arg_values:
for idx, values in enumerate(self.arg_values):
maybe_name = None

decorators = []
Expand All @@ -519,7 +520,7 @@ def test_wrapper(*args, **kwargs):

param_kwargs = dict(zip(self.arg_names, values))

test_name = self._get_subtest_name(values, explicit_name=maybe_name)
test_name = self._get_subtest_name(idx, values, explicit_name=maybe_name)

def decorator_fn(_, decorators=decorators):
return decorators
Expand Down

0 comments on commit 769f924

Please sign in to comment.