-
Notifications
You must be signed in to change notification settings - Fork 25.7k
C++ API handle optimizer defaults #161825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
C++ API handle optimizer defaults #161825
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161825
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 12 PendingAs of commit f5c263c with merge base 322091d ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on the fix, this is an interesting bug indeed. I left a comment on the overall approach. Furthermore, since these tests don't run on CI--could you post a paste of the C++ test results locally?
test/cpp/api/optim.cpp
Outdated
| ASSERT_NEAR(group1_opts.lr(), 0.002, 1e-6); // Inherited | ||
| ASSERT_EQ(group1_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited | ||
| ASSERT_NEAR(group1_opts.eps(), 1e-12, 1e-15); // Inherited | ||
| ASSERT_NEAR(group1_opts.weight_decay(), 0.11, 1e-6); // Preserved |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How come these can't be ASSERT_EQ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed. use in serialization tests left.
test/cpp/api/optim.cpp
Outdated
| } | ||
|
|
||
| TEST(OptimTest, MergeWithDefaultOptions_AdamW) { | ||
| torch::manual_seed(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this important to the test? the actual params won't matter, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right. removed
| "You must override it in your subclass of torch::optim::OptimizerCloneableOptions<YourOptimizerOptions>."); | ||
| } | ||
|
|
||
| void OptimizerOptions::overwrite_from(const OptimizerOptions& source) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! Some highlevel qs:
How come we need a whole new overwrite_from API?
From the toplevel I would expect us to fix the base class so that the user specified defaults override the original defaults and then are used in add_param_group, without the need for adding a new API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we don't :) . I've tried to provide a fix to the base class, without adding a new API, but don't see a way to do it without one new virtual function call.
a041b6f to
d205fa6
Compare
|
Hi @janeyx99, thanks very much for your feedback. I've taken each one of your comments into consideration. Here are the local optimizer tests (including the new ones). Is there a better way with gtest to see what the test actually ran (not just pass/fail)? dropping the filter and running 1020 tests pass too. |
|
Hmmm the reason I was hesitant about the first approach was because it required modifying every optimizer, which this new approach unfortunately still requires. If that is unavoidable, I think it is okay to have as simple of a solution as possible, but have it be an internal detail vs something users can see. To that effect, I would prefer the solution with as few additions to the public API surface + the lowest complexity. If the original approach was cleaner, then we can have a private |
|
Thanks for your comments @janeyx99 . That makes sense! I certainly agree that users should not have to be aware of these implementation details. I wasn't sure how much the runtime performance cost impacted your review. will revisit and simplify. |
d205fa6 to
0c5cf7f
Compare
|
Hi @janeyx99, here is my preferred approach. It does everything in optimizer.h/cpp, doesn't introduce any new API and most work is done at compile time. I've tried to follow how c10 uses template metaprogramming style. C++20 concepts can help smooth out some of the boilerplate, when they become available in PyTorch. Local tests are passing. Please let me know what you think. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much nicer! Can we privatize all the helpers?
0c5cf7f to
1f7aed3
Compare
|
Hi @janeyx99 , I've changed the helpers. I hope I've answered your questions. The proposed solution approach is to use constructor defaults as a comparison baseline to detect user intent, then inherit from optimizer defaults for unspecified fields. |
| }; | ||
|
|
||
| template <typename Derived> | ||
| // Forward declarations for optimizer option types |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make the following classes and structs private as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since these are forward declarations, I'm inclined to not change the style to use the prefix (would need touching all optimizer files). I do have to change them to be struct (which resolves the inconsistency causing clang build failure).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please do privatize as much as possible so we are not inadvertently growing our API surface.
|
|
||
| // SFINAE field detection - detects optimizer fields using public accessor methods | ||
| template <class T, class Enable = void> | ||
| struct has_lr : std::false_type {}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These structs too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These helper structs are in the private part of the class OptimizerCloneableOptions.
Do you just want me to prefix all the implementation stuff I've added with an underscore? Is that just a style convention (happy to follow) or is there some codegen or python binding specific transformation?
|
Thank you so much for following through this change! We are very close to the end! |
1f7aed3 to
41731f4
Compare
|
OK I think this is ready. Thanks for your feedback and help @janeyx99 ! |
|
@pytorchbot merge |
|
PR targets viable/strict rather than main, refusing merge request |
|
@pytorchbot merge -r main |
|
@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here |
|
Successfully rebased |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
Hey, apologies for the revert, but your PR is causing undefined symbol errors when linking the crossplatform build targets internally at meta: (Similar errors for AdamW, Adagrad, RMSprop, LBFGS) probably need to move |
|
@pytorchbot revert -m="Diff reverted internally" -c="ghfirst" This Pull Request has been reverted by a revert inside Meta. To re-land this change, please open another pull request, assign the same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk).) |
|
@pytorchbot successfully started a revert job. Check the current status here. |
This reverts commit f332017. Reverted #161825 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](#161825 (comment)))
|
@stmcgovern your PR has been successfully reverted. |
This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.
Thanks for the information @izaitsevfb. I'll move the function as you suggest and open another PR. |
Addresses PyTorch issue pytorch#141884 by implementing automatic parameter group inheritance that achieves Python-C++ API parity without breaking changes. - Uses comparison-based merging to infer user intent vs default inheritance - C++17 SFINAE patterns following PyTorch conventions (matches c10/util/TypeTraits.h) -Add comprehensive tests for optimizer parameter group inheritance
f5c263c to
5ae6650
Compare
|
Follow-on PR is #165182 |
Fixes pytorch#141884 This fixes the issue for all optimizers and parameter options. A member function `overwrite_from` is added to the optimizer base class. Each optimizer then implements this function for comparing their accepted parameters to defaults. A SFINAE approach to handle the different optimizer parameters generically (in optimizer.h only) was evaluated, but I think this is easier to review and maintain. This mirrors the Python API up to one edge case. An example of the edge case is provided below. Python can distinguish between 1) Key not present in dict = "not specified" and 2) Key present in dict = "explicitly set". The C++ implementation cannot. The issue hinges on whether or not to track if a particular parameter was set by the user explicitly or not (discrepancy in the case when the constructor default is explicitly passed in). To track this seems like it will take more intervention than would be worth it (modify TORCH_ARG to keep track, use std::optional for the parameter types, use bitset tracking) and was not pursued in the current PR. I'm happy to alter the design if appropriate. ### Example of edge case hinging on CONSTRUCTOR DEFAULTS vs OPTIMIZER DEFAULTS 1. CONSTRUCTOR DEFAULTS: These are the values you get when calling AdamOptions() AdamOptions().lr() = 0.001 AdamOptions().weight_decay() = 0 AdamOptions().eps() = 1e-08 2. OPTIMIZER DEFAULTS: These are the values the user chose when creating the optimizer User's optimizer defaults: optimizer.lr() = 0.005 optimizer.weight_decay() = 0.1 optimizer.eps() = 1e-07 3. THE PROBLEM SCENARIO: User wants to add a parameter group with explicit weight_decay=0.0 User sets: weight_decay(0) 4. THE CONFUSION: Constructor default weight_decay: 0 User's explicit weight_decay: 0 Are they equal? YES Since they're equal, our overwrite_from() logic thinks: "User didn't set weight_decay explicitly, use optimizer default" 5. CURRENT BEHAVIOR: Final weight_decay: 0.1 User expected: 0 Match? ❌ NO === KEY INSIGHT === Constructor defaults are built into the C++ class definition. Optimizer defaults are chosen by the user at runtime. We want to respect the user intention. Pull Request resolved: pytorch#161825 Approved by: https://github.com/janeyx99
This reverts commit f332017. Reverted pytorch#161825 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](pytorch#161825 (comment)))
Fixes #141884
This fixes the issue for all optimizers and parameter options.
A member function
overwrite_fromis added to the optimizer base class. Each optimizer then implements this function for comparing their accepted parameters to defaults. A SFINAE approach to handle the different optimizer parameters generically (in optimizer.h only) was evaluated, but I think this is easier to review and maintain.This mirrors the Python API up to one edge case. An example of the edge case is provided below.
Python can distinguish between 1) Key not present in dict = "not specified" and 2) Key present in dict = "explicitly set". The C++ implementation cannot.
The issue hinges on whether or not to track if a particular parameter was set by the user explicitly or not (discrepancy in the case when the constructor default is explicitly passed in).
To track this seems like it will take more intervention than would be worth it (modify TORCH_ARG to keep track, use std::optional for the parameter types, use bitset tracking) and was not pursued in the current PR. I'm happy to alter the design if appropriate.
Example of edge case hinging on CONSTRUCTOR DEFAULTS vs OPTIMIZER DEFAULTS
CONSTRUCTOR DEFAULTS:
These are the values you get when calling AdamOptions()
AdamOptions().lr() = 0.001
AdamOptions().weight_decay() = 0
AdamOptions().eps() = 1e-08
OPTIMIZER DEFAULTS:
These are the values the user chose when creating the optimizer
User's optimizer defaults:
optimizer.lr() = 0.005
optimizer.weight_decay() = 0.1
optimizer.eps() = 1e-07
THE PROBLEM SCENARIO:
User wants to add a parameter group with explicit weight_decay=0.0
User sets: weight_decay(0)
THE CONFUSION:
Constructor default weight_decay: 0
User's explicit weight_decay: 0
Are they equal? YES
Since they're equal, our overwrite_from() logic thinks:
"User didn't set weight_decay explicitly, use optimizer default"
CURRENT BEHAVIOR:
Final weight_decay: 0.1
User expected: 0
Match? ❌ NO
=== KEY INSIGHT ===
Constructor defaults are built into the C++ class definition.
Optimizer defaults are chosen by the user at runtime. We want to respect the user intention.