From aa8d31f74b87492ce8546dd3599db403f12ebf44 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Tue, 18 Mar 2025 12:25:31 -0700 Subject: [PATCH] fix dynamic_shapes spec for moco (#148772) Summary: Fixes https://github.com/pytorch/pytorch/issues/148333 X-link: https://github.com/pytorch/pytorch/pull/148772 Approved by: https://github.com/yushangdi, https://github.com/desertfire Differential Revision: D71412041 --- userbenchmark/dynamo/dynamobench/common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/userbenchmark/dynamo/dynamobench/common.py b/userbenchmark/dynamo/dynamobench/common.py index 312fc3a6ee..52b4b62c28 100644 --- a/userbenchmark/dynamo/dynamobench/common.py +++ b/userbenchmark/dynamo/dynamobench/common.py @@ -1408,7 +1408,7 @@ class AOTInductorModelCache: def load(cls, model, example_inputs): import torch._inductor import torch.export._trace - from torch.export.dynamic_shapes import _tree_map_with_path + from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path key = weakref.ref(model) if key not in cls.cache: @@ -1428,7 +1428,7 @@ def load(cls, model, example_inputs): else: _register_dataclass_output_as_pytree(example_outputs) - combined_args = tuple(example_args) + tuple(example_kwargs.values()) + combined_args = _combine_args(model, example_args, example_kwargs) dynamic_shapes = _tree_map_with_path( _produce_dynamic_shapes_for_export, combined_args ) @@ -1449,13 +1449,13 @@ def load(cls, model, example_inputs): def export(model, example_inputs): - from torch.export.dynamic_shapes import _tree_map_with_path + from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path example_args, example_kwargs = _normalize_bench_inputs(example_inputs) example_outputs = model(*example_args, **example_kwargs) _register_dataclass_output_as_pytree(example_outputs) - combined_args = tuple(example_args) + tuple(example_kwargs.values()) + combined_args = _combine_args(model, example_args, example_kwargs) dynamic_shapes = _tree_map_with_path( _produce_dynamic_shapes_for_export, combined_args )