-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
better-engineeringRelatively self-contained tasks for better engineering contributorsRelatively self-contained tasks for better engineering contributorshigh prioritymodule: fsdponcall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queuetriage reviewtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
If we have the following setup
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(10, 10)
self.b = nn.Linear(10, 10)
def forward(self, x):
return self.b(self.a(x))
model = MyModel()
fsdp = FSDP(
model,
auto_wrap_policy=always_wrap_policy,
cpu_offload=CPUOffload(offload_params=True),
device_id=torch.cuda.current_device()
)
we hit the error:
RuntimeError: Module on rank 1 is given device_id argument cuda:1, but is on cpu. Either move module before FSDP init or omit device_id argument.
This seems to be because the root FSDP unit does not manage any params, so when checking whether to move because it is given device_id argument, it accesses a FSDP submodule's FlatParam which is on CPU, and we throw an error:
| if param is not None and param.device != self.device_id: |
The proper fix should be to bypass this check if we end up with a flatparam.
Lightning integration has hit this issue.
Versions
main
cc @ezyang @gchanan @zou3519 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang @kwen2501
Metadata
Metadata
Assignees
Labels
better-engineeringRelatively self-contained tasks for better engineering contributorsRelatively self-contained tasks for better engineering contributorshigh prioritymodule: fsdponcall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queuetriage reviewtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module