Can't load on rank 0 only with set_optimizer_state_dict
#125177
Labels
module: distributed_checkpoint
oncall: distributed
Add this issue/PR to distributed oncall triage queue
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃悰 Describe the bug
To avoid CPU OOMs, our training library only loads monolithic checkpoints on rank 0 and broadcasts to all other ranks (as PyTorch checkpointing supports). When migrating to the new distributed APIs,
we hit an error with this approach in the below function:
pytorch/torch/distributed/checkpoint/state_dict.py
Line 582 in ae13c7e
In our code,
optim_state_dict
is only non-None on rank-0 asrank0only
is set to True in PyTorch code.Currently, I need to do:
which is unideal. I think the split function should just not be run if the context manager is rank0only, but I am not sure here
Versions
Pytorch 2.3
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @LucasLLC
The text was updated successfully, but these errors were encountered: