Skip to content

Support dynamic shapes for aten_unfold #2407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 20, 2025
Merged

Conversation

xenova
Copy link
Contributor

@xenova xenova commented Jun 20, 2025

While converting a new model that I'd like to add to Transformers.js, I ran into #2309, indicating that dynamic shapes aren't currently supported for aten_unfold:

  File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/onnxscript/function_libs/torch_lib/ops/core.py", line 8662, in aten_unfold
    low_indices = range(0, dim_size, step)
TypeError: 'SymbolicDim' object cannot be interpreted as an integer

So, I dug a bit into the code and with some help from Claude, I got a version which works for my use-case (output matches exactly)! 👍

Code to reproduce (adapted from pytorch/pytorch#112844 (comment))

import torch

class SpecMaker(torch.nn.Module):
    def forward(self, x):
        return torch.ops.aten.unfold(x, -1, 512, 160)

specmodel = SpecMaker()
input = torch.rand(32000 * 10)
spec = specmodel(input)
input_batch = torch.stack([input, input])
spec_batch = specmodel(input_batch)

onnx_program = torch.onnx.export(
    specmodel,
    (input_batch,),
    f="/tmp/model.onnx",
    dynamic_shapes=[{0: "dim_x",1:"length"}],
    input_names=["input"],
    output_names=["output"],
    dynamo=True,
    report=True,
)

Logs (before)

(base) ➜  onnxscript git:(main) ✗ python testing/unfold.py
[torch.onnx] Obtain model graph for `SpecMaker()` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `SpecMaker()` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ❌
[torch.onnx] Export report has been saved to 'onnx_export_2025-06-20_14-08-52-474773_conversion.md'.
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 519, in _handle_call_function_node_with_lowering
    outputs = onnx_function(*onnx_args, **onnx_kwargs)
  File ".../onnxscript/onnxscript/values.py", line 625, in __call__
    return self.func(*args, **kwargs)
           ~~~~~~~~~^^^^^^^^^^^^^^^^^
  File ".../onnxscript/onnxscript/function_libs/torch_lib/ops/core.py", line 8660, in aten_unfold
    low_indices = range(0, dim_size, step)
TypeError: 'SymbolicDim' object cannot be interpreted as an integer

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 707, in _translate_fx_graph
    _handle_call_function_node_with_lowering(
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        model,
        ^^^^^^
    ...<6 lines>...
        node_name_to_local_functions=node_name_to_local_functions,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 521, in _handle_call_function_node_with_lowering
    raise _errors.GraphConstructionError(
        f"Error when calling function '{onnx_function}' with args '{onnx_args}' and kwargs '{onnx_kwargs}'"
    ) from e
torch.onnx._internal.exporter._errors.GraphConstructionError: Error when calling function 'TracedOnnxFunction(<function aten_unfold at 0x120baa7a0>)' with args '[SymbolicTensor(name='x', type=Tensor(FLOAT), shape=Shape([SymbolicDim(s0), SymbolicDim(s1)])), -1, 512, 160]' and kwargs '{}'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 1373, in export
    onnx_program = _exported_program_to_onnx_program(
        decomposed_program, registry=registry
    )
  File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 1007, in _exported_program_to_onnx_program
    values = _translate_fx_graph(
        fx_graph,
    ...<4 lines>...
        registry=registry,
    )
  File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 733, in _translate_fx_graph
    raise _errors.ConversionError(
        f"Error when translating node {node.format_node()}. See the stack trace for more information."
    ) from e
torch.onnx._internal.exporter._errors.ConversionError: Error when translating node %unfold : [num_users=1] = call_function[target=torch.ops.aten.unfold.default](args = (%x, -1, 512, 160), kwargs = {}). See the stack trace for more information.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File ".../onnxscript/testing/unfold.py", line 15, in <module>
    onnx_program = torch.onnx.export(
        specmodel,
    ...<7 lines>...
        # verbose=True,
    )
  File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/torch/onnx/__init__.py", line 364, in export
    return _compat.export_compat(
           ~~~~~~~~~~~~~~~~~~~~~^
        model,
        ^^^^^^
    ...<19 lines>...
        fallback=fallback,
        ^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_compat.py", line 120, in export_compat
    onnx_program = _core.export(
        model,
    ...<11 lines>...
        verbose=verbose,
    )
  File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 1419, in export
    raise _errors.ConversionError(
    ...<3 lines>...
    ) from e
torch.onnx._internal.exporter._errors.ConversionError: Failed to convert the exported program to an ONNX model. This is step 3/3 of exporting the model to ONNX. Next steps:
- If there is a missing ONNX function, implement it and register it to the registry.
- If there is an internal error during ONNX conversion, debug the error and summit a PR to PyTorch.
- Create an error report with `torch.onnx.export(..., report=True)`, and save the ExportedProgram as a pt2 file. Create an issue in the PyTorch GitHub repository against the *onnx* component. Attach the error report and the pt2 model.
Error report has been saved to 'onnx_export_2025-06-20_14-08-52-474773_conversion.md'.

## Exception summary

<class 'TypeError'>: 'SymbolicDim' object cannot be interpreted as an integer
⬆️
<class 'torch.onnx._internal.exporter._errors.GraphConstructionError'>: Error when calling function 'TracedOnnxFunction(<function aten_unfold at 0x120baa7a0>)' with args '[SymbolicTensor(name='x', type=Tensor(FLOAT), shape=Shape([SymbolicDim(s0), SymbolicDim(s1)])), -1, 512, 160]' and kwargs '{}'
⬆️
<class 'torch.onnx._internal.exporter._errors.ConversionError'>: Error when translating node %unfold : [num_users=1] = call_function[target=torch.ops.aten.unfold.default](args = (%x, -1, 512, 160), kwargs = {}). See the stack trace for more information.

(Refer to the full stack trace above for more information.)

Logs (after)

(base) ➜  onnxscript git:(main) ✗ python testing/unfold.py
[torch.onnx] Obtain model graph for `SpecMaker()` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `SpecMaker()` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
[torch.onnx] Export report has been saved to 'onnx_export_2025-06-20_14-11-27-804730_success.md'.
Applied 1 of general pattern rewrite rules.

Closes #2309. cc @justinchuby

@justinchuby
Copy link
Collaborator

Amazing, thanks you!

@justinchuby justinchuby requested a review from Copilot June 20, 2025 18:21
@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Jun 20, 2025
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces support for dynamic shapes in the aten_unfold operator to resolve conversion errors when using symbolic dimensions with torch.onnx.export.

  • Refactored the handling of dim_size by replacing direct tensor indexing with op.Shape and op.Gather.
  • Reworked the window generation logic using op.Range, broadcasting, and reordering the output with op.Transpose to maintain shape consistency.
Comments suppressed due to low confidence (2)

onnxscript/function_libs/torch_lib/ops/core.py:8701

  • [nitpick] Consider adding a clarifying inline comment explaining the permutation logic used in op.Transpose to improve maintainability and assist future reviewers.
        perm.append(perm.pop(dimension + 1))

onnxscript/function_libs/torch_lib/ops/core.py:8669

  • Ensure that op.Div performs floor division to correctly compute output_size as per '(input_size - kernel_size) // stride + 1'. If it doesn't, consider using an explicit floor division operator or integer casting.
        )

Copy link

codecov bot commented Jun 20, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 70.39%. Comparing base (03ab4c5) to head (0b7b0ee).
Report is 1 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2407   +/-   ##
=======================================
  Coverage   70.38%   70.39%           
=======================================
  Files         199      199           
  Lines       25223    25226    +3     
  Branches     2686     2686           
=======================================
+ Hits        17753    17757    +4     
+ Misses       6541     6540    -1     
  Partials      929      929           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@xenova
Copy link
Contributor Author

xenova commented Jun 20, 2025

Thanks for the review 👍 Changes made ✅

Copy link
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! The code is very skilled at using op.Gather and broadcasting!

@titaiwangms titaiwangms merged commit 38871a5 into microsoft:main Jun 20, 2025
26 of 32 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: torchlib Related to the torch/aten function lib in development
Projects
Development

Successfully merging this pull request may close these issues.

[torchlib] Support dynamic shapes for aten_unfold
3 participants