Skip to content

[dynamo exporter] Fix dynamic shapes with DynamicCache #1832

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions olive/common/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def convert_configs_to_dicts(config: Any) -> Any:


def get_the_flattened_and_tree_spec(
dynamic_shapes: Union[dict[str, Any], list[Any]], leave_is_str: bool = False
dynamic_shapes: Union[dict[str, Any], list[Any]], leaf_is_str: bool = False
) -> tuple[list[Any], Any]:
"""Flattens a pytree into a list of values and a TreeSpec that can be used to reconstruct the pytree."""
# More info: https://github.com/pytorch/pytorch/blob/48203bec636692e1a9140fe7f23ba1323b19550d/torch/utils/_pytree.py#L985
Expand All @@ -395,4 +395,4 @@ def is_axes_with_int_key(x) -> bool:
and all(isinstance(k, int) and (v is None or isinstance(v, (str, int))) for k, v in x.items())
) or (isinstance(x, (list, tuple)) and all(v is None or isinstance(v, (str, int)) for v in x))

return _pytree.tree_flatten(dynamic_shapes, is_leaf=is_axes_with_str_key if leave_is_str else is_axes_with_int_key)
return _pytree.tree_flatten(dynamic_shapes, is_leaf=is_axes_with_str_key if leaf_is_str else is_axes_with_int_key)
5 changes: 4 additions & 1 deletion olive/model/config/io_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ def convert_dynamic_shapes(cls, v):
if not v:
return v

flattened, tree_spec = get_the_flattened_and_tree_spec(v, leave_is_str=True)
# dict: {axis: axis_name} -> {int(axis): axis_name}
# list/tuple: [axis_name] -> [axis_name]

flattened, tree_spec = get_the_flattened_and_tree_spec(v, leaf_is_str=True)
new_flattened = []
for axes in flattened:
if isinstance(axes, dict):
Expand Down
36 changes: 17 additions & 19 deletions olive/passes/onnx/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,12 @@ def _export_pytorch_model(
io_config.dynamic_shapes, dummy_inputs, dummy_kwargs = _validate_dynamic_shapes(
io_config.dynamic_shapes, dummy_inputs, dummy_kwargs, pytorch_model
)

# there might be multiple files created during export, so we need to track the dir
# if there are other processes writing to the same dir, we might end up deleting files created by
# other processes
with tempfile.TemporaryDirectory(dir=tempdir, prefix="olive_tmp") as tmp_dir:
tmp_dir_path = Path(tmp_dir)
tmp_model_path = resolve_onnx_path(tmp_dir_path)

onnx_program = torch.onnx.export( # pylint: disable=unexpected-keyword-arg,no-value-for-parameter
pytorch_model,
dummy_inputs,
Expand Down Expand Up @@ -682,32 +680,32 @@ def _validate_dynamic_shapes(dynamic_shapes, dummy_inputs, dummy_kwargs, model):

from torch.utils import _pytree

flat_dynamic_shapes, _ = get_the_flattened_and_tree_spec(dynamic_shapes)

# dict: {axis: axis_name} -> {int(axis): axis_name}
# list/tuple: [axis_name] -> [axis_name]
new_dynamic_shapes = [
{int(k): v for k, v in axes.items()} if isinstance(axes, dict) else axes for axes in flat_dynamic_shapes
]
# Align tree spec only for not transformers.Cache.
if len(dummy_inputs) == 0:
for k, v in dummy_kwargs.items():
if not isinstance(v, transformers.Cache):
input_tree_spec = _pytree.tree_flatten(v)[1]
flatten_dynamic_shapes = get_the_flattened_and_tree_spec(dynamic_shapes[k], leaf_is_str=False)[0]
dynamic_shapes[k] = _pytree.tree_unflatten(flatten_dynamic_shapes, input_tree_spec)
else:
for i, v in enumerate(dummy_inputs):
if not isinstance(v, transformers.Cache):
input_tree_spec = _pytree.tree_flatten(v)[1]
flatten_dynamic_shapes = get_the_flattened_and_tree_spec(dynamic_shapes[i], leaf_is_str=False)[0]
dynamic_shapes[i] = _pytree.tree_unflatten(flatten_dynamic_shapes, input_tree_spec)

# The input can only be either args or kwargs according to line 237.
if len(dummy_inputs) == 0:
# dummy_inputs is empty, so it must be kwargs
_, tree_structure = get_the_flattened_and_tree_spec(dummy_kwargs, leave_is_str=False)
unflatten_dynamic_shapes = _pytree.tree_unflatten(new_dynamic_shapes, tree_structure)

# NOTE: dynamic_shapes need to follow the same model.forward signature when it's referring to kwargs.
param_order = list(inspect.signature(model.forward).parameters)
# Sort io_config.dynamic_shapes based on this order
unflatten_dynamic_shapes = collections.OrderedDict(
sorted(unflatten_dynamic_shapes.items(), key=lambda item: param_order.index(item[0]))
dynamic_shapes = collections.OrderedDict(
sorted(dynamic_shapes.items(), key=lambda item: param_order.index(item[0]))
)
dummy_kwargs = collections.OrderedDict(
sorted(dummy_kwargs.items(), key=lambda item: param_order.index(item[0]))
)
return unflatten_dynamic_shapes, dummy_inputs, dummy_kwargs
return dynamic_shapes, dummy_inputs, dummy_kwargs
# If dynamic_shapes and dummy_inputs are both list/tuple, we don't need to sort.
# dummy_inputs is args
_, tree_structure = get_the_flattened_and_tree_spec(dummy_inputs, leave_is_str=False)
unflatten_dynamic_shapes = _pytree.tree_unflatten(new_dynamic_shapes, tree_structure)
return unflatten_dynamic_shapes, dummy_inputs, dummy_kwargs
return dynamic_shapes, dummy_inputs, dummy_kwargs
4 changes: 2 additions & 2 deletions test/unit_test/passes/onnx/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,12 @@ def forward(
{"a": {0: "axis_batch"}, "b": {1: "x_axis"}},
None,
],
(
[
{0: "axis_batch", 1: "x_axis"},
({1: "x_axis"}, {0: "axis_batch"}),
{"a": {0: "axis_batch"}, "b": {1: "x_axis"}},
None,
),
],
_get_simulate_torch_float_tensor_inputs(return_tuple=True),
),
(
Expand Down