Skip to content

Conversation

janeyx99
Copy link
Contributor

@janeyx99 janeyx99 commented Jul 31, 2024

This PR adds the foreach impl for Adafactor knowing that there are many ways to improve its runtime perf today (by adding more foreach support). After this PR:

  • we have a foreach flag for Adafactor
  • It is NOT the default. Why not? It is only slightly faster + uses O(n) more memory where n is the number of params in your max param group. People tend to use Adafactor for memory efficiency.

Next steps:

  • make torch.compile possible on it
  • make it faster (by adding more foreach apis)

Stack from ghstack (oldest at bottom):

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Jul 31, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/132336

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 97797b5 with merge base 61625a1 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

janeyx99 added a commit that referenced this pull request Jul 31, 2024
ghstack-source-id: 1888a78
Pull Request resolved: #132336
@janeyx99 janeyx99 marked this pull request as ready for review August 7, 2024 22:11
@janeyx99 janeyx99 requested a review from albanD as a code owner August 7, 2024 22:11
This PR adds the foreach impl for Adafactor knowing that there are many ways to improve its runtime perf today (by adding more foreach support). After this PR:
- we have a foreach flag for Adafactor
- It is NOT the default. Why not? It is only slightly faster + uses O(n) more memory where n is the number of params in your max param group. People tend to use Adafactor for memory efficiency.

Next steps:
- make torch.compile possible on it
- make it faster (by adding more foreach apis)




[ghstack-poisoned]
janeyx99 added a commit that referenced this pull request Aug 7, 2024
ghstack-source-id: b8ee838
Pull Request resolved: #132336
Comment on lines 532 to 534
torch._foreach_mul_(device_row_vars, beta2_ts) # type: ignore[arg-type]
torch._foreach_mul_(row_means, one_minus_beta2_ts)
torch._foreach_add_(device_row_vars, row_means) # type: ignore[arg-type]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the future would be a

torch._foreach_lerp_(device_row_vars, row_means, one_minus_beta2_ts)

if we had ScalarList support for _foreach_lerp 3rd arg

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you expect to do these before merging the PR?
Should these improvements be recapped in an issue?

Comment on lines 544 to 546
torch._foreach_mul_(device_col_vars, beta2_ts) # type: ignore[arg-type]
torch._foreach_mul_(col_means, one_minus_beta2_ts)
torch._foreach_add_(device_col_vars, col_means) # type: ignore[arg-type]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch._foreach_lerp_(device_col_vars, col_means, one_minus_beta2_ts)

Comment on lines 566 to 568
torch._foreach_mul_(device_variances, beta2_ts) # type: ignore[arg-type]
torch._foreach_mul_(grads_squared, one_minus_beta2_ts)
torch._foreach_add_(device_variances, grads_squared) # type: ignore[arg-type]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch._foreach_lerp_(device_variances, grads_squared, one_minus_beta2_ts)

), "row_var and col_var should be defined when grad is multidimensional"
# same as (g * g).mean(dim=-1) w/o materializing an intermediate size g
row_means = [
torch.norm(grad, dim=-1, keepdim=True) for grad in device_grads
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no foreach norm support for this type of norm

for row_var, col_var in zip(device_row_vars, device_col_vars)
]
row_var_means = [
row_var.mean(dim=-2, keepdim=True) for row_var in device_row_vars # type: ignore[union-attr]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no foreach mean

del col_means

var_estimates = [
row_var @ col_var # type: ignore[operator]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no foreach mm lol, probably the bulk of the work

Comment on lines +574 to +575
torch._foreach_sqrt_(var_estimates)
torch._foreach_reciprocal_(var_estimates)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would benefit from a foreach_rsqrt

for a, update in zip(alphas, updates)
]
torch._foreach_mul_(updates, alphas)
torch._foreach_add_(device_params, updates) # type: ignore[arg-type]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice to have a foreach_add where the alphas could be a scalarlist

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds pretty good!
Curious what's the plan for all the future improvements

device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 # type: ignore[arg-type]
)
else:
torch._foreach_add_(device_state_steps, 1) # type: ignore[arg-type]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch._foreach_add_(device_state_steps, 1) # type: ignore[arg-type]
torch._foreach_add_(device_state_steps, 1.) # type: ignore[arg-type]

?

Comment on lines 532 to 534
torch._foreach_mul_(device_row_vars, beta2_ts) # type: ignore[arg-type]
torch._foreach_mul_(row_means, one_minus_beta2_ts)
torch._foreach_add_(device_row_vars, row_means) # type: ignore[arg-type]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you expect to do these before merging the PR?
Should these improvements be recapped in an issue?

@janeyx99
Copy link
Contributor Author

@albanD I'm planning to encapsulate all the action items in an issue before landing this PR. Including perf wins, compile support, etc.

This PR adds the foreach impl for Adafactor knowing that there are many ways to improve its runtime perf today (by adding more foreach support). After this PR:
- we have a foreach flag for Adafactor
- It is NOT the default. Why not? It is only slightly faster + uses O(n) more memory where n is the number of params in your max param group. People tend to use Adafactor for memory efficiency.

Next steps:
- make torch.compile possible on it
- make it faster (by adding more foreach apis)




[ghstack-poisoned]
This PR adds the foreach impl for Adafactor knowing that there are many ways to improve its runtime perf today (by adding more foreach support). After this PR:
- we have a foreach flag for Adafactor
- It is NOT the default. Why not? It is only slightly faster + uses O(n) more memory where n is the number of params in your max param group. People tend to use Adafactor for memory efficiency.

Next steps:
- make torch.compile possible on it
- make it faster (by adding more foreach apis)




[ghstack-poisoned]
@janeyx99
Copy link
Contributor Author

Perf tracker with all issues: #133367

This PR adds the foreach impl for Adafactor knowing that there are many ways to improve its runtime perf today (by adding more foreach support). After this PR:
- we have a foreach flag for Adafactor
- It is NOT the default. Why not? It is only slightly faster + uses O(n) more memory where n is the number of params in your max param group. People tend to use Adafactor for memory efficiency.

Next steps:
- make torch.compile possible on it
- make it faster (by adding more foreach apis)




[ghstack-poisoned]
janeyx99 added a commit that referenced this pull request Aug 14, 2024
ghstack-source-id: b1c1eed
Pull Request resolved: #132336
@janeyx99 janeyx99 added ciflow/trunk Trigger trunk jobs on your pull request topic: performance topic category labels Aug 14, 2024
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

@janeyx99
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: optim topic: performance topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants