FSDP + DTensor Loss Flatlines Randomly #117471
Labels
module: dtensor
distributed tensor tag
module: fsdp
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone
馃悰 Describe the bug
We have been training dtensor off torch nightly (in anticipation for 2.2), and we are very often seeing the loss flatline. We do not see this at all on current nightly (as of 4 days ago), and at this point we are very confident there is a regression/bug in the current release candidate (for 2.2) that breaks FSDP training (at least with dtensor).
Our best guess is one of the two PRs linked fix it:
To be safe, I personally would want to also include the no grad bug fix:
Versions
Torch 2.2 branch
cc @zhaojuanmao @mrshenli @rohan-varma @awgu @fegin @penguinwu @kwen2501 @wanchaol @XilunWu @tianyu-l
The text was updated successfully, but these errors were encountered: