Skip to content

Conversation

@zhxchen17
Copy link
Contributor

@zhxchen17 zhxchen17 commented Nov 28, 2023

Summary: Refactor torch.export to separate strict part and non strict part. Adding an option to torch.export called strict=True.

Test Plan: buck2 test mode/opt caffe2/test:test_export -- -r non_strict

cc @avikchaudhuri @gmagogsfm @tugsbayasgalan @angelayi @suo @ydwu4

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 28, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114697

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (12 Unrelated Failures)

As of commit c4f4800 with merge base 3b7d60b (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D51604074

@zhxchen17 zhxchen17 requested review from angelayi, suo and ydwu4 November 28, 2023 20:17
@zhxchen17 zhxchen17 added the topic: not user facing topic category label Nov 28, 2023
zhxchen17 added a commit to zhxchen17/pytorch that referenced this pull request Nov 28, 2023
…#114697)

Summary:

Refactor torch.export to separate strict part and non strict part. Adding an option to torch.export called `strict=True`.

Test Plan: buck2 test mode/opt caffe2/test:test_export -- -r non_strict

Differential Revision: D51604074
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D51604074

Copy link
Contributor

@ydwu4 ydwu4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me! Except it's a bit larger and difficult to review lol. I was distracted by some of the BE modifications and have a hard time relate the old implementation to the new implementaion. Left a few minor comments

*,
constraints: Optional[List[Constraint]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
strict: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since "strict" is user-facing. We might need to add warnings and docs for this "strict" keyword. It would be better if we can mention what are the observable consequences/implications of using strict.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should make it "_strict"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this matters because non strict mode is already unsafe by design, and we just need to fix forward bugs if there's any.

gm = res.graph_module

assert orig_out_spec is not None
tensor_constants = lift_constant_tensor_pass(gm, export_graph_signature, params_buffers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

em... why we need to remove _replace_sym_size_ops_pass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we move it to _export_non_strict directly because this pass is pretty generic.

fake_args,
_reorder_kwargs_by_names(orig_args, fake_args, fake_kwargs),
fake_params_buffers,
transform=_process_user_inputs
Copy link
Contributor

@ydwu4 ydwu4 Nov 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking: do we really need the transform kwarg? It looks like a wrapper. When seeing the word transform, I thought it's some passes over the graph module. Can we just do the wrapping eagerly then pass the wrapped into _export_non_strict? It's a bit more readable I guess.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah good point. I might remove this at some point but right now I need this to support user input mutations which requires some code to run before and after aot_export_module, so it has a sandwich structure. I can add a TODO here.

inp = ([torch.ones(1, 3)], torch.ones(1, 3))
self._test_export_same_as_eager(f, inp)

def test_basic_non_strict(self):
Copy link
Contributor

@ydwu4 ydwu4 Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also add a test for fake tensor mode and fake tensor inputs for non strict mode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure.

Comment on lines 655 to 656
# Note: aot_export_module doesn't accept kwargs, we'd like to reorder the kwargs as an OrderedDict
# to follow the order in orig_args and correctly call module
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this comment relevant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch

kwargs = kwargs or {}

if not strict:
assert isinstance(f, torch.nn.Module)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we assert this if we just wrap f with another module?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right, but I'd prefer to keep it here until we do need it.

sig.buffers_to_mutate = pytree.tree_map(strip_root, sig.buffers_to_mutate)
return gm, sig
return _aot_export_non_strict
ep_non_strict = _export_non_strict(f, args, {}, f.state_dict(), transform=_tuplify_outputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to fakeify the args? Seems like _export_non_strict takes fake args? Or we can move the _convert_input_to_fake call to _export_non_strict?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I will leave this part for later discussion when we're supporting dynamic shapes. Right now it shouldn't matter.

if id(dynamo_buffer) in buffer_lookup:
param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)].pop()

if isinstance(f, torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the lines above related to the param_buffer_table, should that also be added to _export_to_non_strict? Since we probably want to keep the param/buffer names the same in both cases right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

param_buffer_table is only for dynamo which messes up the state dict. If we're using aot_export_module directly we don't need to do anything special here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok, then maybe we should move that to export_to_torch_ir (not related to this PR)

o: b for o, b in graph_signature.buffers_to_mutate.items() if b not in names
}
graph_signature.user_inputs = list(reversed(new_node_names.values())) # type: ignore[arg-type]
graph_signature.user_inputs.extend(new_node_names.values())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the reversed not matter here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put it at line 625

super().__init__()
self._export_root = mod

def forward(self, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just:

is_scalar = False
def forward(self, args):
       inner = self._export_root(args)
       if inner is not (list, dict, tuple):
          nonlocal is_scalar
          is_scalar = True
          return tuple(inner)
       return inner

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds like pytree is more fool-proof?

*,
constraints: Optional[List[Constraint]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
strict: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should make it "_strict"?

@facebook-github-bot
Copy link
Contributor

@zhxchen17 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@zhxchen17 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you leave a comment here explaining we don't actually use aot_export_module's out_spec so it is up to us to manipulate it however we want?

…#114697)

Summary:

Refactor torch.export to separate strict part and non strict part. Adding an option to torch.export called `strict=True`.

Test Plan: buck2 test mode/opt caffe2/test:test_export -- -r non_strict

Differential Revision: D51604074
@zhxchen17
Copy link
Contributor Author

@pytorchbot merge -f "fbgemm errors"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR has internal changes and must be landed via Phabricator

Details for Dev Infra team Raised by workflow job

@zhxchen17
Copy link
Contributor Author

@pytorchbot merge -f "fbgemm errors"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR has internal changes and must be landed via Phabricator

Details for Dev Infra team Raised by workflow job

@zhxchen17
Copy link
Contributor Author

@pytorchbot merge -f "fbgemm errors"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Dec 7, 2023
)

Current non-strict test cases (added in #114697) are already supported by strict mode, so it can't demonstrate the incremental value of non-strict mode. How about adding test cases that fail in strict mode but pass in non-strict mode?

Test Plan:
python test/export/test_export.py -k test_external_call_non_strict_real_tensor
Pull Request resolved: #115245
Approved by: https://github.com/tugsbayasgalan, https://github.com/zhxchen17
dmenig pushed a commit to dmenig/pytorch that referenced this pull request Dec 21, 2023
…#114697)

Summary: Refactor torch.export to separate strict part and non strict part. Adding an option to torch.export called `strict=True`.

Test Plan: buck2 test mode/opt caffe2/test:test_export -- -r non_strict

Pull Request resolved: pytorch#114697
Approved by: https://github.com/ydwu4, https://github.com/tugsbayasgalan
dmenig pushed a commit to dmenig/pytorch that referenced this pull request Dec 21, 2023
…rch#115245)

Current non-strict test cases (added in pytorch#114697) are already supported by strict mode, so it can't demonstrate the incremental value of non-strict mode. How about adding test cases that fail in strict mode but pass in non-strict mode?

Test Plan:
python test/export/test_export.py -k test_external_call_non_strict_real_tensor
Pull Request resolved: pytorch#115245
Approved by: https://github.com/tugsbayasgalan, https://github.com/zhxchen17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants