Skip to content

Commit

Permalink
Fix ignored_states when they are passed as generators (#102575)
Browse files Browse the repository at this point in the history
This PR fixed the case where ignored_states are passed as generators, not List/Set

Pull Request resolved: #102575
Approved by: https://github.com/awgu
  • Loading branch information
zhaojuanmao authored and pytorchmergebot committed May 31, 2023
1 parent 9f97b7c commit f47ee87
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions test/distributed/fsdp/test_fsdp_ignored_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,9 @@ def _test_diff_ignored_modules_across_ranks(
{"ignored_modules": layer1_ignored_modules}
if ignore_modules
else {
"ignored_states": {
"ignored_states": (
p for m in layer1_ignored_modules for p in m.parameters()
}
)
}
)
model.layer1 = wrap_cls(model.layer1, **ignore_kwargs)
Expand Down
10 changes: 5 additions & 5 deletions torch/distributed/fsdp/_init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,12 +253,12 @@ def _init_ignored_module_states(
), "Can not pass `ignored_modules` and `ignored_states` at the same time. \
Please either pass `ignored_modules` or `ignored_states`."
ignored_parameters = None
if ignored_states:
ignored_states_set = set(ignored_states)
if isinstance(next(iter(ignored_states), None), torch.nn.Parameter):
ignored_parameters = ignored_states_set
ignored_states_list = list(ignored_states) if ignored_states is not None else []
if ignored_states_list and len(ignored_states_list) > 0:
if isinstance(ignored_states_list[0], torch.nn.Parameter):
ignored_parameters = ignored_states_list
else:
ignored_modules = ignored_states_set
ignored_modules = ignored_states_list
state._ignored_modules = _get_ignored_modules(module, ignored_modules)
state._ignored_params = _get_ignored_params(
module,
Expand Down

0 comments on commit f47ee87

Please sign in to comment.