Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Support tuples in ScriptModule inputs/outputs #20784

Closed
wants to merge 6 commits into from

Conversation

BowenBao
Copy link
Collaborator

@BowenBao BowenBao commented May 21, 2019

  • Add tests after [ONNX] Fix bug in exporting node with multiple outputs by scripting #20256 is merged

  • Support exporting ScriptModule with inputs/outputs of arbitrarily constructed tuples.

  • Moved the assigning of output shapes to after graph conversion to ONNX is completed. By then all tuples in the IR has already been lowered by the pass _jit_pass_lower_all_tuples. If assigning output shapes is required to happen before that, we'll need to hand parse the tuple structures in the graph, and repeat the same logic in _jit_pass_lower_all_tuples. Handling inputs is easier because all tuple information is encoded within the input tensor type.

  • Swap the order of _jit_pass_lower_all_tuples and _jit_pass_erase_number_types. Ops like prim::TupleIndex relies on index being a scalar. _jit_pass_erase_number_types will convert these kind of scalars to tensors.

@pytorchbot pytorchbot added oncall: jit Add this issue/PR to JIT oncall triage queue module: onnx Related to torch.onnx labels May 21, 2019
@BowenBao BowenBao changed the title [ONNX][WIP] Support tuples in ScriptModule inputs/outputs [ONNX] Support tuples in ScriptModule inputs/outputs May 23, 2019
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@houseroad
Copy link
Member

Fix the lint error please

@BowenBao
Copy link
Collaborator Author

Fix the lint error please

Fixed, thanks.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@BowenBao BowenBao force-pushed the onnx_tuple_script branch 2 times, most recently from 32d045c to baf2107 Compare June 6, 2019 17:13
@ezyang ezyang requested a review from houseroad June 6, 2019 18:53
@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 6, 2019
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

DimensionedTensorType::create(stack.at(i).toTensor()));
static TypePtr getTensorType(
const at::Tensor& t,
bool isDimensionedTensor) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it's a bit hard to understand the meaning of isDimensionedTensor's meaning. Can we get a better name for it? And add some comments why we need it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, changed to type_kind instead, marking what is the desired kind of tensor type.

Copy link
Member

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Could you address my inline comments?

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Member

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

@BowenBao
Copy link
Collaborator Author

Looks good, thanks!

Thanks! rebased to resolve conflict.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@houseroad
Copy link
Member

The test failures are unrelated. So landing

@dzhulgakov
Copy link
Collaborator

cc @zdevito @suo FYI in case we want to expand it to more types and there's already some other convenience function

@facebook-github-bot
Copy link
Contributor

@houseroad merged this pull request in a3db284.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: onnx Related to torch.onnx oncall: jit Add this issue/PR to JIT oncall triage queue open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants