-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[ZeroRedundancyOptimizer] Elastic and pytorch compatible state dict #52760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit d02adb1 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
…tate dict" [ghstack-poisoned]
o.step() | ||
self.assertEqual(x, torch.tensor([0.9], device=DEVICE)) | ||
|
||
def test_local_state_dict(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no more "local state dict" concept, a bit error prone and not elastic, there was no user that I know of for this in fairscale
else: | ||
# Dispatch this rank's state dictionary to the wrapped shard optimizer | ||
self.load_local_state_dict(ZeroRedundancyOptimizer.rank_local_state_dict(self.rank, state_dict)) | ||
# NOTE: PyTorch 1.5 does not index linearly but with the id(params) at saving time |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make zeroredundancyoptimizer compatible with a normal pytorch checkpoint
…tate dict" [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Add a some minor comments. The main one is whether we need the v1.5 workaround.
OrderedDict() | ||
) # device, rank, params | ||
self._param_rank: Dict[torch.Tensor, int] = {} | ||
self._param_to_index: Dict[int, int] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(this is prior to this PR) curious, what if we directly use Tensor
instead of id(param)
as the key, will that result in wrong mapping? If so, shall we add a comment to mention that here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw the comments below. Would I be correct if I assume the concern of using Tensor as the map key was that Tensor hash values depends Tensor implementation which might 1) depend on Tensor value 2) be more expensive then id()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm good question, I didn't benchmark that, maybe nicer to move to Tensor for consistency ? The context is that for state dicts in pytorch code id(param) is typically used, and I did the same. When trying to "cache" this it moved into a more generic place, and it does look inconsistent with the line just above, I agree
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It just returns the id haha:
Lines 599 to 602 in 7a178a8
def __hash__(self): | |
if has_torch_function_unary(self): | |
return handle_torch_function(Tensor.__hash__, (self,), self) | |
return id(self) |
…tate dict" [ghstack-poisoned]
…tate dict" [ghstack-poisoned]
Summary: [ZeroRedundancyOptimizer] Elastic and pytorch compatible state dict Test Plan: CircleCI / Unit tests Differential Revision: D26703501 fbshipit-source-id: 9f42e16735a192eaca24a8ed9108b50cd13460c3
…roken tooling on "[ZeroRedundancyOptimizer] Elastic and pytorch compatible state dict" [ghstack-poisoned]
Summary: Same as #52760 which I could not get to land. I just could not live with ghstack/ghimport/randomly broken things, I break enough of them myself, so this is a fresh copy without ghstack shenanigans. I'm hopeful that this can land relatively bug free, and am sorry for the duplications.. What this does: - call the common_utils test runner instead of unittest, because it seems that it's how it should be done - change the returned state from ZeroRedundancyOptimizer to be PyTorch compliant, which has the added benefit of being elastic (world size independent) Pull Request resolved: #52960 Reviewed By: mrshenli Differential Revision: D26710932 Pulled By: blefaudeux fbshipit-source-id: 1d914bc9221442ba1bb2b48f5df10c313e674ece
Summary: Same as pytorch#52760 which I could not get to land. I just could not live with ghstack/ghimport/randomly broken things, I break enough of them myself, so this is a fresh copy without ghstack shenanigans. I'm hopeful that this can land relatively bug free, and am sorry for the duplications.. What this does: - call the common_utils test runner instead of unittest, because it seems that it's how it should be done - change the returned state from ZeroRedundancyOptimizer to be PyTorch compliant, which has the added benefit of being elastic (world size independent) Pull Request resolved: pytorch#52960 Reviewed By: mrshenli Differential Revision: D26710932 Pulled By: blefaudeux fbshipit-source-id: 1d914bc9221442ba1bb2b48f5df10c313e674ece
Summary: Same as pytorch#52760 which I could not get to land. I just could not live with ghstack/ghimport/randomly broken things, I break enough of them myself, so this is a fresh copy without ghstack shenanigans. I'm hopeful that this can land relatively bug free, and am sorry for the duplications.. What this does: - call the common_utils test runner instead of unittest, because it seems that it's how it should be done - change the returned state from ZeroRedundancyOptimizer to be PyTorch compliant, which has the added benefit of being elastic (world size independent) Pull Request resolved: pytorch#52960 Reviewed By: mrshenli Differential Revision: D26710932 Pulled By: blefaudeux fbshipit-source-id: 1d914bc9221442ba1bb2b48f5df10c313e674ece
Stack from ghstack: