Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP + DTensor Loss Flatlines Randomly #117471

Closed
mvpatel2000 opened this issue Jan 14, 2024 · 5 comments
Closed

FSDP + DTensor Loss Flatlines Randomly #117471

mvpatel2000 opened this issue Jan 14, 2024 · 5 comments
Labels
module: dtensor distributed tensor tag module: fsdp triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@mvpatel2000
Copy link
Contributor

mvpatel2000 commented Jan 14, 2024

馃悰 Describe the bug

We have been training dtensor off torch nightly (in anticipation for 2.2), and we are very often seeing the loss flatline. We do not see this at all on current nightly (as of 4 days ago), and at this point we are very confident there is a regression/bug in the current release candidate (for 2.2) that breaks FSDP training (at least with dtensor).
Our best guess is one of the two PRs linked fix it:

image

Versions

Torch 2.2 branch

cc @zhaojuanmao @mrshenli @rohan-varma @awgu @fegin @penguinwu @kwen2501 @wanchaol @XilunWu @tianyu-l

@Skylion007 Skylion007 added this to the 2.2.0 milestone Jan 15, 2024
@Skylion007
Copy link
Collaborator

We confirmed this affects 2.2.0 final RC

@atalman
Copy link
Contributor

atalman commented Jan 15, 2024

@mvpatel2000
Copy link
Contributor Author

@atalman unfortunately I do not have a minimal repro nor am able to share the code for this run at this time :(

We run a transformer model with Dtensor + FSDP (pass in device mesh). The only different think we do is some weights are manually wrapped with dtensor and presharded before FSDP -- I'm pretty sure this won't matter so reproducing on your end shouldn't be too hard, but I'm not 100% confident

@wanchaol
Copy link
Contributor

@atalman I just checked our release branch, in addition to #117020 We'll also need this PR together to resolve the merge conflicts #116122.

I can also confirms that I also met similar numeric issues (although not loss flatline, but it's loss NaN problem which looks similar to the issue that @mvpatel2000 met). These two fixes helps me resolve the NaN problem, it would be great if we can include these two fixes in our release branch :)

@awgu awgu added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: fsdp module: dtensor distributed tensor tag labels Jan 16, 2024
@atalman atalman modified the milestones: 2.2.0, 2.2.1 Jan 18, 2024
@mvpatel2000
Copy link
Contributor Author

Fixed in dev

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dtensor distributed tensor tag module: fsdp triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants