Skip to content

[fsdp][torch.compile] FSDP changes #115497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from

Conversation

voznesenskym
Copy link
Collaborator

@voznesenskym voznesenskym commented Dec 10, 2023

Copy link

pytorch-bot bot commented Dec 10, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/115497

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 2c65d1a with merge base 87ea6fb (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@voznesenskym voznesenskym mentioned this pull request Dec 12, 2023
@albanD albanD added oncall: distributed Add this issue/PR to distributed oncall triage queue and removed module: distributed labels Dec 13, 2023
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
@fegin fegin added the ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR label Dec 14, 2023
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
@voznesenskym voznesenskym requested a review from albanD December 16, 2023 18:41
@@ -441,7 +441,9 @@ def _pre_forward_unshard(
if not handle._prefetched:
_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
handle._needs_pre_forward_unshard = False
state._device_handle.current_stream().wait_stream(state._unshard_stream)
# Don't wait during trace
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is the trace valid if it contains stream information but not stream synchronization?
I guess we don't care for now since it's ignored by the backend but most likely will need fixing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is sound for now - but the longer term version of this will need to teach inductor about .wait_stream and plumb the whole thing through.

Copy link
Contributor

@zejun-chen zejun-chen Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @voznesenskym @albanD

Yes, this is sound for now - but the longer term version of this will need to teach inductor about .wait_stream and plumb the whole thing through.

May i know why here bypassing the stream methods? Does it mean for the longer term plan, the stream and other runtime primitives, like event, will be compiled through the torch.compile(dynamo+AOTAutograd+inductor)?
May i know if there is a plan for supporting stream/event in whole torch.compile stack?

Thank you. :)

and hasattr(handle, "_ran_pre_backward_hook")
and handle._ran_pre_backward_hook
):
log.warning("%s %s", id(state), "Not Running pre backward! Already Ran!")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this really be a warning?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug is better

@@ -327,3 +333,7 @@ def _replace_by_prefix(
new_key = new_prefix + key[len(old_prefix) :]
state_dict[new_key] = state_dict[key]
del state_dict[key]


def _data_ptr_allocated(tensor: torch.Tensor) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this missing an allow_in_graph? Or is it not used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't need it for now - It will in the future, but this makes it easier for me to track where I will eventually need to add it.

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM!

@voznesenskym
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 19, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

dmenig pushed a commit to dmenig/pytorch that referenced this pull request Dec 21, 2023
@facebook-github-bot facebook-github-bot deleted the gh/voznesenskym/294/head branch December 23, 2023 15:22
weifengpy added a commit to weifengpy/pytorch that referenced this pull request Jan 10, 2024
flat_param_part_view is unused in pytorch repo: https://fburl.com/ssaomd7x

it became unused since refactoring in pytorch#115497

before that, the original code is below. Since flat_param is 1D, we do
not need .view for reshaping

```
self.flat_param.data = padded_unsharded_flat_param[
    : unsharded_size.numel()
].view(
    unsharded_size
)
```

unit test: pytest test/distributed/fsdp/test_fsdp_core.py
pytorchmergebot pushed a commit that referenced this pull request Jan 11, 2024
flat_param_part_view is unused in pytorch repo: https://fburl.com/ssaomd7x

it became unused since refactoring in #115497

before that, the original code is below. Since flat_param is 1D, we do
not need .view for reshaping

```
self.flat_param.data = padded_unsharded_flat_param[
    : unsharded_size.numel()
].view(
    unsharded_size
)
```

unit test: pytest test/distributed/fsdp/test_fsdp_core.py

Pull Request resolved: #117082
Approved by: https://github.com/awgu, https://github.com/wconstab, https://github.com/Skylion007
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (ddp) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants