Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Checkpoint][2D][2/N] Add traverse for distributed checkpoint to core distributed #89398

Closed
wants to merge 9 commits into from
Closed

Conversation

wz337
Copy link
Contributor

@wz337 wz337 commented Nov 21, 2022

This PR moves traverse and its test to torch.distributed.checkpoint. This is a pre-req for enabling 2D checkpoint.

This is used when flatten nested dict and flatten sharded tensors.

Docstring and comments will be added in the following PRs.

Test:

python3 test/distributed/_tensor/parallel/test_2d_parallel.py

and CI

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 21, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89398

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 9182656:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@wz337 wz337 changed the title add_traverse [Checkpoint][2D][1/N] Add traverse for distributed checkpoint Nov 21, 2022
@wz337 wz337 changed the title [Checkpoint][2D][1/N] Add traverse for distributed checkpoint [Checkpoint][2D][1/N] Move traverse for distributed checkpoint to core distributed Nov 21, 2022
@wz337 wz337 changed the title [Checkpoint][2D][1/N] Move traverse for distributed checkpoint to core distributed [Checkpoint][2D][2/N] Move traverse for distributed checkpoint to core distributed Nov 21, 2022
@wz337 wz337 changed the title [Checkpoint][2D][2/N] Move traverse for distributed checkpoint to core distributed [Checkpoint][2D][2/N] Add traverse for distributed checkpoint to core distributed Nov 21, 2022
@wz337 wz337 requested a review from wanchaol November 21, 2022 16:05
@wz337 wz337 marked this pull request as ready for review November 21, 2022 16:05
data,
)
self.assertEqual(
data[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason why all those asserts linting in this way? can we shorten them in fewer lines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will reformat and clean this up. I am taking this directly from Tau. I remember Rodrigo mentioned that he purposely formatted it this way to pass CI. Not sure about the details. lol

STATE_DICT_TYPE,
)
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed._tensor import DTensor as DT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's just call it DTensor? I feel DT is a bit non-readable.

OBJ_PATH = Tuple[PATH_ITEM, ...]
T = TypeVar("T")

STATE_DICT_ITEM = object
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we make this type alias for object?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My gut feeling is that this is for readability since we are traversing state dict here.

_print_nested(
value._local_tensor,
f"{padding}\t",
"(offset ???) ",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this ??? intentional

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think he is trying to print the offset for a given _local_tensor here, but don't have an API for it. I am removing this for now. To my knowledge, we don't have anything for this yet, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean a nested ST + DT? I didn't aware of it. Could you test this with the 2-D tests to make sure removing it does not break anything? thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a TODO here to revisit this. Removing it for now as it doesn't break the test_2d_parallel.py tests.

Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thanks for fixing the lint. Have one more comment.


def _print_nested(
value: STATE_DICT_ITEM,
padding: str = "",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's padding mean here? I didn't see it being used anywhere, shall we remove this arg?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems redundant. Removing it as well. Thanks for pointing out!

@wz337
Copy link
Contributor Author

wz337 commented Nov 22, 2022

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 22, 2022
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
… distributed (pytorch#89398)

This PR moves traverse and its test to torch.distributed.checkpoint. This is a pre-req for enabling 2D checkpoint.

This is used when flatten nested dict and flatten sharded tensors.

Docstring and comments will be added in the following PRs.

Test:
```
python3 test/distributed/_tensor/parallel/test_2d_parallel.py
```
and CI
Pull Request resolved: pytorch#89398
Approved by: https://github.com/wanchaol
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants