-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Checkpoint][2D][6/N] Add optimizer and update default_planner to core distributed #90212
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90212
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f151190: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchmergebot rebase |
@pytorchbot successfully started a rebase job. Check the current status here |
Rebase failed due to Command
Raised by https://github.com/pytorch/pytorch/actions/runs/3625809969 |
❌ 🤖 pytorchbot command failed:
Try |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good, have some question about the planner init and I think we should probably have a plan about the user facing API consolidation.
def test_distributed_tensor_planner(self) -> None: | ||
CHECKPOINT_DIR = self.temp_dir | ||
|
||
model = FSDP(torch.nn.Linear(8, 8, device="meta")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would FSDP automatically materialize the meta device Linear layer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe so. Here is an example from fsdp test (
test_fully_shard.py).
pytorch/test/distributed/_composable/test_fully_shard.py
Lines 199 to 204 in ce21262
fsdp_wrapped_model = FSDP( | |
Model(device="meta"), | |
auto_wrap_policy=Model.policy(), | |
param_init_fn=_param_init_fn, | |
use_orig_params=True, | |
) |
self.dedup_replicated_tensors = dedup_replicated_tensors | ||
self.mappings = {} | ||
|
||
def init(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a reason that we separate __init__
and init
as two functions? it's a bit confusing, shall we combine them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Filed a task regarding renaming init() method for planners. #90346
original_state_dict: STATE_DICT_TYPE | ||
mappings: FLATTEN_MAPPING | ||
|
||
def __init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same question about init and __init__
@@ -242,3 +348,79 @@ def _create_default_local_metadata(state_dict: STATE_DICT_TYPE) -> Metadata: | |||
plan = _create_default_metadata_only_plan(state_dict) | |||
_, md = create_default_global_save_plan([plan]) | |||
return md | |||
|
|||
|
|||
def _check_box_overlap( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are those gonna be used by other files or it's only specific to this file? if it's a general util, we should probably move this to util file instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So far it is only used by this file, but I'll do a cleanup for what APIs should be public and what should be private and move helper functions around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
File task to revisit what APIs should be public and what APIs should be private. #90328
@@ -21,7 +21,7 @@ def wrapper(self, *args: Tuple[object], **kwargs: Dict[str, Any]) -> None: | |||
# Only create temp_dir when rank is 0 | |||
if dist.get_rank() == 0: | |||
temp_dir = tempfile.mkdtemp() | |||
print(f"Using temp directory: {self.temp_dir }") | |||
print(f"Using temp directory: {temp_dir}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was this not captured by the test before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Filed a new issue to add test for this utils. #90327
return super().lookup_tensor(self.translation.get(index, index)) | ||
|
||
|
||
def load_sharded_optimizer_state_dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like as a user of checkpoint package, I shouldn't worry about loading this specific function from this module for my optimizer checkpoint, we should hide this from the user API, and ideally it should just simply be torch.distributed.checkpoint.save/load
state_dict, all of the components like the sharded optimizer state_dict should be automatically loaded/handled by the checkpoint package. As part of beta release, could you document the plan about user API consolidation in an issue and the detailed work items?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add issue to consolidate loading optimizer state dict to load_state_dict of dist_cp directly in #90325
) | ||
model_2.load_state_dict(state_dict["model"]) | ||
|
||
optim_state = load_sharded_optimizer_state_dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we consolidate optimizer state dict to load_state_dict
of dist_cp directly?
@pytorchmergebot rebase |
@pytorchbot successfully started a rebase job. Check the current status here |
Successfully rebased |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, thanks for addressing comments and tracking issues :)
Update ccomment. Co-authored-by: Wanchao <wanchaol@users.noreply.github.com>
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…e distributed (pytorch#90212) This is the last PR for integrating 2D into core distributed. This PR does the following: 1. Add optimizer.py: this adds ability to load a state_dict in conjunction with FSDP sharded optimzer state. 2. Update default_planner.py to support 2D checkpoint. 3. Add test_fsdp_optim_state.py as a unit test for No. 1. 4. Fix bug in torch/testing/_internal/distributed/checkpoint_utils.py 5. Rename the filename for the APIs that should be private. Will organize and cleanup further in following PRs. pytorch#90328 Docstring and integration test will be added in the following PRs. Pull Request resolved: pytorch#90212 Approved by: https://github.com/wanchaol
Fixing import and type error (change due to pytorch/pytorch#90212). ``` spmd/checkpoint/dt_planner.py:7:0 Undefined import [21]: Could not find a module corresponding to import torch.distributed.checkpoint.dedup_tensors. spmd/checkpoint/dt_planner.py:87:24 Undefined attribute [16]: Module torch.distributed.checkpoint has no attribute dedup_tensors. spmd/checkpoint/pg_planner.py:79:8 Illegal annotation target [35]: Target self.original_state_dict cannot be annotated. spmd/checkpoint/pg_planner.py:79:63 Unused ignore [0]: The pyre-ignore[16] or pyre-fixme[16] comment is not suppressing type errors, please remove it. spmd/checkpoint/pg_planner.py:89:59 Unused ignore [0]: The pyre-ignore[16] or pyre-fixme[16] comment is not suppressing type errors, please remove it. ``` Remove test/spmd/checkpoint/test_fsdp_optim_state.py as this has been upstreamed to pt-main in pytorch/pytorch#90212.
After this PR(pytorch/pytorch#90212) landed to Pytorch master, ProcessGroupAwareSavePlanner and DefaultSavePlanner has an overlapping variable (self.original_state_dict), which would result in the bytes value of the state_dict not being loaded correctly. This PR fixes the issue above and re-instate the test for pg_planner.
After this PR(pytorch/pytorch#90212) landed to Pytorch master, the class variable is defined in DefaultSavePlanner. Therefore, we will need to pass the correct value to super().__init__(). This PR fixes the issue above and re-instate the test for dt_planner.
This is the last PR for integrating 2D into core distributed.
This PR does the following:
Docstring and integration test will be added in the following PRs.