Skip to content

Conversation

awgu
Copy link
Collaborator

@awgu awgu commented Jun 1, 2022

Stack from ghstack:

This enables the first argument to the optimizer constructor, optim_input, to have different orders across ranks, e.g. if the parameters in one parameter group are permuted. This requires modification to full_optim_state_dict(), shard_full_optim_state_dict(), and scatter_full_optim_state_dict().

The high-level algorithmic change is that the state dicts are kept as being keyed by unflattened parameter name until after sharding/unsharding and flattening/unflattening and are rekeyed to be by parameter ID according to each rank's own optim_input only at the end.

Because this PR adds non-parameter-specific collectives to full_optim_state_dict(), it adds a group=None argument to the method to have a process group to default to when running those common collectives.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 1, 2022

🔗 Helpful links

❌ 1 New Failures

As of commit edef523 (more details on the Dr. CI page):

Expand to see more
  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build pull / linux-focal-py3.7-gcc7 / test (backwards_compat, 1, 1, linux.2xlarge) (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-06-03T00:31:28.4361196Z The PR is introduc...m to confirm whether this change is wanted or not.
2022-06-03T00:31:28.4348951Z processing existing schema:  text(__torch__.torch.classes.profiling.SourceRef _0) -> str _0
2022-06-03T00:31:28.4349925Z processing existing schema:  count(__torch__.torch.classes.profiling.InstructionStats _0) -> int _0
2022-06-03T00:31:28.4351072Z processing existing schema:  duration_ns(__torch__.torch.classes.profiling.InstructionStats _0) -> int _0
2022-06-03T00:31:28.4352467Z processing existing schema:  source(__torch__.torch.classes.profiling.SourceStats _0) -> __torch__.torch.classes.profiling.SourceRef _0
2022-06-03T00:31:28.4354316Z processing existing schema:  line_map(__torch__.torch.classes.profiling.SourceStats _0) -> Dict(int, __torch__.torch.classes.profiling.InstructionStats) _0
2022-06-03T00:31:28.4355251Z processing existing schema:  __init__(__torch__.torch.classes.profiling._ScriptProfile _0) -> NoneType _0
2022-06-03T00:31:28.4356591Z processing existing schema:  enable(__torch__.torch.classes.profiling._ScriptProfile _0) -> NoneType _0
2022-06-03T00:31:28.4357508Z processing existing schema:  disable(__torch__.torch.classes.profiling._ScriptProfile _0) -> NoneType _0
2022-06-03T00:31:28.4359248Z processing existing schema:  _dump_stats(__torch__.torch.classes.profiling._ScriptProfile _0) -> __torch__.torch.classes.profiling.SourceStats[] _0
2022-06-03T00:31:28.4360379Z processing existing schema:  __init__(__torch__.torch.classes.dist_rpc.WorkerInfo _0, str _1, int _2) -> NoneType _0
2022-06-03T00:31:28.4361196Z The PR is introducing backward incompatible changes to the operator library. Please contact PyTorch team to confirm whether this change is wanted or not. 
2022-06-03T00:31:28.4361568Z 
2022-06-03T00:31:28.4361642Z Broken ops: [
2022-06-03T00:31:28.4362147Z 	aten::_linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh)
2022-06-03T00:31:28.4362729Z 	aten::_linalg_svd.U(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh)
2022-06-03T00:31:28.4363215Z 	aten::linalg_svd(Tensor A, bool full_matrices=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh)
2022-06-03T00:31:28.4363774Z 	aten::linalg_svd.U(Tensor A, bool full_matrices=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh)
2022-06-03T00:31:28.4364178Z 	aten::linalg_svdvals(Tensor A, *, str? driver=None) -> Tensor
2022-06-03T00:31:28.4364543Z 	aten::linalg_svdvals.out(Tensor A, *, str? driver=None, Tensor(a!) out) -> Tensor(a!)
2022-06-03T00:31:28.4364763Z ]
2022-06-03T00:31:28.5508838Z + cleanup

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jun 1, 2022
awgu pushed a commit that referenced this pull request Jun 1, 2022
ghstack-source-id: 761c531
Pull Request resolved: #78599
@awgu awgu changed the title [FSDP] Allow diff optim_input across ranks [FSDP] Allow different optim_input orders across ranks Jun 1, 2022
@rohan-varma rohan-varma self-requested a review June 1, 2022 15:10
This enables the first argument to the optimizer constructor, `optim_input`, to have different orders across ranks, e.g. if the parameters in one parameter group are permuted. This requires modification to `full_optim_state_dict()`, `shard_full_optim_state_dict()`, and `scatter_full_optim_state_dict()`.

The high-level algorithmic change is that the state dicts are kept as being keyed by unflattened parameter name until after sharding/unsharding and flattening/unflattening and are rekeyed to be by parameter ID according to each rank's own `optim_input` only at the end.

Differential Revision: [D36798482](https://our.internmc.facebook.com/intern/diff/D36798482)

[ghstack-poisoned]
Copy link
Contributor

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Let's test to ensure that it is fixed for the usecase we're looking at. Thanks so much for the quick fix!

raise RuntimeError(
"FSDP currently requires each rank to have at least the "
"optimizer states needed by rank 0's optimizer but some ranks "
"are missing some of those states"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it also be useful to log the missing keys?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add that at the cost of an all_gather_object().

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: The error now looks like:

RuntimeError: FSDP currently requires each rank to have at least the optimizer states needed by rank 0's optimizer but some ranks are missing some of those states
Rank 1 is missing states for the parameters: [('block2.2.weight', 'block2.2.bias_module0.bias', 'block2.2.bias_module1.bias')]
Rank 2 is missing states for the parameters: [('block2.2.weight', 'block2.2.bias_module0.bias', 'block2.2.bias_module1.bias')]
Rank 3 is missing states for the parameters: [('block2.2.weight', 'block2.2.bias_module0.bias', 'block2.2.bias_module1.bias')]

This enables the first argument to the optimizer constructor, `optim_input`, to have different orders across ranks, e.g. if the parameters in one parameter group are permuted. This requires modification to `full_optim_state_dict()`, `shard_full_optim_state_dict()`, and `scatter_full_optim_state_dict()`.

The high-level algorithmic change is that the state dicts are kept as being keyed by unflattened parameter name until after sharding/unsharding and flattening/unflattening and are rekeyed to be by parameter ID according to each rank's own `optim_input` only at the end.

Differential Revision: [D36798482](https://our.internmc.facebook.com/intern/diff/D36798482)

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Jun 1, 2022
ghstack-source-id: 191ec79
Pull Request resolved: #78599
This enables the first argument to the optimizer constructor, `optim_input`, to have different orders across ranks, e.g. if the parameters in one parameter group are permuted. This requires modification to `full_optim_state_dict()`, `shard_full_optim_state_dict()`, and `scatter_full_optim_state_dict()`.

The high-level algorithmic change is that the state dicts are kept as being keyed by unflattened parameter name until after sharding/unsharding and flattening/unflattening and are rekeyed to be by parameter ID according to each rank's own `optim_input` only at the end.

Differential Revision: [D36798482](https://our.internmc.facebook.com/intern/diff/D36798482)

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Jun 3, 2022
ghstack-source-id: bdfdc48
Pull Request resolved: #78599
@awgu
Copy link
Collaborator Author

awgu commented Jun 3, 2022

Backward compatibility error due to broken ops seems unrelated.

@awgu
Copy link
Collaborator Author

awgu commented Jun 3, 2022

@pytorchbot merge

@github-actions
Copy link
Contributor

github-actions bot commented Jun 3, 2022

Hey @awgu.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@awgu awgu added release notes: distributed (fsdp) release notes category topic: improvements topic category labels Jun 3, 2022
facebook-github-bot pushed a commit that referenced this pull request Jun 3, 2022
Summary:
Pull Request resolved: #78599

Approved by: https://github.com/rohan-varma

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/4615738a3d5aee256a9f1929c846dcae7d20a041

Reviewed By: rohan-varma

Differential Revision: D36798482

fbshipit-source-id: de05b37db8ed41b6cf11a9fc526a0e03a98f570d
Copy link
Contributor

@zhaojuanmao zhaojuanmao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @fegin as sharded optimizer states are built on top of it

@facebook-github-bot facebook-github-bot deleted the gh/awgu/51/head branch June 6, 2022 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category topic: improvements topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants