diff --git a/userbenchmark/dynamo/dynamobench/common.py b/userbenchmark/dynamo/dynamobench/common.py index 312fc3a6ee..52b4b62c28 100644 --- a/userbenchmark/dynamo/dynamobench/common.py +++ b/userbenchmark/dynamo/dynamobench/common.py @@ -1408,7 +1408,7 @@ class AOTInductorModelCache: def load(cls, model, example_inputs): import torch._inductor import torch.export._trace - from torch.export.dynamic_shapes import _tree_map_with_path + from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path key = weakref.ref(model) if key not in cls.cache: @@ -1428,7 +1428,7 @@ def load(cls, model, example_inputs): else: _register_dataclass_output_as_pytree(example_outputs) - combined_args = tuple(example_args) + tuple(example_kwargs.values()) + combined_args = _combine_args(model, example_args, example_kwargs) dynamic_shapes = _tree_map_with_path( _produce_dynamic_shapes_for_export, combined_args ) @@ -1449,13 +1449,13 @@ def load(cls, model, example_inputs): def export(model, example_inputs): - from torch.export.dynamic_shapes import _tree_map_with_path + from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path example_args, example_kwargs = _normalize_bench_inputs(example_inputs) example_outputs = model(*example_args, **example_kwargs) _register_dataclass_output_as_pytree(example_outputs) - combined_args = tuple(example_args) + tuple(example_kwargs.values()) + combined_args = _combine_args(model, example_args, example_kwargs) dynamic_shapes = _tree_map_with_path( _produce_dynamic_shapes_for_export, combined_args )