diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/testing.py b/userbenchmark/dynamo/dynamobench/_dynamo/testing.py index 8a0f95670b..b4c022e8d8 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/testing.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/testing.py @@ -311,7 +311,9 @@ def _fn(*args, **kwargs): return _fn -def make_test_cls_with_patches(cls, cls_prefix, fn_suffix, *patches, xfail_prop=None): +def make_test_cls_with_patches( + cls, cls_prefix, fn_suffix, *patches, xfail_prop=None, decorator=lambda x: x +): DummyTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {}) DummyTestClass.__qualname__ = DummyTestClass.__name__ @@ -326,7 +328,7 @@ def make_test_cls_with_patches(cls, cls_prefix, fn_suffix, *patches, xfail_prop= new_fn.__name__ = new_name if xfail_prop is not None and hasattr(fn, xfail_prop): new_fn = unittest.expectedFailure(new_fn) - setattr(DummyTestClass, new_name, new_fn) + setattr(DummyTestClass, new_name, decorator(new_fn)) # NB: Doesn't handle slots correctly, but whatever elif not hasattr(DummyTestClass, name): setattr(DummyTestClass, name, getattr(cls, name)) diff --git a/userbenchmark/dynamo/dynamobench/common.py b/userbenchmark/dynamo/dynamobench/common.py index 2ecadef60e..48ea0496da 100644 --- a/userbenchmark/dynamo/dynamobench/common.py +++ b/userbenchmark/dynamo/dynamobench/common.py @@ -75,6 +75,7 @@ graph_break_reasons, maybe_enable_compiled_autograd, ) +import torch._functorch.config from torch._functorch.aot_autograd import set_model_name from torch._inductor import config as inductor_config, metrics from torch._subclasses.fake_tensor import FakeTensorMode @@ -3155,6 +3156,11 @@ def get_example_inputs(self): action="store_true", help="Runs a dynamic shapes version of the benchmark, if available.", ) + parser.add_argument( + "--propagate-real-tensors", + action="store_true", + help="Capture as much data dependent as you can by unsoundly propagating real tensors", + ) parser.add_argument( "--dynamic-batch-only", action="store_true", @@ -3603,6 +3609,11 @@ def run(runner, args, original_dir=None): if args.dynamic_shapes: if not args.dynamic_batch_only: torch._dynamo.config.assume_static_by_default = False + if args.propagate_real_tensors: + # TODO: Separate flag for data dependent + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + torch._functorch.config.fake_tensor_propagate_real_tensors = True if args.specialize_int: torch._dynamo.config.specialize_int = True if args.ci: