-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Add Adafactor foreach impl #132336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Adafactor foreach impl #132336
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/132336
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 97797b5 with merge base 61625a1 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
This PR adds the foreach impl for Adafactor knowing that there are many ways to improve its runtime perf today (by adding more foreach support). After this PR: - we have a foreach flag for Adafactor - It is NOT the default. Why not? It is only slightly faster + uses O(n) more memory where n is the number of params in your max param group. People tend to use Adafactor for memory efficiency. Next steps: - make torch.compile possible on it - make it faster (by adding more foreach apis) [ghstack-poisoned]
torch/optim/_adafactor.py
Outdated
torch._foreach_mul_(device_row_vars, beta2_ts) # type: ignore[arg-type] | ||
torch._foreach_mul_(row_means, one_minus_beta2_ts) | ||
torch._foreach_add_(device_row_vars, row_means) # type: ignore[arg-type] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the future would be a
torch._foreach_lerp_(device_row_vars, row_means, one_minus_beta2_ts)
if we had ScalarList support for _foreach_lerp 3rd arg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you expect to do these before merging the PR?
Should these improvements be recapped in an issue?
torch/optim/_adafactor.py
Outdated
torch._foreach_mul_(device_col_vars, beta2_ts) # type: ignore[arg-type] | ||
torch._foreach_mul_(col_means, one_minus_beta2_ts) | ||
torch._foreach_add_(device_col_vars, col_means) # type: ignore[arg-type] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch._foreach_lerp_(device_col_vars, col_means, one_minus_beta2_ts)
torch/optim/_adafactor.py
Outdated
torch._foreach_mul_(device_variances, beta2_ts) # type: ignore[arg-type] | ||
torch._foreach_mul_(grads_squared, one_minus_beta2_ts) | ||
torch._foreach_add_(device_variances, grads_squared) # type: ignore[arg-type] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch._foreach_lerp_(device_variances, grads_squared, one_minus_beta2_ts)
), "row_var and col_var should be defined when grad is multidimensional" | ||
# same as (g * g).mean(dim=-1) w/o materializing an intermediate size g | ||
row_means = [ | ||
torch.norm(grad, dim=-1, keepdim=True) for grad in device_grads |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no foreach norm support for this type of norm
torch/optim/_adafactor.py
Outdated
for row_var, col_var in zip(device_row_vars, device_col_vars) | ||
] | ||
row_var_means = [ | ||
row_var.mean(dim=-2, keepdim=True) for row_var in device_row_vars # type: ignore[union-attr] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no foreach mean
torch/optim/_adafactor.py
Outdated
del col_means | ||
|
||
var_estimates = [ | ||
row_var @ col_var # type: ignore[operator] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no foreach mm lol, probably the bulk of the work
torch._foreach_sqrt_(var_estimates) | ||
torch._foreach_reciprocal_(var_estimates) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would benefit from a foreach_rsqrt
torch/optim/_adafactor.py
Outdated
for a, update in zip(alphas, updates) | ||
] | ||
torch._foreach_mul_(updates, alphas) | ||
torch._foreach_add_(device_params, updates) # type: ignore[arg-type] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be nice to have a foreach_add where the alphas could be a scalarlist
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds pretty good!
Curious what's the plan for all the future improvements
torch/optim/_adafactor.py
Outdated
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 # type: ignore[arg-type] | ||
) | ||
else: | ||
torch._foreach_add_(device_state_steps, 1) # type: ignore[arg-type] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch._foreach_add_(device_state_steps, 1) # type: ignore[arg-type] | |
torch._foreach_add_(device_state_steps, 1.) # type: ignore[arg-type] |
?
torch/optim/_adafactor.py
Outdated
torch._foreach_mul_(device_row_vars, beta2_ts) # type: ignore[arg-type] | ||
torch._foreach_mul_(row_means, one_minus_beta2_ts) | ||
torch._foreach_add_(device_row_vars, row_means) # type: ignore[arg-type] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you expect to do these before merging the PR?
Should these improvements be recapped in an issue?
@albanD I'm planning to encapsulate all the action items in an issue before landing this PR. Including perf wins, compile support, etc. |
This PR adds the foreach impl for Adafactor knowing that there are many ways to improve its runtime perf today (by adding more foreach support). After this PR: - we have a foreach flag for Adafactor - It is NOT the default. Why not? It is only slightly faster + uses O(n) more memory where n is the number of params in your max param group. People tend to use Adafactor for memory efficiency. Next steps: - make torch.compile possible on it - make it faster (by adding more foreach apis) [ghstack-poisoned]
This PR adds the foreach impl for Adafactor knowing that there are many ways to improve its runtime perf today (by adding more foreach support). After this PR: - we have a foreach flag for Adafactor - It is NOT the default. Why not? It is only slightly faster + uses O(n) more memory where n is the number of params in your max param group. People tend to use Adafactor for memory efficiency. Next steps: - make torch.compile possible on it - make it faster (by adding more foreach apis) [ghstack-poisoned]
Perf tracker with all issues: #133367 |
This PR adds the foreach impl for Adafactor knowing that there are many ways to improve its runtime perf today (by adding more foreach support). After this PR: - we have a foreach flag for Adafactor - It is NOT the default. Why not? It is only slightly faster + uses O(n) more memory where n is the number of params in your max param group. People tend to use Adafactor for memory efficiency. Next steps: - make torch.compile possible on it - make it faster (by adding more foreach apis) [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR adds the foreach impl for Adafactor knowing that there are many ways to improve its runtime perf today (by adding more foreach support). After this PR:
Next steps:
Stack from ghstack (oldest at bottom):