Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[export] Fix tree spec matching behavior. #109679

Closed
wants to merge 1 commit into from
Closed

Conversation

zhxchen17
Copy link
Contributor

Summary:

Test Plan:
Internal test.
Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 20, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109679

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6d70453 with merge base 6e3a747 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

@angelayi angelayi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just a cleaner way of doing things? Or is there actually a bug that it's fixing?

@zhxchen17
Copy link
Contributor Author

Is this just a cleaner way of doing things? Or is there actually a bug that it's fixing?

@angelayi yeah I'm trying to fix an internal test. The current behavior of tree_flatten_spec doesn't not actually check whether we're using up all the elements in a container. For example, the following will pass today:

_, spec = pytree.tree_flatten([1])
res = fx_pytree.tree_flatten_spec([1, 2, 3], spec)
assert len(res) == 1

In this case, pytree will drop the rest of elements that're not mentioned in the original spec. Now this might be intended for most cases, for our case we also need to check whether we have an exact match on pytree.
Also, I realized that simply comparing pytree spec might not work since dicts are unordered, therefore I added an BC preserving option to the tree_flatten_spec function.

@facebook-github-bot
Copy link
Contributor

@zhxchen17 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@zhxchen17 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@angelayi angelayi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can also add some tests in OSS?

Summary:

Test Plan:
Internal test.
Reviewers:

Subscribers:

Tasks:

Tags:
@facebook-github-bot
Copy link
Contributor

@zhxchen17 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@zhxchen17
Copy link
Contributor Author

Added unit test

@zhxchen17
Copy link
Contributor Author

@pytorchbot merge -f "unblocking internal server model enablement"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@zou3519
Copy link
Contributor

zou3519 commented Sep 21, 2023

In this case, pytree will drop the rest of elements that're not mentioned in the original spec. Now this might be intended for most cases, for our case we also need to check whether we have an exact match on pytree.

This is not intended. We should fix tree_flatten_spec instead of adding this new API.

@zhxchen17
Copy link
Contributor Author

In this case, pytree will drop the rest of elements that're not mentioned in the original spec. Now this might be intended for most cases, for our case we also need to check whether we have an exact match on pytree.

This is not intended. We should fix tree_flatten_spec instead of adding this new API.

@zou3519 To be clear I'm not adding a new API, I'm adding a new flag to existing tree_flatten_spec without trying to break the existing code.

@zou3519
Copy link
Contributor

zou3519 commented Sep 21, 2023

@zhxchen17 this PR modified the register_pytree_flatten_spec to take in two functions, a flatten_fn_spec function and an flatten_fn_exact_match_spec function. Assuming that the correct semantics are "flatten AND check that the number of elements match", then the better design is to fix the existing flatten_fn_spec to also check that the number of elements match.

def register_pytree_flatten_spec(
    typ: Any,
    flatten_fn_spec: FlattenFuncSpec,
    flatten_fn_exact_match_spec: 

@zhxchen17
Copy link
Contributor Author

@zhxchen17 this PR modified the register_pytree_flatten_spec to take in two functions, a flatten_fn_spec function and an flatten_fn_exact_match_spec function. Assuming that the correct semantics are "flatten AND check that the number of elements match", then the better design is to fix the existing flatten_fn_spec to also check that the number of elements match.

def register_pytree_flatten_spec(
    typ: Any,
    flatten_fn_spec: FlattenFuncSpec,
    flatten_fn_exact_match_spec: 

@zou3519 The problem of fixing pytree_flatten_spec entirely is, it's almost impossible to land a change like this with unknown number of clients who is already relying on the undesirable behavior.

For example, we can have a torch packaged program who make a lot of calls to pytree_flatten_spec() and the number of elements doesn't match in production systems, and I also saw a torch xla test depending on the previous more permissive behavior.

If we really want to fix the behavior for all existing cases, I can help with a second PR to flip the exect_structural_match to True by default and observe all the test failures and fix them incrementally (or simply remove the flag and do the same).
Otherwise I can also revert this PR if you have more concerns than this and I'm fine with this.

Which makes more sense to you?

@zou3519
Copy link
Contributor

zou3519 commented Sep 21, 2023

If we really want to fix the behavior for all existing cases, I can help with a second PR to flip the exect_structural_match to True by default and observe all the test failures and fix them incrementally (or simply remove the flag and do the same).

We should at least attempt a yolo fix via a follow-up PR that flips the default and seeing what breaks. If it's too much, then yeah, it's probably not worth cleaning this up, but the hypothesis is that tree_flatten_spec silently ignoring too many elements is not expected behavior

@zhxchen17
Copy link
Contributor Author

following up in #109841

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants