Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Fix for optim state dict #102901

Closed
wants to merge 3 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion torch/distributed/fsdp/_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,7 @@
state = (
{} if param_key is None else optim_state_dict["state"][param_key]
)
print(f"RV: calling _Gather_orig_param_state")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#nit: Remove print?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto on the others before landing :)

unflat_state = [
_gather_orig_param_state(
fsdp_param_info,
Expand Down Expand Up @@ -1497,7 +1498,9 @@
object_list: List[StateInfo] = [
processed_state for _ in range(fsdp_state.world_size)
]
dist.all_gather_object(object_list, processed_state)
assert fsdp_state.world_size == fsdp_state.process_group.size()

Check failure on line 1501 in torch/distributed/fsdp/_optim_utils.py

View workflow job for this annotation

GitHub Actions / lintrunner / linux-job

MYPY [union-attr]

Item "None" of "Optional[ProcessGroup]" has no attribute "size"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Add an assert message?

print(f"RV: {fsdp_state.process_group.size()} vs {torch.distributed.distributed_c10d._get_default_group().size()}")

Check failure on line 1502 in torch/distributed/fsdp/_optim_utils.py

View workflow job for this annotation

GitHub Actions / lintrunner / linux-job

MYPY [union-attr]

Item "None" of "Optional[ProcessGroup]" has no attribute "size"
dist.all_gather_object(object_list, processed_state, group=fsdp_state.process_group)

# Convert the gathered, pre-processed state of each rank to the original one.
gathered_state: Dict[str, Any] = {}
Expand Down Expand Up @@ -1592,6 +1595,7 @@
):
return optim_state

print(f"RV: gathering state")
gathered_state = _all_gather_optim_state(fsdp_state, optim_state)

# Unflatten state values.
Expand Down