Skip to content

Conversation

fegin
Copy link
Contributor

@fegin fegin commented Sep 7, 2022

Stack from ghstack (oldest at bottom):

Background:
Optimizer states are of the type Dict[int, Dict[str, torch.Tensor]] and the order of dict.items() is the creation order of keys. Without checkpoint (state_dict/load_state_dict), the creation order of keys depends on the implementation of the optimizer (e.g., Adam seems to creates exp_avg then exp_avg_sq). However, when loading states from a checkpoint, since the optimizer state are lazily initialized, the order depends on the user code (reading state_dict from IO). See the following example:

optimizer_state_dict = USER_CODE_TO_READ_STATE_FROM_IO()
optimizer.load_state_dict(optimizer_state_dict)

The key order of optimizer_state_dict depends on USER_CODE_TO_READ_STATE_FROM_IO and there is no guarantee that the order is the same across ranks.

What Can Go Wrong?
After the first checkpoint load, the key order of optimizer may not be the same on different ranks. When users try to save another checkpoint, user will call _unflatten_optim_state() to save the optimizer states. Inside _unflatten_optim_state(), dict.itmes() will be called to iterate all the local optimizer state and all_gather() will be used to gather the local states. Since the order may be different across ranks, the gathered states are not correct.

We have seen some models get NaN loss after the second checkpoint load because of this issue.

What This PR Does?
This PR implements a sorted_items() to return sorted (key, value) pairs. We can do this because the key is either an integer or a string.

Differential Revision: D39315184

…timizer states

**Background:**
Optimizer states are of the type `Dict[int, Dict[str, torch.Tensor]]` and the order of `dict.items()`  is the creation order of keys. Without checkpoint (state_dict/load_state_dict), the creation order of keys depends on the implementation of the optimizer (e.g., Adam seems to creates `exp_avg` then `exp_avg_sq`). However, when loading states from a checkpoint, since the optimizer state are lazily initialized, the order depends on the user code (reading state_dict from IO). See the following example:

```
optimizer_state_dict = USER_CODE_TO_READ_STATE_FROM_IO()
optimizer.load_state_dict(optimizer_state_dict)
```
The key order of `optimizer_state_dict` depends on `USER_CODE_TO_READ_STATE_FROM_IO` and there is no guarantee that the order is the same across ranks.

**What Can Go Wrong?**
After the first checkpoint load, the key order of optimizer may not be the same on different ranks. When users try to save another checkpoint, user will call `_unflatten_optim_state()` to save the optimizer states. Inside `_unflatten_optim_state()`, `dict.itmes()` will be called to iterate all the local optimizer state and `all_gather()` will be used to gather the local states. Since the order may be different across ranks, the gathered states are not correct.

We have seen some models get NaN loss after the second checkpoint load because of this issue.

**What This PR Does?**
This PR implements a `sorted_items()` to return sorted `(key, value)` pairs. We can do this because the key is either an integer or a string.

Differential Revision: [D39315184](https://our.internmc.facebook.com/intern/diff/D39315184/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Sep 7, 2022

🔗 Helpful links

✅ No Failures (32 Pending)

As of commit 69f5148 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

fegin added a commit that referenced this pull request Sep 7, 2022
…timizer states

**Background:**
Optimizer states are of the type `Dict[int, Dict[str, torch.Tensor]]` and the order of `dict.items()`  is the creation order of keys. Without checkpoint (state_dict/load_state_dict), the creation order of keys depends on the implementation of the optimizer (e.g., Adam seems to creates `exp_avg` then `exp_avg_sq`). However, when loading states from a checkpoint, since the optimizer state are lazily initialized, the order depends on the user code (reading state_dict from IO). See the following example:

```
optimizer_state_dict = USER_CODE_TO_READ_STATE_FROM_IO()
optimizer.load_state_dict(optimizer_state_dict)
```
The key order of `optimizer_state_dict` depends on `USER_CODE_TO_READ_STATE_FROM_IO` and there is no guarantee that the order is the same across ranks.

**What Can Go Wrong?**
After the first checkpoint load, the key order of optimizer may not be the same on different ranks. When users try to save another checkpoint, user will call `_unflatten_optim_state()` to save the optimizer states. Inside `_unflatten_optim_state()`, `dict.itmes()` will be called to iterate all the local optimizer state and `all_gather()` will be used to gather the local states. Since the order may be different across ranks, the gathered states are not correct.

We have seen some models get NaN loss after the second checkpoint load because of this issue.

**What This PR Does?**
This PR implements a `sorted_items()` to return sorted `(key, value)` pairs. We can do this because the key is either an integer or a string.

Differential Revision: [D39315184](https://our.internmc.facebook.com/intern/diff/D39315184/)

ghstack-source-id: 166684083
Pull Request resolved: #84654
@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Sep 7, 2022
… through optimizer states"

**Background:**
Optimizer states are of the type `Dict[int, Dict[str, torch.Tensor]]` and the order of `dict.items()`  is the creation order of keys. Without checkpoint (state_dict/load_state_dict), the creation order of keys depends on the implementation of the optimizer (e.g., Adam seems to creates `exp_avg` then `exp_avg_sq`). However, when loading states from a checkpoint, since the optimizer state are lazily initialized, the order depends on the user code (reading state_dict from IO). See the following example:

```
optimizer_state_dict = USER_CODE_TO_READ_STATE_FROM_IO()
optimizer.load_state_dict(optimizer_state_dict)
```
The key order of `optimizer_state_dict` depends on `USER_CODE_TO_READ_STATE_FROM_IO` and there is no guarantee that the order is the same across ranks.

**What Can Go Wrong?**
After the first checkpoint load, the key order of optimizer may not be the same on different ranks. When users try to save another checkpoint, user will call `_unflatten_optim_state()` to save the optimizer states. Inside `_unflatten_optim_state()`, `dict.itmes()` will be called to iterate all the local optimizer state and `all_gather()` will be used to gather the local states. Since the order may be different across ranks, the gathered states are not correct.

We have seen some models get NaN loss after the second checkpoint load because of this issue.

**What This PR Does?**
This PR implements a `sorted_items()` to return sorted `(key, value)` pairs. We can do this because the key is either an integer or a string.

Differential Revision: [D39315184](https://our.internmc.facebook.com/intern/diff/D39315184/)

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 7, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/84654

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures, 4 Pending

As of commit c0274ac:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

fegin added a commit that referenced this pull request Sep 7, 2022
…timizer states

Pull Request resolved: #84654

**Background:**
Optimizer states are of the type `Dict[int, Dict[str, torch.Tensor]]` and the order of `dict.items()`  is the creation order of keys. Without checkpoint (state_dict/load_state_dict), the creation order of keys depends on the implementation of the optimizer (e.g., Adam seems to creates `exp_avg` then `exp_avg_sq`). However, when loading states from a checkpoint, since the optimizer state are lazily initialized, the order depends on the user code (reading state_dict from IO). See the following example:

```
optimizer_state_dict = USER_CODE_TO_READ_STATE_FROM_IO()
optimizer.load_state_dict(optimizer_state_dict)
```
The key order of `optimizer_state_dict` depends on `USER_CODE_TO_READ_STATE_FROM_IO` and there is no guarantee that the order is the same across ranks.

**What Can Go Wrong?**
After the first checkpoint load, the key order of optimizer may not be the same on different ranks. When users try to save another checkpoint, user will call `_unflatten_optim_state()` to save the optimizer states. Inside `_unflatten_optim_state()`, `dict.itmes()` will be called to iterate all the local optimizer state and `all_gather()` will be used to gather the local states. Since the order may be different across ranks, the gathered states are not correct.

We have seen some models get NaN loss after the second checkpoint load because of this issue.

**What This PR Does?**
This PR implements a `sorted_items()` to return sorted `(key, value)` pairs. We can do this because the key is either an integer or a string.
ghstack-source-id: 166712574

Differential Revision: [D39315184](https://our.internmc.facebook.com/intern/diff/D39315184/)
… through optimizer states"

**Background:**
Optimizer states are of the type `Dict[int, Dict[str, torch.Tensor]]` and the order of `dict.items()`  is the creation order of keys. Without checkpoint (state_dict/load_state_dict), the creation order of keys depends on the implementation of the optimizer (e.g., Adam seems to creates `exp_avg` then `exp_avg_sq`). However, when loading states from a checkpoint, since the optimizer state are lazily initialized, the order depends on the user code (reading state_dict from IO). See the following example:

```
optimizer_state_dict = USER_CODE_TO_READ_STATE_FROM_IO()
optimizer.load_state_dict(optimizer_state_dict)
```
The key order of `optimizer_state_dict` depends on `USER_CODE_TO_READ_STATE_FROM_IO` and there is no guarantee that the order is the same across ranks.

**What Can Go Wrong?**
After the first checkpoint load, the key order of optimizer may not be the same on different ranks. When users try to save another checkpoint, user will call `_unflatten_optim_state()` to save the optimizer states. Inside `_unflatten_optim_state()`, `dict.itmes()` will be called to iterate all the local optimizer state and `all_gather()` will be used to gather the local states. Since the order may be different across ranks, the gathered states are not correct.

We have seen some models get NaN loss after the second checkpoint load because of this issue.

**What This PR Does?**
This PR implements a `sorted_items()` to return sorted `(key, value)` pairs. We can do this because the key is either an integer or a string.

Differential Revision: [D39315184](https://our.internmc.facebook.com/intern/diff/D39315184/)

[ghstack-poisoned]
fegin added a commit that referenced this pull request Sep 8, 2022
…timizer states

Pull Request resolved: #84654

**Background:**
Optimizer states are of the type `Dict[int, Dict[str, torch.Tensor]]` and the order of `dict.items()`  is the creation order of keys. Without checkpoint (state_dict/load_state_dict), the creation order of keys depends on the implementation of the optimizer (e.g., Adam seems to creates `exp_avg` then `exp_avg_sq`). However, when loading states from a checkpoint, since the optimizer state are lazily initialized, the order depends on the user code (reading state_dict from IO). See the following example:

```
optimizer_state_dict = USER_CODE_TO_READ_STATE_FROM_IO()
optimizer.load_state_dict(optimizer_state_dict)
```
The key order of `optimizer_state_dict` depends on `USER_CODE_TO_READ_STATE_FROM_IO` and there is no guarantee that the order is the same across ranks.

**What Can Go Wrong?**
After the first checkpoint load, the key order of optimizer may not be the same on different ranks. When users try to save another checkpoint, user will call `_unflatten_optim_state()` to save the optimizer states. Inside `_unflatten_optim_state()`, `dict.itmes()` will be called to iterate all the local optimizer state and `all_gather()` will be used to gather the local states. Since the order may be different across ranks, the gathered states are not correct.

We have seen some models get NaN loss after the second checkpoint load because of this issue.

**What This PR Does?**
This PR implements a `sorted_items()` to return sorted `(key, value)` pairs. We can do this because the key is either an integer or a string.
ghstack-source-id: 166751531

Differential Revision: [D39315184](https://our.internmc.facebook.com/intern/diff/D39315184/)
Copy link
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@zhaojuanmao
Copy link
Contributor

nice catch!!

@fegin
Copy link
Contributor Author

fegin commented Sep 9, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here.
The merge job was triggered without a flag. This means that your change will be merged once all checks on your PR have passed (ETA: 0-4 Hours). If this is not the intended behavior, feel free to use some of the other merge options in the wiki.
Please reach out to the PyTorch DevX Team with feedback or questions!

@github-actions
Copy link
Contributor

github-actions bot commented Sep 9, 2022

Hey @fegin.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Sep 9, 2022
…timizer states (#84654)

Summary:
Pull Request resolved: #84654

**Background:**
Optimizer states are of the type `Dict[int, Dict[str, torch.Tensor]]` and the order of `dict.items()`  is the creation order of keys. Without checkpoint (state_dict/load_state_dict), the creation order of keys depends on the implementation of the optimizer (e.g., Adam seems to creates `exp_avg` then `exp_avg_sq`). However, when loading states from a checkpoint, since the optimizer state are lazily initialized, the order depends on the user code (reading state_dict from IO). See the following example:

```
optimizer_state_dict = USER_CODE_TO_READ_STATE_FROM_IO()
optimizer.load_state_dict(optimizer_state_dict)
```
The key order of `optimizer_state_dict` depends on `USER_CODE_TO_READ_STATE_FROM_IO` and there is no guarantee that the order is the same across ranks.

**What Can Go Wrong?**
After the first checkpoint load, the key order of optimizer may not be the same on different ranks. When users try to save another checkpoint, user will call `_unflatten_optim_state()` to save the optimizer states. Inside `_unflatten_optim_state()`, `dict.itmes()` will be called to iterate all the local optimizer state and `all_gather()` will be used to gather the local states. Since the order may be different across ranks, the gathered states are not correct.

We have seen some models get NaN loss after the second checkpoint load because of this issue.

**What This PR Does?**
This PR implements a `sorted_items()` to return sorted `(key, value)` pairs. We can do this because the key is either an integer or a string.
ghstack-source-id: 166751531

Test Plan: CI

Reviewed By: awgu

Differential Revision: D39315184

fbshipit-source-id: 8089527d5ae609a41b76b68bc567256167810d03
rdspring1 pushed a commit to rdspring1/pytorch that referenced this pull request Sep 10, 2022
…timizer states (pytorch#84654)

**Background:**
Optimizer states are of the type `Dict[int, Dict[str, torch.Tensor]]` and the order of `dict.items()`  is the creation order of keys. Without checkpoint (state_dict/load_state_dict), the creation order of keys depends on the implementation of the optimizer (e.g., Adam seems to creates `exp_avg` then `exp_avg_sq`). However, when loading states from a checkpoint, since the optimizer state are lazily initialized, the order depends on the user code (reading state_dict from IO). See the following example:

```
optimizer_state_dict = USER_CODE_TO_READ_STATE_FROM_IO()
optimizer.load_state_dict(optimizer_state_dict)
```
The key order of `optimizer_state_dict` depends on `USER_CODE_TO_READ_STATE_FROM_IO` and there is no guarantee that the order is the same across ranks.

**What Can Go Wrong?**
After the first checkpoint load, the key order of optimizer may not be the same on different ranks. When users try to save another checkpoint, user will call `_unflatten_optim_state()` to save the optimizer states. Inside `_unflatten_optim_state()`, `dict.itmes()` will be called to iterate all the local optimizer state and `all_gather()` will be used to gather the local states. Since the order may be different across ranks, the gathered states are not correct.

We have seen some models get NaN loss after the second checkpoint load because of this issue.

**What This PR Does?**
This PR implements a `sorted_items()` to return sorted `(key, value)` pairs. We can do this because the key is either an integer or a string.

Differential Revision: [D39315184](https://our.internmc.facebook.com/intern/diff/D39315184/)
Pull Request resolved: pytorch#84654
Approved by: https://github.com/awgu
@facebook-github-bot facebook-github-bot deleted the gh/fegin/24/head branch September 12, 2022 14:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants