Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] APPO/IMPALA: Enable using 2 separate optimizers for policy and vs (and 2 learning rates) on the old API stack. #40927

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented Nov 3, 2023

APPO/IMPALA: Enable using 2 separate optimizers for policy and value function (and 2 learning rates) on the old API stack.

Note that this feature had already existed for tf/tf2, but not for torch.

  • Added additional learning tests for APPO (torch) and Impala (tf/tf2 + torch).

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: sven1977 <svenmika1977@gmail.com>
…_torch_old_stack_enable_two_optimizers_two_lrs
…_torch_old_stack_enable_two_optimizers_two_lrs
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Comment on lines +181 to +203
# Figure out, which parameters of the model belong to the value
# function (and which to the policy net).
dummy_batch = self._lazy_tensor_dict(
self._get_dummy_batch_from_view_requirements()
)
# Zero out all gradients (set to None)
for param in self.model.parameters():
param.grad = None
# Perform a dummy forward pass (through the policy net, which should be
# separated from the value function in this particular user setup).
out = self.model(dummy_batch)
# Perform a (dummy) backward pass to be able to see, which params have
# gradients and are therefore used for the policy computations (vs vf
# computations).
torch.sum(out[0]).backward() # [0] -> Model returns out and state-outs.
# Collect policy vs value function params separately.
policy_params = []
value_params = []
for param in self.model.parameters():
if param.grad is None:
value_params.append(param)
else:
policy_params.append(param)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the need for this. Why can't you directly index to the model and ask it to give you .value.parameters() and .policy.parameters()? There should be a better way than treating self.model as a blackbox with only the knowledge hat if I do forward pass on the model directly it will use the parameters that are used for policy. Also what if there are shared parameters between the value and policy components? This will lump them up into the policy's optimizer. They won't get updated based on the loss from value function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, the problem here is that the API is NOT defined at all and some users might have self.policy, others self.policy_net, etc..
The only thing that is required for you if you want a value function to be present is to implement the self.value_function() method. Take a look at our torch default models (ModelV2). They are all different in how they store the (separate) value sub-networks. It's quite a mess. I'm with you that this is not the normal way we should solve this, but since this is old API stack, which will get 100% retired very soon, I'm personally fine with this. Suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation, I figured that might be the reason. We should be explicit about this in the comments

@@ -229,6 +229,7 @@ def _import_leela_chess_zero():
"DreamerV3": _import_dreamerv3,
"DT": _import_dt,
"IMPALA": _import_impala,
"Impala": _import_impala,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, where is this coming from? It will mess up with our telemetry.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me explain: We added a new test case in this PS, which is IMPALA (separate policy and vf) on cartpole. The new tuned_example file is a python file (I'm trying to create as few new yamls as possible nowadays). Hence, in there I'm using the ImpalaConfig() class/object. It seems to not work well with tune.run_experiment for whatever reason.

I didn't think about telemetry. Let me see, whether there is a better way that would not break things ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, ok, this is the culprit here (in rllib/train.py).

        experiments = {
            f"default_{uuid.uuid4().hex}": {
                "run": algo_config.__class__.__name__.replace("Config", ""),
                "env": config.get("env"),
                "config": config,
                "stop": stop,
            }
        }

Ok, let me provide a better fix.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah We should import the name directly from the registry if possible or avoid run.run_experiments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All algo config objects know what their corresponding algo class is, so this is solved now much more elegantly.

Copy link
Contributor

@kouroshHakha kouroshHakha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One big comment about how the value vs policy parameters are retrieved.

Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
@sven1977
Copy link
Contributor Author

sven1977 commented Nov 3, 2023

Fixed. @kouroshHakha, thanks for the review! Please take another look.

@kouroshHakha
Copy link
Contributor

the tests are failing. Let's hold onto merging until the issue is resolved. @can-anyscale Can you tell me what is wrong with the tests? All rllib-tests are complaining about a grpc plugin missing.
https://buildkite.com/ray-project/premerge/builds/10829#_

kouroshHakha and others added 6 commits November 3, 2023 12:30
Signed-off-by: sven1977 <svenmika1977@gmail.com>
…optimizers_two_lrs' into appo_torch_old_stack_enable_two_optimizers_two_lrs
…_torch_old_stack_enable_two_optimizers_two_lrs
…optimizers_two_lrs' into appo_torch_old_stack_enable_two_optimizers_two_lrs
@sven1977 sven1977 merged commit 8711328 into ray-project:master Nov 4, 2023
24 of 26 checks passed
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Nov 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants