New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] SAC on new API stack (w/ EnvRunner and ConnectorV2): SACLearner
and SACTorchLearner
classes.
#42570
Conversation
…ality to 'TorchLearner' to build a trainable 'nn.Parameter' needed for the temperature parameter in 'SAC'. Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
SACLearner
and SACTorchLearner
SACLearner
and SACTorchLearner
classes.
@@ -290,6 +290,9 @@ def build(self) -> None: | |||
flags, so that `_make_module()` can place the created module on the correct | |||
device. After running super() it will wrap the module in a TorchDDPRLModule | |||
if `_distributed` is True. | |||
Note, in inherited classes it is advisable to call the parent's `build()` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, I'm not sure why we don't move the entire code before the super().build()
call below into the c'tor already. All we do here is to determine the GPU device - if any - based on config settings, which are all available already at c'tor time.
Then we don't need this comment here and we should also add the @OverrideToImplementCustomLogic_CallToSuperRecommended
decorator on top of this method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am with you on this.
However, the comment I added there was b/c when calling the build()
in inherited classes at the beginning in the inherited build()
method configure_optimizers_for_module()
gets called before variables were defined (example: variables defined in the build()
of the SACLearner
, then calling super().build()
before defining the variables in the SACLearner
would have led to a case where the Learner.build()
would call all configure_optimizers_for_module
methods that were overriden by the SACTorchLearner
and needed variables defined in the SACLearner
- and these were not yet defined. So first define variables then call build()
and to keep the knowledge I added it here.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, let's test this (and change only, if possible) before we merge ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense. So it's not about whether to call super, but about when to call it.
The call to super must be after(!) all parameters are available (for the optimizers to know).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly! Let's keep this in mind for documentation! I am sure users who derive from this class will run into it.
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks @simonsays1980 :)
This PR is coded alongside #42568 and should implement the learner logic needed to transfer 'SAC' over to our new stack.
Why are these changes needed?
Our major RL algorithms should be transferred to our new stack, which offers higher modularity and customizability to users.
Related issue number
Closes #37778
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.