Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
ezyang committed May 1, 2024
1 parent 9d2b0a6 commit 8bb7f84
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
4 changes: 2 additions & 2 deletions test/test_fake_tensor.py
Expand Up @@ -832,6 +832,7 @@ def make_propagate_real_tensors_cls(cls):
"_propagate_real_tensors",
(torch._functorch.config, "fake_tensor_propagate_real_tensors", True),
xfail_prop="_expected_failure_propagate_real_tensors",
decorator=skipIfTorchDynamo("propagate_real_tensors affects Dynamo"),
)
cls.__file__ = __file__
cls.__module__ = __name__
Expand Down Expand Up @@ -1351,8 +1352,7 @@ def to_fake_tensor(x):
self.assertTrue(failed)


# Propagate real tensors doesn't work with fake-on-fake
@expectedFailurePropagateRealTensors
@expectedFailurePropagateRealTensors # Propagate real tensors doesn't work with fake-on-fake
def test_fake_tensor_prop_on_nn_module_with_optional_args(self):
class OptionalArgumentInBetween(torch.nn.Module):
def __init__(self):
Expand Down
6 changes: 4 additions & 2 deletions torch/_dynamo/testing.py
Expand Up @@ -311,7 +311,9 @@ def _fn(*args, **kwargs):
return _fn


def make_test_cls_with_patches(cls, cls_prefix, fn_suffix, *patches, xfail_prop=None):
def make_test_cls_with_patches(
cls, cls_prefix, fn_suffix, *patches, xfail_prop=None, decorator=lambda x: x
):
DummyTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {})
DummyTestClass.__qualname__ = DummyTestClass.__name__

Expand All @@ -326,7 +328,7 @@ def make_test_cls_with_patches(cls, cls_prefix, fn_suffix, *patches, xfail_prop=
new_fn.__name__ = new_name
if xfail_prop is not None and hasattr(fn, xfail_prop):
new_fn = unittest.expectedFailure(new_fn)
setattr(DummyTestClass, new_name, new_fn)
setattr(DummyTestClass, new_name, decorator(new_fn))
# NB: Doesn't handle slots correctly, but whatever
elif not hasattr(DummyTestClass, name):
setattr(DummyTestClass, name, getattr(cls, name))
Expand Down

0 comments on commit 8bb7f84

Please sign in to comment.