-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[mta] Implement fused SGD #94791
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mta] Implement fused SGD #94791
Conversation
38cd2bd to
b3945f1
Compare
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
test/optim/test_optim.py
Outdated
| (optim.SGD,), | ||
| [ | ||
| {"lr": 0.1, "momentum": 0.0, "dampening": d, "weight_decay": w, "nesterov": n} | ||
| for d, w, n in itertools.product((0.0, 0.5), (0.0, 0.5), (False,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor nit but it seems like this entry and the next entry in the list could be combined by adding another tuple (0.0, 0.5) for momentum.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am NOT done reviewing this! Will continue reviewing later
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
| - func: _propagate_xla_data(Tensor input, Tensor output) -> () | ||
| variants: function | ||
|
|
||
| - func: _fused_sgd_(Tensor(a!)[] self, Tensor(b!)[] grads, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to fused Adam and AdamW, we should allow Tensor lr as well.
| expected_scales, | ||
| expected_growth_trackers, | ||
| expected_grad_vals): | ||
| for data, scale, growth_tracker, grad_val in zip(input_vals, expected_scales, expected_growth_trackers, expected_grad_vals): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this purely stylistic?
| # because JIT can't handle Optionals nor fancy conditionals when scripting | ||
| if not torch.jit.is_scripting(): | ||
| _, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False) | ||
| fused, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment here similar to in adam(w):
Note that we default to foreach and pass False to use_fused. This is not a mistake--we want to give the fused implementation bake-in time before making it the default, even if it is typically faster.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks p good. The highlevel feedback I have in addition to comments:
- We now support tensor lr for fused adam/w. We should strive to do the same with sgd. We don't need to land this as a part of the PR though, if you think it will make this PR too large, but we should then be explicit that the fused implementation does not accept tensor LRs. Then in a followup PR, we can add the tensor LR overloads.
- Could you get some benchmarks showing the wins?
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
| CUDA: _fused_sgd_with_momentum_kernel_cuda_ | ||
| autogen: _fused_sgd_with_momentum, _fused_sgd_with_momentum.out | ||
|
|
||
| - func: _fused_sgd_with_momentum_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be an overload instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or like, an optional value?
| for d, w, n in product((0.0, 0.5), (0.0, 0.5), (False,)) | ||
| ] + [ | ||
| {"lr": 0.1, "momentum": 0.5, "dampening": d, "weight_decay": w, "nesterov": n, "fused": True} | ||
| for d, w, n in product((0.0,), (0.0, 0.5), (True,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
False too?
test/optim/test_optim.py
Outdated
| for lr, d, w, n in itertools.product((0.1, torch.tensor(0.1)), (0.0, 0.5), (0.0, 0.5), (False,)) | ||
| ] + [ | ||
| {"lr": lr, "momentum": 0.5, "dampening": d, "weight_decay": w, "nesterov": n} | ||
| for lr, d, w, n in itertools.product((0.1, torch.tensor(0.1)), (0.0,), (0.0, 0.5), (True,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
False too
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
depends on #116583 rel: - #94791 Pull Request resolved: #116585 Approved by: https://github.com/janeyx99
rel:
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @gujinghui @PenghuiCheng @jianyuh @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen @mcarilli @ptrblck @leslie-fang-intel @voznesenskym @penguinwu @EikanWang @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @avikchaudhuri @gmagogsfm