-
Notifications
You must be signed in to change notification settings - Fork 25.2k
[FSDP] Ensure that all ranks use the same order to iterate through optimizer states #84654
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…timizer states **Background:** Optimizer states are of the type `Dict[int, Dict[str, torch.Tensor]]` and the order of `dict.items()` is the creation order of keys. Without checkpoint (state_dict/load_state_dict), the creation order of keys depends on the implementation of the optimizer (e.g., Adam seems to creates `exp_avg` then `exp_avg_sq`). However, when loading states from a checkpoint, since the optimizer state are lazily initialized, the order depends on the user code (reading state_dict from IO). See the following example: ``` optimizer_state_dict = USER_CODE_TO_READ_STATE_FROM_IO() optimizer.load_state_dict(optimizer_state_dict) ``` The key order of `optimizer_state_dict` depends on `USER_CODE_TO_READ_STATE_FROM_IO` and there is no guarantee that the order is the same across ranks. **What Can Go Wrong?** After the first checkpoint load, the key order of optimizer may not be the same on different ranks. When users try to save another checkpoint, user will call `_unflatten_optim_state()` to save the optimizer states. Inside `_unflatten_optim_state()`, `dict.itmes()` will be called to iterate all the local optimizer state and `all_gather()` will be used to gather the local states. Since the order may be different across ranks, the gathered states are not correct. We have seen some models get NaN loss after the second checkpoint load because of this issue. **What This PR Does?** This PR implements a `sorted_items()` to return sorted `(key, value)` pairs. We can do this because the key is either an integer or a string. Differential Revision: [D39315184](https://our.internmc.facebook.com/intern/diff/D39315184/) [ghstack-poisoned]
🔗 Helpful links
✅ No Failures (32 Pending)As of commit 69f5148 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
…timizer states **Background:** Optimizer states are of the type `Dict[int, Dict[str, torch.Tensor]]` and the order of `dict.items()` is the creation order of keys. Without checkpoint (state_dict/load_state_dict), the creation order of keys depends on the implementation of the optimizer (e.g., Adam seems to creates `exp_avg` then `exp_avg_sq`). However, when loading states from a checkpoint, since the optimizer state are lazily initialized, the order depends on the user code (reading state_dict from IO). See the following example: ``` optimizer_state_dict = USER_CODE_TO_READ_STATE_FROM_IO() optimizer.load_state_dict(optimizer_state_dict) ``` The key order of `optimizer_state_dict` depends on `USER_CODE_TO_READ_STATE_FROM_IO` and there is no guarantee that the order is the same across ranks. **What Can Go Wrong?** After the first checkpoint load, the key order of optimizer may not be the same on different ranks. When users try to save another checkpoint, user will call `_unflatten_optim_state()` to save the optimizer states. Inside `_unflatten_optim_state()`, `dict.itmes()` will be called to iterate all the local optimizer state and `all_gather()` will be used to gather the local states. Since the order may be different across ranks, the gathered states are not correct. We have seen some models get NaN loss after the second checkpoint load because of this issue. **What This PR Does?** This PR implements a `sorted_items()` to return sorted `(key, value)` pairs. We can do this because the key is either an integer or a string. Differential Revision: [D39315184](https://our.internmc.facebook.com/intern/diff/D39315184/) ghstack-source-id: 166684083 Pull Request resolved: #84654
… through optimizer states" **Background:** Optimizer states are of the type `Dict[int, Dict[str, torch.Tensor]]` and the order of `dict.items()` is the creation order of keys. Without checkpoint (state_dict/load_state_dict), the creation order of keys depends on the implementation of the optimizer (e.g., Adam seems to creates `exp_avg` then `exp_avg_sq`). However, when loading states from a checkpoint, since the optimizer state are lazily initialized, the order depends on the user code (reading state_dict from IO). See the following example: ``` optimizer_state_dict = USER_CODE_TO_READ_STATE_FROM_IO() optimizer.load_state_dict(optimizer_state_dict) ``` The key order of `optimizer_state_dict` depends on `USER_CODE_TO_READ_STATE_FROM_IO` and there is no guarantee that the order is the same across ranks. **What Can Go Wrong?** After the first checkpoint load, the key order of optimizer may not be the same on different ranks. When users try to save another checkpoint, user will call `_unflatten_optim_state()` to save the optimizer states. Inside `_unflatten_optim_state()`, `dict.itmes()` will be called to iterate all the local optimizer state and `all_gather()` will be used to gather the local states. Since the order may be different across ranks, the gathered states are not correct. We have seen some models get NaN loss after the second checkpoint load because of this issue. **What This PR Does?** This PR implements a `sorted_items()` to return sorted `(key, value)` pairs. We can do this because the key is either an integer or a string. Differential Revision: [D39315184](https://our.internmc.facebook.com/intern/diff/D39315184/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/84654
Note: Links to docs will display an error until the docs builds have been completed. ✅ No Failures, 4 PendingAs of commit c0274ac: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…timizer states Pull Request resolved: #84654 **Background:** Optimizer states are of the type `Dict[int, Dict[str, torch.Tensor]]` and the order of `dict.items()` is the creation order of keys. Without checkpoint (state_dict/load_state_dict), the creation order of keys depends on the implementation of the optimizer (e.g., Adam seems to creates `exp_avg` then `exp_avg_sq`). However, when loading states from a checkpoint, since the optimizer state are lazily initialized, the order depends on the user code (reading state_dict from IO). See the following example: ``` optimizer_state_dict = USER_CODE_TO_READ_STATE_FROM_IO() optimizer.load_state_dict(optimizer_state_dict) ``` The key order of `optimizer_state_dict` depends on `USER_CODE_TO_READ_STATE_FROM_IO` and there is no guarantee that the order is the same across ranks. **What Can Go Wrong?** After the first checkpoint load, the key order of optimizer may not be the same on different ranks. When users try to save another checkpoint, user will call `_unflatten_optim_state()` to save the optimizer states. Inside `_unflatten_optim_state()`, `dict.itmes()` will be called to iterate all the local optimizer state and `all_gather()` will be used to gather the local states. Since the order may be different across ranks, the gathered states are not correct. We have seen some models get NaN loss after the second checkpoint load because of this issue. **What This PR Does?** This PR implements a `sorted_items()` to return sorted `(key, value)` pairs. We can do this because the key is either an integer or a string. ghstack-source-id: 166712574 Differential Revision: [D39315184](https://our.internmc.facebook.com/intern/diff/D39315184/)
… through optimizer states" **Background:** Optimizer states are of the type `Dict[int, Dict[str, torch.Tensor]]` and the order of `dict.items()` is the creation order of keys. Without checkpoint (state_dict/load_state_dict), the creation order of keys depends on the implementation of the optimizer (e.g., Adam seems to creates `exp_avg` then `exp_avg_sq`). However, when loading states from a checkpoint, since the optimizer state are lazily initialized, the order depends on the user code (reading state_dict from IO). See the following example: ``` optimizer_state_dict = USER_CODE_TO_READ_STATE_FROM_IO() optimizer.load_state_dict(optimizer_state_dict) ``` The key order of `optimizer_state_dict` depends on `USER_CODE_TO_READ_STATE_FROM_IO` and there is no guarantee that the order is the same across ranks. **What Can Go Wrong?** After the first checkpoint load, the key order of optimizer may not be the same on different ranks. When users try to save another checkpoint, user will call `_unflatten_optim_state()` to save the optimizer states. Inside `_unflatten_optim_state()`, `dict.itmes()` will be called to iterate all the local optimizer state and `all_gather()` will be used to gather the local states. Since the order may be different across ranks, the gathered states are not correct. We have seen some models get NaN loss after the second checkpoint load because of this issue. **What This PR Does?** This PR implements a `sorted_items()` to return sorted `(key, value)` pairs. We can do this because the key is either an integer or a string. Differential Revision: [D39315184](https://our.internmc.facebook.com/intern/diff/D39315184/) [ghstack-poisoned]
…timizer states Pull Request resolved: #84654 **Background:** Optimizer states are of the type `Dict[int, Dict[str, torch.Tensor]]` and the order of `dict.items()` is the creation order of keys. Without checkpoint (state_dict/load_state_dict), the creation order of keys depends on the implementation of the optimizer (e.g., Adam seems to creates `exp_avg` then `exp_avg_sq`). However, when loading states from a checkpoint, since the optimizer state are lazily initialized, the order depends on the user code (reading state_dict from IO). See the following example: ``` optimizer_state_dict = USER_CODE_TO_READ_STATE_FROM_IO() optimizer.load_state_dict(optimizer_state_dict) ``` The key order of `optimizer_state_dict` depends on `USER_CODE_TO_READ_STATE_FROM_IO` and there is no guarantee that the order is the same across ranks. **What Can Go Wrong?** After the first checkpoint load, the key order of optimizer may not be the same on different ranks. When users try to save another checkpoint, user will call `_unflatten_optim_state()` to save the optimizer states. Inside `_unflatten_optim_state()`, `dict.itmes()` will be called to iterate all the local optimizer state and `all_gather()` will be used to gather the local states. Since the order may be different across ranks, the gathered states are not correct. We have seen some models get NaN loss after the second checkpoint load because of this issue. **What This PR Does?** This PR implements a `sorted_items()` to return sorted `(key, value)` pairs. We can do this because the key is either an integer or a string. ghstack-source-id: 166751531 Differential Revision: [D39315184](https://our.internmc.facebook.com/intern/diff/D39315184/)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
nice catch!! |
@pytorchbot merge |
@pytorchbot successfully started a merge job. Check the current status here. |
Hey @fegin. |
…timizer states (#84654) Summary: Pull Request resolved: #84654 **Background:** Optimizer states are of the type `Dict[int, Dict[str, torch.Tensor]]` and the order of `dict.items()` is the creation order of keys. Without checkpoint (state_dict/load_state_dict), the creation order of keys depends on the implementation of the optimizer (e.g., Adam seems to creates `exp_avg` then `exp_avg_sq`). However, when loading states from a checkpoint, since the optimizer state are lazily initialized, the order depends on the user code (reading state_dict from IO). See the following example: ``` optimizer_state_dict = USER_CODE_TO_READ_STATE_FROM_IO() optimizer.load_state_dict(optimizer_state_dict) ``` The key order of `optimizer_state_dict` depends on `USER_CODE_TO_READ_STATE_FROM_IO` and there is no guarantee that the order is the same across ranks. **What Can Go Wrong?** After the first checkpoint load, the key order of optimizer may not be the same on different ranks. When users try to save another checkpoint, user will call `_unflatten_optim_state()` to save the optimizer states. Inside `_unflatten_optim_state()`, `dict.itmes()` will be called to iterate all the local optimizer state and `all_gather()` will be used to gather the local states. Since the order may be different across ranks, the gathered states are not correct. We have seen some models get NaN loss after the second checkpoint load because of this issue. **What This PR Does?** This PR implements a `sorted_items()` to return sorted `(key, value)` pairs. We can do this because the key is either an integer or a string. ghstack-source-id: 166751531 Test Plan: CI Reviewed By: awgu Differential Revision: D39315184 fbshipit-source-id: 8089527d5ae609a41b76b68bc567256167810d03
…timizer states (pytorch#84654) **Background:** Optimizer states are of the type `Dict[int, Dict[str, torch.Tensor]]` and the order of `dict.items()` is the creation order of keys. Without checkpoint (state_dict/load_state_dict), the creation order of keys depends on the implementation of the optimizer (e.g., Adam seems to creates `exp_avg` then `exp_avg_sq`). However, when loading states from a checkpoint, since the optimizer state are lazily initialized, the order depends on the user code (reading state_dict from IO). See the following example: ``` optimizer_state_dict = USER_CODE_TO_READ_STATE_FROM_IO() optimizer.load_state_dict(optimizer_state_dict) ``` The key order of `optimizer_state_dict` depends on `USER_CODE_TO_READ_STATE_FROM_IO` and there is no guarantee that the order is the same across ranks. **What Can Go Wrong?** After the first checkpoint load, the key order of optimizer may not be the same on different ranks. When users try to save another checkpoint, user will call `_unflatten_optim_state()` to save the optimizer states. Inside `_unflatten_optim_state()`, `dict.itmes()` will be called to iterate all the local optimizer state and `all_gather()` will be used to gather the local states. Since the order may be different across ranks, the gathered states are not correct. We have seen some models get NaN loss after the second checkpoint load because of this issue. **What This PR Does?** This PR implements a `sorted_items()` to return sorted `(key, value)` pairs. We can do this because the key is either an integer or a string. Differential Revision: [D39315184](https://our.internmc.facebook.com/intern/diff/D39315184/) Pull Request resolved: pytorch#84654 Approved by: https://github.com/awgu
Stack from ghstack (oldest at bottom):
Background:
Optimizer states are of the type
Dict[int, Dict[str, torch.Tensor]]
and the order ofdict.items()
is the creation order of keys. Without checkpoint (state_dict/load_state_dict), the creation order of keys depends on the implementation of the optimizer (e.g., Adam seems to createsexp_avg
thenexp_avg_sq
). However, when loading states from a checkpoint, since the optimizer state are lazily initialized, the order depends on the user code (reading state_dict from IO). See the following example:The key order of
optimizer_state_dict
depends onUSER_CODE_TO_READ_STATE_FROM_IO
and there is no guarantee that the order is the same across ranks.What Can Go Wrong?
After the first checkpoint load, the key order of optimizer may not be the same on different ranks. When users try to save another checkpoint, user will call
_unflatten_optim_state()
to save the optimizer states. Inside_unflatten_optim_state()
,dict.itmes()
will be called to iterate all the local optimizer state andall_gather()
will be used to gather the local states. Since the order may be different across ranks, the gathered states are not correct.We have seen some models get NaN loss after the second checkpoint load because of this issue.
What This PR Does?
This PR implements a
sorted_items()
to return sorted(key, value)
pairs. We can do this because the key is either an integer or a string.Differential Revision: D39315184