New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Support tuples in ScriptModule inputs/outputs #20784
Conversation
31913e0
to
57b3715
Compare
72a71bb
to
91aa7d4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Fix the lint error please |
91aa7d4
to
5b1da10
Compare
Fixed, thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
32d045c
to
baf2107
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
torch/csrc/jit/script/init.cpp
Outdated
DimensionedTensorType::create(stack.at(i).toTensor())); | ||
static TypePtr getTensorType( | ||
const at::Tensor& t, | ||
bool isDimensionedTensor) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find it's a bit hard to understand the meaning of isDimensionedTensor
's meaning. Can we get a better name for it? And add some comments why we need it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, changed to type_kind
instead, marking what is the desired kind of tensor type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Could you address my inline comments?
baf2107
to
a0b42ae
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks!
d9e9558
to
9f53df3
Compare
Thanks! rebased to resolve conflict. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
The test failures are unrelated. So landing |
@houseroad merged this pull request in a3db284. |
Add tests after [ONNX] Fix bug in exporting node with multiple outputs by scripting #20256 is merged
Support exporting ScriptModule with inputs/outputs of arbitrarily constructed tuples.
Moved the assigning of output shapes to after graph conversion to ONNX is completed. By then all tuples in the IR has already been lowered by the pass
_jit_pass_lower_all_tuples
. If assigning output shapes is required to happen before that, we'll need to hand parse the tuple structures in the graph, and repeat the same logic in_jit_pass_lower_all_tuples
. Handling inputs is easier because all tuple information is encoded within the input tensor type.Swap the order of
_jit_pass_lower_all_tuples
and_jit_pass_erase_number_types
. Ops likeprim::TupleIndex
relies on index being a scalar._jit_pass_erase_number_types
will convert these kind of scalars to tensors.