-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Exploration API: Policy changes needed for forward pass noisifications. #7798
[RLlib] Exploration API: Policy changes needed for forward pass noisifications. #7798
Conversation
…oration_API_parameter_noise_api_only � Conflicts: � rllib/policy/policy.py � rllib/tests/test_checkpoint_restore.py
…oration_API_parameter_noise_api_only � Conflicts: � rllib/agents/dqn/tests/test_dqn.py
Co-Authored-By: Eric Liang <ekhliang@gmail.com>
Co-Authored-By: Eric Liang <ekhliang@gmail.com>
…oration_api_minimal_param_noise
Co-Authored-By: Eric Liang <ekhliang@gmail.com>
Co-Authored-By: Eric Liang <ekhliang@gmail.com>
…com/sven1977/ray into exploration_api_minimal_param_noise � Conflicts: � rllib/utils/exploration/exploration.py
…oration_api_minimal_param_noise
@@ -95,17 +92,8 @@ def compute_actions(self, | |||
**kwargs): | |||
return list(state_batches[0]), state_batches, {} | |||
|
|||
def learn_on_batch(self, samples): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider keeping these, there are there to make it clear these are no-ops.
run_heuristic_vs_learned(use_lstm=False) | ||
# run_with_custom_entropy_loss() | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--stop", type=int, default=1000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider keeping parser args at the top of the file by convention.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
run_same_policy(args) | ||
run_heuristic_vs_learned(args, use_lstm=True) | ||
run_heuristic_vs_learned(args, use_lstm=False) | ||
run_with_custom_entropy_loss(args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I can see this being a bit confusing to run as an example since there are four different runs in the output.
Could we at least add print()s in between?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I just fixed that test case. For some reason, 3 of these have always been commented out and weren't working. I'll add the print()s.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
timestep=timestep) | ||
else: | ||
# Exploration hook before each forward pass. | ||
self.exploration.before_compute_actions( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't the hook be called in all cases including action_sampler_fn?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure. I was thinking of action_sampler_fn
as a completely custom way of doing things. So the user would have to apply exploration him/herself.
That being said: No one is using action_sampler_fn
right now, anyway, so I guess it doesn't matter much.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, a few minor comments.
Cool, thanks! Will fix these. |
Co-Authored-By: Eric Liang <ekhliang@gmail.com>
Co-Authored-By: Eric Liang <ekhliang@gmail.com>
Co-Authored-By: Eric Liang <ekhliang@gmail.com>
Test FAILed. |
…://github.com/sven1977/ray into exploration_api_minimal_param_noise_2nd_stage
…oration_api_minimal_param_noise_2nd_stage
Test FAILed. |
Test PASSed. |
@ericl This can be merged now. Tests are all ok. |
Test PASSed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. looks like some merge conflicts though.
…oration_api_minimal_param_noise_2nd_stage � Conflicts: � rllib/policy/tf_policy.py � rllib/policy/torch_policy.py � rllib/policy/torch_policy_template.py
Merged, waiting for re-testing. |
@ericl Please merge. Re-tests ok after merging. |
Test FAILed. |
This PR contains the following Policy changes to allow different types of forward-pass noisifications:
action_sampler_fn
: Fully customized action sampling behavior, returning only actions and logp. If provided, the Policy will not automatically use its Exploration object. Exploration has to be done (customized) inside theaction_sampler_fn
!action_distribution_fn
: Customized distribution-inputs and -class generator. Policy's Exploration is used automatically (before_compute_action
andget_exploration_action
).This removes the need for a
log_likelihood_fn
and thus simplifies DQN and SAC.All
Policy.compute_actions()
methods now return (by default and if available) the extra_action_fetches key:ACTION_DIST_INPUTS
This is currently only used by the ParameterNoise class (but should be useful information for other cases as well).
The ActionDistribution object is now passed directly into
Exploration.get_exploration_action()
. Before, distribution-inputs and distribution-class were passed in separately.Each Exploration class requires the Model upon construction. Hence the Model is not longer passed into e.g.
Exploration.get_exploration_action
. This requires the Model being generated before the Exploration object in all Policy classes.Checks
scripts/format.sh
to lint the changes in this PR.