Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Checkpoint][2D][6/N] Add optimizer and update default_planner to core distributed #90212

Closed
wants to merge 10 commits into from
Closed

[Checkpoint][2D][6/N] Add optimizer and update default_planner to core distributed #90212

wants to merge 10 commits into from

Conversation

wz337
Copy link
Contributor

@wz337 wz337 commented Dec 5, 2022

This is the last PR for integrating 2D into core distributed.

This PR does the following:

  1. Add optimizer.py: this adds ability to load a state_dict in conjunction with FSDP sharded optimzer state.
  2. Update default_planner.py to support 2D checkpoint.
  3. Add test_fsdp_optim_state.py as a unit test for No. 1.
  4. Fix bug in torch/testing/_internal/distributed/checkpoint_utils.py
  5. Rename the filename for the APIs that should be private. Will organize and cleanup further in following PRs. [PT-D][checkpoint] Revisit public and private APIs in distributed checkpoint module #90328

Docstring and integration test will be added in the following PRs.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Dec 5, 2022
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 5, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90212

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f151190:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@wz337 wz337 changed the title Add optimizer update default planner [Checkpoint][2D][6/N] Add optimizer and update default_planner to core distributed Dec 5, 2022
@wz337 wz337 marked this pull request as ready for review December 6, 2022 00:26
@wz337
Copy link
Contributor Author

wz337 commented Dec 6, 2022

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict pull/90212/head returned non-zero exit code 1

Rebasing (33/45)
Auto-merging torch/distributed/checkpoint/default_planner.py
CONFLICT (content): Merge conflict in torch/distributed/checkpoint/default_planner.py
Auto-merging torch/distributed/checkpoint/optimizer.py
CONFLICT (add/add): Merge conflict in torch/distributed/checkpoint/optimizer.py
error: could not apply 01baabafae... update import
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
Could not apply 01baabafae... update import

Raised by https://github.com/pytorch/pytorch/actions/runs/3625809969

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 6, 2022

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: unrecognized arguments: master

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci} ...

Try @pytorchbot --help for more info.

@wz337 wz337 marked this pull request as draft December 6, 2022 04:30
@wz337 wz337 marked this pull request as ready for review December 6, 2022 05:09
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, have some question about the planner init and I think we should probably have a plan about the user facing API consolidation.

def test_distributed_tensor_planner(self) -> None:
CHECKPOINT_DIR = self.temp_dir

model = FSDP(torch.nn.Linear(8, 8, device="meta"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would FSDP automatically materialize the meta device Linear layer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so. Here is an example from fsdp test (
test_fully_shard.py).

fsdp_wrapped_model = FSDP(
Model(device="meta"),
auto_wrap_policy=Model.policy(),
param_init_fn=_param_init_fn,
use_orig_params=True,
)

self.dedup_replicated_tensors = dedup_replicated_tensors
self.mappings = {}

def init(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason that we separate __init__ and init as two functions? it's a bit confusing, shall we combine them?

Copy link
Contributor Author

@wz337 wz337 Dec 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed a task regarding renaming init() method for planners. #90346

original_state_dict: STATE_DICT_TYPE
mappings: FLATTEN_MAPPING

def __init__(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question about init and __init__

@@ -242,3 +348,79 @@ def _create_default_local_metadata(state_dict: STATE_DICT_TYPE) -> Metadata:
plan = _create_default_metadata_only_plan(state_dict)
_, md = create_default_global_save_plan([plan])
return md


def _check_box_overlap(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are those gonna be used by other files or it's only specific to this file? if it's a general util, we should probably move this to util file instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far it is only used by this file, but I'll do a cleanup for what APIs should be public and what should be private and move helper functions around.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

File task to revisit what APIs should be public and what APIs should be private. #90328

@@ -21,7 +21,7 @@ def wrapper(self, *args: Tuple[object], **kwargs: Dict[str, Any]) -> None:
# Only create temp_dir when rank is 0
if dist.get_rank() == 0:
temp_dir = tempfile.mkdtemp()
print(f"Using temp directory: {self.temp_dir }")
print(f"Using temp directory: {temp_dir}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this not captured by the test before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed a new issue to add test for this utils. #90327

return super().lookup_tensor(self.translation.get(index, index))


def load_sharded_optimizer_state_dict(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like as a user of checkpoint package, I shouldn't worry about loading this specific function from this module for my optimizer checkpoint, we should hide this from the user API, and ideally it should just simply be torch.distributed.checkpoint.save/load state_dict, all of the components like the sharded optimizer state_dict should be automatically loaded/handled by the checkpoint package. As part of beta release, could you document the plan about user API consolidation in an issue and the detailed work items?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add issue to consolidate loading optimizer state dict to load_state_dict of dist_cp directly in #90325

)
model_2.load_state_dict(state_dict["model"])

optim_state = load_sharded_optimizer_state_dict(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we consolidate optimizer state dict to load_state_dict of dist_cp directly?

@wz337
Copy link
Contributor Author

wz337 commented Dec 7, 2022

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased add_optimizer_update_default_planner onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout add_optimizer_update_default_planner && git pull --rebase)

Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks for addressing comments and tracking issues :)

torch/distributed/checkpoint/default_planner.py Outdated Show resolved Hide resolved
Update ccomment.

Co-authored-by: Wanchao <wanchaol@users.noreply.github.com>
@wz337
Copy link
Contributor Author

wz337 commented Dec 7, 2022

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 7, 2022
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
…e distributed (pytorch#90212)

This is the last PR for integrating 2D into core distributed.

This PR does the following:
1. Add optimizer.py: this adds ability to load a state_dict in conjunction with FSDP sharded optimzer state.
2. Update default_planner.py to support 2D checkpoint.
3. Add test_fsdp_optim_state.py as a unit test for No. 1.
4. Fix bug in torch/testing/_internal/distributed/checkpoint_utils.py
5. Rename the filename for the APIs that should be private. Will organize and cleanup further in following PRs. pytorch#90328

Docstring and integration test will be added in the following PRs.
Pull Request resolved: pytorch#90212
Approved by: https://github.com/wanchaol
wz337 added a commit to pytorch/PiPPy that referenced this pull request Dec 10, 2022
Fixing import and type error (change due to
pytorch/pytorch#90212).

```
spmd/checkpoint/dt_planner.py:7:0 Undefined import [21]: Could not find a module corresponding to import torch.distributed.checkpoint.dedup_tensors.
spmd/checkpoint/dt_planner.py:87:24 Undefined attribute [16]: Module torch.distributed.checkpoint has no attribute dedup_tensors.
spmd/checkpoint/pg_planner.py:79:8 Illegal annotation target [35]: Target self.original_state_dict cannot be annotated.
spmd/checkpoint/pg_planner.py:79:63 Unused ignore [0]: The pyre-ignore[16] or pyre-fixme[16] comment is not suppressing type errors, please remove it.
spmd/checkpoint/pg_planner.py:89:59 Unused ignore [0]: The pyre-ignore[16] or pyre-fixme[16] comment is not suppressing type errors, please remove it.
```

Remove test/spmd/checkpoint/test_fsdp_optim_state.py as this has been
upstreamed to pt-main in pytorch/pytorch#90212.
wz337 added a commit to pytorch/PiPPy that referenced this pull request Dec 16, 2022
After this PR(pytorch/pytorch#90212) landed to
Pytorch master, ProcessGroupAwareSavePlanner and DefaultSavePlanner has
an overlapping variable (self.original_state_dict), which would result
in the bytes value of the state_dict not being loaded correctly.

This PR fixes the issue above and re-instate the test for pg_planner.
wz337 added a commit to pytorch/PiPPy that referenced this pull request Dec 16, 2022
After this PR(pytorch/pytorch#90212) landed to
Pytorch master, the class variable is defined in DefaultSavePlanner.
Therefore, we will need to pass the correct value to super().__init__().

This PR fixes the issue above and re-instate the test for dt_planner.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants