FSDP init can crash with shared parameters #83052
Labels
high priority
module: fsdp
oncall: distributed
Add this issue/PR to distributed oncall triage queue
triage review
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃悰 Describe the bug
FSDP initialization can crash when modules with shared params are wrapped separately. For example, if wrap https://github.com/facebookresearch/multimodal/blob/679f3596e4c44b483c68d4023b24e3c7f77292b3/torchmultimodal/modules/losses/flava.py#L138 linear (decoder) separately from the main module and then wrap the main module with
device_id
argument, this will raise an error due tobias
param being shared. Thebias
param would have already been moved to GPU by the linear wrapped FSDP unit, but then the higher-level wrapper would still expect it to be on CPU, resulting in this error:pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py
Line 814 in 9e65e93
Versions
main
cc @ezyang @gchanan @zou3519 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang @kwen2501
The text was updated successfully, but these errors were encountered: