-
Notifications
You must be signed in to change notification settings - Fork 25.2k
[ONNX] Support 'aten::randint' in torchscript onnx exporter #105089
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Export as 'ONNX::RandomUniform' which produces floating point result, then round it to integer with 'ONNX::Cast'. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/105089
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit ffd7663: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…t onnx exporter" Export as 'ONNX::RandomUniform' which produces floating point result, then round it to integer with 'ONNX::Cast'. [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I am assuming we don't care testing torch.jit.script scenarios at this point. scalar_type handling usually crash for script as opposed to tracing
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
randint = g.op("Cast", randn, to_i=int_dtype.onnx_type()) | ||
if int_dtype != scalar_type: | ||
randint = g.op("Cast", randint, to_i=scalar_type.onnx_type()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to understand: why do we need to cast twice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
first time to "floor" floating random number to integer.
second time the dtype
is part of op argument, could be floating number.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks!
Stack from ghstack (oldest at bottom):
Export as 'ONNX::RandomUniform' which produces floating point result,
then round it to integer with 'ONNX::Cast'.
Fixes https://github.com/microsoft/onnx-converters-private/issues/173