-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[export] Refactor export() and separate the non-strict part. #114697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
This pull request was exported from Phabricator. Differential Revision: D51604074 |
…#114697) Summary: Refactor torch.export to separate strict part and non strict part. Adding an option to torch.export called `strict=True`. Test Plan: buck2 test mode/opt caffe2/test:test_export -- -r non_strict Differential Revision: D51604074
158810a to
935bdf6
Compare
|
This pull request was exported from Phabricator. Differential Revision: D51604074 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good to me! Except it's a bit larger and difficult to review lol. I was distracted by some of the BE modifications and have a hard time relate the old implementation to the new implementaion. Left a few minor comments
| *, | ||
| constraints: Optional[List[Constraint]] = None, | ||
| dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, | ||
| strict: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since "strict" is user-facing. We might need to add warnings and docs for this "strict" keyword. It would be better if we can mention what are the observable consequences/implications of using strict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should make it "_strict"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this matters because non strict mode is already unsafe by design, and we just need to fix forward bugs if there's any.
| gm = res.graph_module | ||
|
|
||
| assert orig_out_spec is not None | ||
| tensor_constants = lift_constant_tensor_pass(gm, export_graph_signature, params_buffers) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
em... why we need to remove _replace_sym_size_ops_pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we move it to _export_non_strict directly because this pass is pretty generic.
| fake_args, | ||
| _reorder_kwargs_by_names(orig_args, fake_args, fake_kwargs), | ||
| fake_params_buffers, | ||
| transform=_process_user_inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non-blocking: do we really need the transform kwarg? It looks like a wrapper. When seeing the word transform, I thought it's some passes over the graph module. Can we just do the wrapping eagerly then pass the wrapped into _export_non_strict? It's a bit more readable I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah good point. I might remove this at some point but right now I need this to support user input mutations which requires some code to run before and after aot_export_module, so it has a sandwich structure. I can add a TODO here.
test/export/test_export.py
Outdated
| inp = ([torch.ones(1, 3)], torch.ones(1, 3)) | ||
| self._test_export_same_as_eager(f, inp) | ||
|
|
||
| def test_basic_non_strict(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we also add a test for fake tensor mode and fake tensor inputs for non strict mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure.
torch/_export/__init__.py
Outdated
| # Note: aot_export_module doesn't accept kwargs, we'd like to reorder the kwargs as an OrderedDict | ||
| # to follow the order in orig_args and correctly call module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this comment relevant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch
| kwargs = kwargs or {} | ||
|
|
||
| if not strict: | ||
| assert isinstance(f, torch.nn.Module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we assert this if we just wrap f with another module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're right, but I'd prefer to keep it here until we do need it.
| sig.buffers_to_mutate = pytree.tree_map(strip_root, sig.buffers_to_mutate) | ||
| return gm, sig | ||
| return _aot_export_non_strict | ||
| ep_non_strict = _export_non_strict(f, args, {}, f.state_dict(), transform=_tuplify_outputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to fakeify the args? Seems like _export_non_strict takes fake args? Or we can move the _convert_input_to_fake call to _export_non_strict?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I will leave this part for later discussion when we're supporting dynamic shapes. Right now it shouldn't matter.
| if id(dynamo_buffer) in buffer_lookup: | ||
| param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)].pop() | ||
|
|
||
| if isinstance(f, torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the lines above related to the param_buffer_table, should that also be added to _export_to_non_strict? Since we probably want to keep the param/buffer names the same in both cases right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
param_buffer_table is only for dynamo which messes up the state dict. If we're using aot_export_module directly we don't need to do anything special here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh ok, then maybe we should move that to export_to_torch_ir (not related to this PR)
torch/_export/__init__.py
Outdated
| o: b for o, b in graph_signature.buffers_to_mutate.items() if b not in names | ||
| } | ||
| graph_signature.user_inputs = list(reversed(new_node_names.values())) # type: ignore[arg-type] | ||
| graph_signature.user_inputs.extend(new_node_names.values()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the reversed not matter here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I put it at line 625
| super().__init__() | ||
| self._export_root = mod | ||
|
|
||
| def forward(self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just:
is_scalar = False
def forward(self, args):
inner = self._export_root(args)
if inner is not (list, dict, tuple):
nonlocal is_scalar
is_scalar = True
return tuple(inner)
return inner
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds like pytree is more fool-proof?
| *, | ||
| constraints: Optional[List[Constraint]] = None, | ||
| dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, | ||
| strict: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should make it "_strict"?
935bdf6 to
4fe7912
Compare
|
@zhxchen17 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
4fe7912 to
908b0de
Compare
|
@zhxchen17 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
torch/_export/__init__.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you leave a comment here explaining we don't actually use aot_export_module's out_spec so it is up to us to manipulate it however we want?
…#114697) Summary: Refactor torch.export to separate strict part and non strict part. Adding an option to torch.export called `strict=True`. Test Plan: buck2 test mode/opt caffe2/test:test_export -- -r non_strict Differential Revision: D51604074
908b0de to
c4f4800
Compare
|
@pytorchbot merge -f "fbgemm errors" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: This PR has internal changes and must be landed via Phabricator Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -f "fbgemm errors" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: This PR has internal changes and must be landed via Phabricator Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -f "fbgemm errors" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
) Current non-strict test cases (added in #114697) are already supported by strict mode, so it can't demonstrate the incremental value of non-strict mode. How about adding test cases that fail in strict mode but pass in non-strict mode? Test Plan: python test/export/test_export.py -k test_external_call_non_strict_real_tensor Pull Request resolved: #115245 Approved by: https://github.com/tugsbayasgalan, https://github.com/zhxchen17
…#114697) Summary: Refactor torch.export to separate strict part and non strict part. Adding an option to torch.export called `strict=True`. Test Plan: buck2 test mode/opt caffe2/test:test_export -- -r non_strict Pull Request resolved: pytorch#114697 Approved by: https://github.com/ydwu4, https://github.com/tugsbayasgalan
…rch#115245) Current non-strict test cases (added in pytorch#114697) are already supported by strict mode, so it can't demonstrate the incremental value of non-strict mode. How about adding test cases that fail in strict mode but pass in non-strict mode? Test Plan: python test/export/test_export.py -k test_external_call_non_strict_real_tensor Pull Request resolved: pytorch#115245 Approved by: https://github.com/tugsbayasgalan, https://github.com/zhxchen17
Summary: Refactor torch.export to separate strict part and non strict part. Adding an option to torch.export called
strict=True.Test Plan: buck2 test mode/opt caffe2/test:test_export -- -r non_strict
cc @avikchaudhuri @gmagogsfm @tugsbayasgalan @angelayi @suo @ydwu4