Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[export] allow register dataclass as pytree node #106160

Closed
wants to merge 1 commit into from
Closed

Conversation

ydwu4
Copy link
Contributor

@ydwu4 ydwu4 commented Jul 27, 2023

In this pr, we allow users to register a customized flatten/unflatten/serialization/deserialization for a dataclass. We provide some default implementation for flatten/unflatten. We could implement a decorator based on it when needed.

Motivation:

HuggingFace and many internal models return dataclass output and torch.export wants to maintain the invariant that export result (i.e. exported_program) has the same calling convention and result as the original callable.

This is not supported in export yet: we cannot recover the original dataclass from flattened output produced by the underlying graph module (produced by dynamo and processed further by aot_export). We need to have a place to store the metadata of the dataclass so that we can re-construct it. To avoid adding hacky code in export and allow princinpled extensibility, we think extending pytree may be a good option.

Implementation:

@zou3519 mentioned https://github.com/pytorch/pytorch/pull/93214/files and jax-2371, which suggests that it's not a good idea to make dataclass a default pytree node but it could be good to provide a default implementation for dataclass. Since currently, this seems to be an export-only feature, we added this extension point in export.

We also add "return_none_fields" flag to control whether none fields are returned after flattening, which is expected to be False in produce_matching of dynamo.export.

Also added some tests.

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 27, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106160

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8eaea18:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@ydwu4 ydwu4 requested review from tugsbayasgalan, gmagogsfm and angelayi and removed request for gmagogsfm July 27, 2023 20:08
@ydwu4 ydwu4 requested a review from zhxchen17 July 27, 2023 20:08
@ydwu4 ydwu4 added release notes: export ciflow/trunk Trigger trunk jobs on your pull request labels Jul 27, 2023
@ydwu4 ydwu4 changed the title Allow to register dataclass as pytree node [export] add entrypoint in export to register dataclass as pytree node Jul 27, 2023
@ydwu4 ydwu4 changed the title [export] add entrypoint in export to register dataclass as pytree node [export] allow register dataclass as pytree node Jul 27, 2023
@zou3519 zou3519 self-requested a review July 27, 2023 21:31
@ydwu4
Copy link
Contributor Author

ydwu4 commented Jul 28, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

bobby-palmer pushed a commit to bobby-palmer/pytorch that referenced this pull request Jul 29, 2023
In this pr, we allow users to register a customized flatten/unflatten/serialization/deserialization for a dataclass. We provide some default implementation for flatten/unflatten. We could implement a decorator based on it when needed.

## Motivation:
HuggingFace and many internal models return dataclass output and torch.export wants to maintain the invariant that export result (i.e. exported_program) has the same calling convention and result as the original callable.

This is not supported in export yet: we cannot recover the original dataclass from flattened output produced by the underlying graph module (produced by dynamo and processed further by aot_export). We need to have a place to store the metadata of the dataclass so that we can re-construct it. To avoid adding hacky code in export and allow princinpled extensibility, we think extending pytree may be a good option.

## Implementation:
@zou3519 mentioned https://github.com/pytorch/pytorch/pull/93214/files and [jax-2371](google/jax#2371 (comment)), which suggests that it's not a good idea to make dataclass a default pytree node but it could be good to provide a default implementation for dataclass. Since currently, this seems to be an export-only feature, we added this extension point in export.

We also add "return_none_fields" flag to control whether none fields are returned after flattening, which is expected to be False in produce_matching of dynamo.export.

Also added some tests.

Pull Request resolved: pytorch#106160
Approved by: https://github.com/zhxchen17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: export
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants