-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[optim] include nn.Parameter as foreach supported #95811
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/95811
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3add0bd: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: f92e1ac06f2d1c151ebefa45e6a5b444afcf7c30 Pull Request resolved: #95811
@@ -15,6 +15,7 @@ | |||
__all__ = ['Optimizer', 'register_optimizer_step_pre_hook', 'register_optimizer_step_post_hook'] | |||
_global_optimizer_pre_hooks: Dict[int, Callable] = OrderedDict() | |||
_global_optimizer_post_hooks: Dict[int, Callable] = OrderedDict() | |||
_foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: make it a tuple instead so it's not mutated accidentally mutated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually intentional! This PR was a realization that the defaulting was way too conservative and did not allow models to be included for faster foreach. We want to allow users to be able to add their own subclasses here if they know that foreach will not break on those tensor subclasses.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@janeyx99 What is the recommended way to add to this list? Should we directly append like the following in trainer code:
torch.optim.optimizer._foreach_supported_types.append(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@janeyx99 Have the same qq @awgu raised. Do you know how users are adding their own subclasses to the list?
Due to circular dependency, instead of adding DTensor here, I am just trying out what Andrew suggested first. Seems appending tensor subclass to the list has no effect at later stage so I am wondering how users are doing this if you are aware.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may be the first user to try this--what Andrew suggests seems okay for now barring the fact that it's a private API. Once we want to productionize this, we should work together to make this an exposed public API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It sounds like appending is not working?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It sounds like appending is not working?
My bad! False information. Just double check append seems working fine on the unit test locally!
This PR is a result of a realization that models are NOT subscribed to the foreach defaulting as have been claimed on our documentation for months now. BIG OOPS. Pull Request resolved: pytorch#95811 Approved by: https://github.com/albanD
* [optim] include nn.Parameter as foreach supported (#95811) This PR is a result of a realization that models are NOT subscribed to the foreach defaulting as have been claimed on our documentation for months now. BIG OOPS. Pull Request resolved: #95811 Approved by: https://github.com/albanD * [optim] Widen the cases for defaulting to foreach (#95820) Big OOP correction continued. Also added a test this time to verify the defaulting was as expected. The key here is realizing that the grouping for foreach already assumes that the non-param tensorlists follow suit in dtype and device, so it is too narrow to check that _all_ tensors were on CUDA. The main leeway this allowed was state_steps, which are sometimes cpu tensors. Since foreach _can_ handle cpu tensors, this should not introduce breakage. Pull Request resolved: #95820 Approved by: https://github.com/albanD
This PR is a result of a realization that models are NOT subscribed to the foreach defaulting as have been claimed on our documentation for months now. BIG OOPS. Pull Request resolved: pytorch/pytorch#95811 Approved by: https://github.com/albanD
This PR is a result of a realization that models are NOT subscribed to the foreach defaulting as have been claimed on our documentation for months now. BIG OOPS. Pull Request resolved: pytorch/pytorch#95811 Approved by: https://github.com/albanD
* [optim] include nn.Parameter as foreach supported (pytorch#95811) This PR is a result of a realization that models are NOT subscribed to the foreach defaulting as have been claimed on our documentation for months now. BIG OOPS. Pull Request resolved: pytorch#95811 Approved by: https://github.com/albanD * [optim] Widen the cases for defaulting to foreach (pytorch#95820) Big OOP correction continued. Also added a test this time to verify the defaulting was as expected. The key here is realizing that the grouping for foreach already assumes that the non-param tensorlists follow suit in dtype and device, so it is too narrow to check that _all_ tensors were on CUDA. The main leeway this allowed was state_steps, which are sometimes cpu tensors. Since foreach _can_ handle cpu tensors, this should not introduce breakage. Pull Request resolved: pytorch#95820 Approved by: https://github.com/albanD
This PR is a result of a realization that models are NOT subscribed to the foreach defaulting as have been claimed on our documentation for months now. BIG OOPS.
Stack from ghstack (oldest at bottom):