-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Checkpoint][2D][2/N] Add traverse for distributed checkpoint to core distributed #89398
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89398
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 9182656: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
data, | ||
) | ||
self.assertEqual( | ||
data[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a reason why all those asserts linting in this way? can we shorten them in fewer lines?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will reformat and clean this up. I am taking this directly from Tau. I remember Rodrigo mentioned that he purposely formatted it this way to pass CI. Not sure about the details. lol
STATE_DICT_TYPE, | ||
) | ||
from torch.distributed._shard.sharded_tensor.api import ShardedTensor | ||
from torch.distributed._tensor import DTensor as DT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: let's just call it DTensor? I feel DT is a bit non-readable.
OBJ_PATH = Tuple[PATH_ITEM, ...] | ||
T = TypeVar("T") | ||
|
||
STATE_DICT_ITEM = object |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why we make this type alias for object?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My gut feeling is that this is for readability since we are traversing state dict here.
_print_nested( | ||
value._local_tensor, | ||
f"{padding}\t", | ||
"(offset ???) ", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this ??? intentional
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think he is trying to print the offset for a given _local_tensor here, but don't have an API for it. I am removing this for now. To my knowledge, we don't have anything for this yet, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean a nested ST + DT? I didn't aware of it. Could you test this with the 2-D tests to make sure removing it does not break anything? thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a TODO here to revisit this. Removing it for now as it doesn't break the test_2d_parallel.py tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, thanks for fixing the lint. Have one more comment.
|
||
def _print_nested( | ||
value: STATE_DICT_ITEM, | ||
padding: str = "", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's padding mean here? I didn't see it being used anywhere, shall we remove this arg?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems redundant. Removing it as well. Thanks for pointing out!
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
… distributed (pytorch#89398) This PR moves traverse and its test to torch.distributed.checkpoint. This is a pre-req for enabling 2D checkpoint. This is used when flatten nested dict and flatten sharded tensors. Docstring and comments will be added in the following PRs. Test: ``` python3 test/distributed/_tensor/parallel/test_2d_parallel.py ``` and CI Pull Request resolved: pytorch#89398 Approved by: https://github.com/wanchaol
This PR moves traverse and its test to torch.distributed.checkpoint. This is a pre-req for enabling 2D checkpoint.
This is used when flatten nested dict and flatten sharded tensors.
Docstring and comments will be added in the following PRs.
Test:
and CI