Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,7 +1408,7 @@ class AOTInductorModelCache:
def load(cls, model, example_inputs):
import torch._inductor
import torch.export._trace
from torch.export.dynamic_shapes import _tree_map_with_path
from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path

key = weakref.ref(model)
if key not in cls.cache:
Expand All @@ -1428,7 +1428,7 @@ def load(cls, model, example_inputs):
else:
_register_dataclass_output_as_pytree(example_outputs)

combined_args = tuple(example_args) + tuple(example_kwargs.values())
combined_args = _combine_args(model, example_args, example_kwargs)
dynamic_shapes = _tree_map_with_path(
_produce_dynamic_shapes_for_export, combined_args
)
Expand All @@ -1449,13 +1449,13 @@ def load(cls, model, example_inputs):


def export(model, example_inputs):
from torch.export.dynamic_shapes import _tree_map_with_path
from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path

example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
example_outputs = model(*example_args, **example_kwargs)
_register_dataclass_output_as_pytree(example_outputs)

combined_args = tuple(example_args) + tuple(example_kwargs.values())
combined_args = _combine_args(model, example_args, example_kwargs)
dynamic_shapes = _tree_map_with_path(
_produce_dynamic_shapes_for_export, combined_args
)
Expand Down
Loading