-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Open
Labels
module: optimizerRelated to torch.optimRelated to torch.optimoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🚀 The feature, motivation and pitch
The single tensor version of Adafactor has already landed in #129905 with foreach Adafactor in progress here: #132336.
This issue is to track torch.compile support:
- Ensure single tensor Adafactor can run and is tested
- Ensure foreach Adafactor can run and is tested
- Adafactor foreach relies on the grouping logic to return a non-None dtype value in order to calculate eps1 in the default case. Historically, torch.compile has skipped the grouping logic as it's already handled in inductor and returned Nones for device and dtype. This is the most visible hurdle for lack of compile support in the foreach case today.
Alternatives
No response
Additional context
Adafactor is our first param-wise (not pointwise, not global) optimizer. There are many improvements left to be desired in the eager foreach implementation by supporting more foreach ops. Compile support would be pretty cool though.
cc @vincentqb @jbschlosser @albanD @crcrpar @ezyang @chauhang @penguinwu
Metadata
Metadata
Assignees
Labels
module: optimizerRelated to torch.optimRelated to torch.optimoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module