New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] fix: fix for fsdp zero2 validation error #110139
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110139
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit ef241bb with merge base 0013611 (): UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@Edwiv Thanks for catching this! Do you think you would be interested in adding a unit test or at least providing a repro? |
@awgu sure~ |
@Edwiv feel free to put it in test_fsdp_misc.py:
(I do not think we have a place for eval tests right now.) |
@awgu done~ |
@Edwiv Thanks for adding the unit test! It will run automatically in CI. Would it be possible to strengthen the unit test to check the correctness, e.g. with DDP? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the fix!
loss = loss.sum() | ||
loss.backward() | ||
|
||
with torch.no_grad(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we interested in putting the model back to train after eval, and verify that training still works fine?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done~
Sorry, I wanted to bump this in case you missed it @Edwiv! |
@awgu Sorry I didn't understand how to test with DDP, what should be done to strengthen the unit testing? |
@Edwiv You can construct an identical model with DDP applied (to implement data parallel semantics). Then, you can the same training loop for both DDP and FSDP models and compare each iteration's loss (for example) to check for correctness. |
@awgu I got it, I've added the DDP comparison to the unit test, please take a look. |
|
||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
for _ in range(5): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: We normally run optimizer in the training loop. Otherwise, since the inputs x
and y
are not changing, then the loss
should be the same on all iterations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: fsdp_loss
and ddp_loss
might be more informative names than loss
and loss1
.
y = torch.randint(low=0, high=9, size=(8,), device="cuda") | ||
x1 = x.clone().detach().requires_grad_() | ||
y1 = y.clone().detach() | ||
seed = 20231010 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should set the seed to be something like self.rank + 1
(anything so that data parallel ranks have different seeds); otherwise, the gradient reduction may not be tested (something like x + x / 2
= x
is the same as not summing and dividing).
loss1.backward() | ||
|
||
assert torch.allclose(loss, loss1) | ||
assert torch.allclose(x.grad.data, x1.grad.data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: the .data
are not needed.
assert torch.allclose(fsdp_loss, ddp_loss) | ||
assert torch.allclose(x.grad, x1.grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Would it be possible to check correctness after every iteration? This lets us narrow down / catch bugs much faster.
You can reference: https://github.com/pytorch/pytorch/pull/110948/files#diff-ab5af580410c642dd66ea27656265fbbc1ec6c6713e048a0ef111573eb52286cR195
Otherwise, we are pretty much good to go.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've perfected this UT, and by the way, the author of the PR above turned out to be a coworker of mine😂.
Lint error looks like an infra issue. Let me rebase and merge. |
@pytorchbot rebase -s |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
685ad69
to
ef241bb
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
# Problem When sharding_strategy is set to SHARD_GRAD_OP and forward_prefetch is turned on, the validation after the train has an incorrect weight shape. <img width="1508" alt="image" src="https://github.com/pytorch/pytorch/assets/41232043/57a9c3bb-cb5c-46df-ac26-922740686f9e"> # Analyze When using `SHARD_GRAD_OP`, the `free_unsharded_flat_param` in `_post_forward_reshard` is often False, so it does not set the handle's `_prefetched` flag to False after the forward. The normal train phase sets this flag to False in the `_post_backward_final_callback`, and the validation phase doesn't execute the hook, so after the first iter of the validation is done, the flag of the handle of the prefetched will remain True. This will cause the handle to skip the `_unshard` in the next `_pre_forward_unshard`, and the `_prefetch_handle` will not do a prefetch, which will result in an incorrect weight shape. Pull Request resolved: pytorch#110139 Approved by: https://github.com/awgu
# Problem When sharding_strategy is set to SHARD_GRAD_OP and forward_prefetch is turned on, the validation after the train has an incorrect weight shape. <img width="1508" alt="image" src="https://github.com/pytorch/pytorch/assets/41232043/57a9c3bb-cb5c-46df-ac26-922740686f9e"> # Analyze When using `SHARD_GRAD_OP`, the `free_unsharded_flat_param` in `_post_forward_reshard` is often False, so it does not set the handle's `_prefetched` flag to False after the forward. The normal train phase sets this flag to False in the `_post_backward_final_callback`, and the validation phase doesn't execute the hook, so after the first iter of the validation is done, the flag of the handle of the prefetched will remain True. This will cause the handle to skip the `_unshard` in the next `_pre_forward_unshard`, and the `_prefetch_handle` will not do a prefetch, which will result in an incorrect weight shape. Pull Request resolved: pytorch#110139 Approved by: https://github.com/awgu
Problem
When sharding_strategy is set to SHARD_GRAD_OP and forward_prefetch is turned on, the validation after the train has an incorrect weight shape.
Analyze
When using
SHARD_GRAD_OP
, thefree_unsharded_flat_param
in_post_forward_reshard
is often False, so it does not set the handle's_prefetched
flag to False after the forward.The normal train phase sets this flag to False in the
_post_backward_final_callback
, and the validation phase doesn't execute the hook, so after the first iter of the validation is done, the flag of the handle of the prefetched will remain True.This will cause the handle to skip the
_unshard
in the next_pre_forward_unshard
, and the_prefetch_handle
will not do a prefetch, which will result in an incorrect weight shape.