Skip to content

Conversation

titaiwangms
Copy link
Collaborator

Add dynamo: bool = True as a switch in torch.onnx.export to provide users an option to try torch.onnx.dynamo_export.

Copy link

pytorch-bot bot commented Jun 4, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/127974

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit af535f0 with merge base 22964d1 (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Jun 4, 2024
Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks!

(torch.randn(1, 1, 2), torch.randn(1, 1, 2)),
dynamo=True,
)
self.assertEqual(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you say something about how this equality assertion works? Just for my own education, wondering what kind of infrastructure exists for proto-object equality check.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. I didn't verify that. This might be invalid. I will verify this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is some internal equality check according to https://proto-plus-python.readthedocs.io/en/stable/_modules/proto/message.html. I think it should work when two messages are exactly the same.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does catch the difference between two protobuf if there is any. Here is their code:

@titaiwangms
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 5, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

)
# TODO: check args normalization
args = _decide_input_format(model, args)
kwargs = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a follow up PR, I would add an kwarg parameter to the api and encourage users to use that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.onnx.export has its own way to support kwargs. Do we want to add another way of supporting kwargs?

pytorch/torch/onnx/utils.py

Lines 240 to 281 in 4adee71

3. A TUPLE OF ARGUMENTS ENDING WITH A DICTIONARY OF NAMED ARGUMENTS::
args = (
x,
{
"y": input_y,
"z": input_z
}
)
All but the last element of the tuple will be passed as non-keyword arguments,
and named arguments will be set from the last element. If a named argument is
not present in the dictionary, it is assigned the default value, or None if a
default value is not provided.
.. note::
If a dictionary is the last element of the args tuple, it will be
interpreted as containing named arguments. In order to pass a dict as the
last non-keyword arg, provide an empty dict as the last element of the args
tuple. For example, instead of::
torch.onnx.export(
model,
(
x,
# WRONG: will be interpreted as named arguments
{y: z}
),
"test.onnx.pb"
)
Write::
torch.onnx.export(
model,
(
x,
{y: z},
{}
),
"test.onnx.pb"
)

args: Union[Tuple[Any, ...], torch.Tensor],
f: Union[str, io.BytesIO],
f: Optional[Union[str, io.BytesIO]] = None,
export_params: bool = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
export_params: bool = True,
*,
export_params: bool = True,

as a follow up

keep_initializers_as_inputs: Optional[bool] = None,
custom_opsets: Optional[Mapping[str, int]] = None,
export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]] = False,
autograd_inlining: Optional[bool] = True,
Copy link
Collaborator

@justinchuby justinchuby Jun 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
autograd_inlining: Optional[bool] = True,
autograd_inlining: Optional[bool] = True,
kwargs: Mapping[str, Any] | None = None,

As a follow up

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: new features topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants