New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Get rid of dim_groups attribute from DeviceMesh #103105
Conversation
This PR get rids of the dim_groups attribute from DeviceMesh, the main motivation behind this is that we should let c10d store the process groups during its creation instead of DeviceMesh, DeviceMesh should just handle ranks correctly. This could enable DTensor becomes picklable! (torch.save/load could be possible), which I will give it a try in the next PR [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/103105
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 5de352e: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR get rids of the dim_groups attribute from DeviceMesh, the main motivation behind this is that we should let c10d store the process groups during its creation instead of DeviceMesh, DeviceMesh should just handle ranks correctly. This could enable DTensor becomes picklable! (torch.save/load could be possible), which I will give it a try in the next PR [ghstack-poisoned]
This PR get rids of the dim_groups attribute from DeviceMesh, the main motivation behind this is that we should let c10d store the process groups during its creation instead of DeviceMesh, DeviceMesh should just handle ranks correctly. This could enable DTensor becomes picklable! (torch.save/load could be possible), which I will give it a try in the next PR [ghstack-poisoned]
This PR get rids of the dim_groups attribute from DeviceMesh, the main motivation behind this is that we should let c10d store the process groups during its creation instead of DeviceMesh, DeviceMesh should just handle ranks correctly. This could enable DTensor becomes picklable! (torch.save/load could be possible), which I will give it a try in the next PR [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we start to make DeviceMesh pickable, thanks for the work!! Some suggestions (I could be wrong). Let me know what you think ;-)
This PR get rids of the dim_groups attribute from DeviceMesh, the main motivation behind this is that we should let c10d store the process groups during its creation instead of DeviceMesh, DeviceMesh should just handle ranks correctly. This could enable DTensor becomes picklable! (torch.save/load could be possible), which I will give it a try in the next PR [ghstack-poisoned]
This PR get rids of the dim_groups attribute from DeviceMesh, the main motivation behind this is that we should let c10d store the process groups during its creation instead of DeviceMesh, DeviceMesh should just handle ranks correctly. This could enable DTensor becomes picklable! (torch.save/load could be possible), which I will give it a try in the next PR [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So basically if we store PG as attribute it will make it unpickable but now what we store is List[Tuple[str, List[int]]]
so it works now. Is this understanding correct?
@fduwjj That is correct. For some reason |
Yep, it's impossible to pickle process group, the previous sharded tensor one is fragile and not good for state_dict as it needs to add state_dict hook just for the sake of process group |
This PR get rids of the dim_groups attribute from DeviceMesh, the main motivation behind this is that we should let c10d store the process groups during its creation instead of DeviceMesh, DeviceMesh should just handle ranks correctly. This could enable DTensor becomes picklable! (torch.save/load could be possible), which I will give it a try in the next PR [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, hopefully this will not add too much perf overhead.
hmmm what do you mean by perf overhead here? I think this would not have any perf implications? |
This PR get rids of the dim_groups attribute from DeviceMesh, the main motivation behind this is that we should let c10d store the process groups during its creation instead of DeviceMesh, DeviceMesh should just handle ranks correctly. This could enable DTensor becomes picklable! (torch.save/load could be possible), which I will give it a try in the next PR [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
This PR get rids of the dim_groups attribute from DeviceMesh, the main
motivation behind this is that we should let c10d store the process
groups during its creation instead of DeviceMesh, DeviceMesh should just
handle ranks correctly.
This could enable DTensor becomes picklable! (torch.save/load could be
possible), which I will give it a try in the next PR