Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] fix: fix for fsdp zero2 validation error #110139

Closed
wants to merge 8 commits into from

Conversation

Edwiv
Copy link
Contributor

@Edwiv Edwiv commented Sep 27, 2023

Problem

When sharding_strategy is set to SHARD_GRAD_OP and forward_prefetch is turned on, the validation after the train has an incorrect weight shape.
image

Analyze

When using SHARD_GRAD_OP, the free_unsharded_flat_param in _post_forward_reshard is often False, so it does not set the handle's _prefetched flag to False after the forward.

The normal train phase sets this flag to False in the _post_backward_final_callback, and the validation phase doesn't execute the hook, so after the first iter of the validation is done, the flag of the handle of the prefetched will remain True.

This will cause the handle to skip the _unshard in the next _pre_forward_unshard, and the _prefetch_handle will not do a prefetch, which will result in an incorrect weight shape.

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 27, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110139

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit ef241bb with merge base 0013611 (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Sep 27, 2023
@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Sep 27, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

@awgu
Copy link
Contributor

awgu commented Sep 27, 2023

@Edwiv Thanks for catching this!

Do you think you would be interested in adding a unit test or at least providing a repro?

@awgu awgu self-assigned this Sep 27, 2023
@Edwiv
Copy link
Contributor Author

Edwiv commented Sep 27, 2023

@awgu sure~

@drisspg drisspg added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 27, 2023
@awgu
Copy link
Contributor

awgu commented Sep 27, 2023

@Edwiv feel free to put it in test_fsdp_misc.py:

class TestFSDPMiscMultiProcess(FSDPTest):

(I do not think we have a place for eval tests right now.)

@Edwiv
Copy link
Contributor Author

Edwiv commented Sep 28, 2023

@awgu done~
Can you please tell me how I should test this UT using CI?

@awgu
Copy link
Contributor

awgu commented Sep 28, 2023

@Edwiv Thanks for adding the unit test! It will run automatically in CI.

Would it be possible to strengthen the unit test to check the correctness, e.g. with DDP?

Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the fix!

loss = loss.sum()
loss.backward()

with torch.no_grad():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we interested in putting the model back to train after eval, and verify that training still works fine?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done~

@awgu
Copy link
Contributor

awgu commented Oct 9, 2023

Would it be possible to strengthen the unit test to check the correctness, e.g. with DDP?

Sorry, I wanted to bump this in case you missed it @Edwiv!

@Edwiv
Copy link
Contributor Author

Edwiv commented Oct 9, 2023

@awgu Sorry I didn't understand how to test with DDP, what should be done to strengthen the unit testing?

@awgu
Copy link
Contributor

awgu commented Oct 9, 2023

@Edwiv You can construct an identical model with DDP applied (to implement data parallel semantics). Then, you can the same training loop for both DDP and FSDP models and compare each iteration's loss (for example) to check for correctness.

@Edwiv
Copy link
Contributor Author

Edwiv commented Oct 10, 2023

@awgu I got it, I've added the DDP comparison to the unit test, please take a look.


torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
for _ in range(5):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We normally run optimizer in the training loop. Otherwise, since the inputs x and y are not changing, then the loss should be the same on all iterations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: fsdp_loss and ddp_loss might be more informative names than loss and loss1.

y = torch.randint(low=0, high=9, size=(8,), device="cuda")
x1 = x.clone().detach().requires_grad_()
y1 = y.clone().detach()
seed = 20231010
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should set the seed to be something like self.rank + 1 (anything so that data parallel ranks have different seeds); otherwise, the gradient reduction may not be tested (something like x + x / 2 = x is the same as not summing and dividing).

loss1.backward()

assert torch.allclose(loss, loss1)
assert torch.allclose(x.grad.data, x1.grad.data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the .data are not needed.

Comment on lines 218 to 219
assert torch.allclose(fsdp_loss, ddp_loss)
assert torch.allclose(x.grad, x1.grad)
Copy link
Contributor

@awgu awgu Oct 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Would it be possible to check correctness after every iteration? This lets us narrow down / catch bugs much faster.

You can reference: https://github.com/pytorch/pytorch/pull/110948/files#diff-ab5af580410c642dd66ea27656265fbbc1ec6c6713e048a0ef111573eb52286cR195

Otherwise, we are pretty much good to go.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've perfected this UT, and by the way, the author of the PR above turned out to be a coworker of mine😂.

@awgu
Copy link
Contributor

awgu commented Oct 14, 2023

Lint error looks like an infra issue. Let me rebase and merge.

@awgu
Copy link
Contributor

awgu commented Oct 14, 2023

@pytorchbot rebase -s

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased zyj/fix/fsdp_val_error onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout zyj/fix/fsdp_val_error && git pull --rebase)

@awgu
Copy link
Contributor

awgu commented Oct 14, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 14, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

yeounoh pushed a commit to yeounoh/pytorch that referenced this pull request Oct 16, 2023
# Problem
When sharding_strategy is set to SHARD_GRAD_OP and forward_prefetch is turned on, the validation after the train has an incorrect weight shape.
<img width="1508" alt="image" src="https://github.com/pytorch/pytorch/assets/41232043/57a9c3bb-cb5c-46df-ac26-922740686f9e">

# Analyze
When using `SHARD_GRAD_OP`, the `free_unsharded_flat_param` in `_post_forward_reshard` is often False, so it does not set the handle's `_prefetched` flag to False after the forward.

The normal train phase sets this flag to False in the `_post_backward_final_callback`, and the validation phase doesn't execute the hook, so after the first iter of the validation is done, the flag of the handle of the prefetched will remain True.

This will cause the handle to skip the `_unshard` in the next `_pre_forward_unshard`, and the `_prefetch_handle` will not do a prefetch, which will result in an incorrect weight shape.
Pull Request resolved: pytorch#110139
Approved by: https://github.com/awgu
yeounoh pushed a commit to yeounoh/pytorch that referenced this pull request Oct 16, 2023
# Problem
When sharding_strategy is set to SHARD_GRAD_OP and forward_prefetch is turned on, the validation after the train has an incorrect weight shape.
<img width="1508" alt="image" src="https://github.com/pytorch/pytorch/assets/41232043/57a9c3bb-cb5c-46df-ac26-922740686f9e">

# Analyze
When using `SHARD_GRAD_OP`, the `free_unsharded_flat_param` in `_post_forward_reshard` is often False, so it does not set the handle's `_prefetched` flag to False after the forward.

The normal train phase sets this flag to False in the `_post_backward_final_callback`, and the validation phase doesn't execute the hook, so after the first iter of the validation is done, the flag of the handle of the prefetched will remain True.

This will cause the handle to skip the `_unshard` in the next `_pre_forward_unshard`, and the `_prefetch_handle` will not do a prefetch, which will result in an incorrect weight shape.
Pull Request resolved: pytorch#110139
Approved by: https://github.com/awgu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: distributed (fsdp) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants