Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fix policy_to_train logic for new API stack generically for all algorithms. #41529

Merged
merged 27 commits into from
Dec 9, 2023

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented Nov 30, 2023

This PR solves a couple of problems related to APPO/IMPALA on the new API stack:

  • It fixes the policy_to_train logic for the new API stack, generically for all algorithms. No more specific logic will be required in each algorithm's training_step method as the mail filter step moves completely into LearnerGroup.
  • It renames "is_module_trainable" into a more clear: "should_module_be_updated_fn", which indicates that the control here lies on a much higher level than the module itself, namely in the LearnerGroup, who decides, which batches to filter out before sending them to the n Learner workers.
  • It cleans up and enhances the our self-play example script by properly setting a stricter learning goal (via the league size as opposed to the win-rate, which did NOT capture the learning progress as any win rate can be achieved vs random) and allowing for running with APPO (and PPO). The learning criterium has been made a little stricter by requiring now to at least reach the first win-rate threshold (at which point a new Policy/RLModule is inserted into the league).
  • Fixes a bug in APPO/IMPALA on the new stack related to vtrace and sequence zero-padding.
  • INFOs (on the new stack) get properly zero-padded with empty dicts (instead of None).

The new API stack should now work properly with APPO and IMPALA as an additional CI test cases prove.

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
main = "examples/self_play_with_open_spiel.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "small",
srcs = ["examples/self_play_with_open_spiel.py"],
args = ["--framework=tf", "--env=connect_four", "--win-rate-threshold=0.9", "--num-episodes-human-play=0", "--as-test", "--min-win-rate=0.6"]
args = ["--framework=tf", "--env=connect_four", "--win-rate-threshold=0.9", "--num-episodes-human-play=0", "--as-test", "--min-league-size=3"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see below on why min-win-rate does not make sense

@@ -778,7 +778,27 @@ def setup(self, config: AlgorithmConfig) -> None:
modules_to_load=modules_to_load,
rl_module_ckpt_dirs=rl_module_ckpt_dirs,
)
# sync the weights from the learner group to the rollout workers
# Setup proper policies-to-train/shoul-module-be-updated functions
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ground truth for the "is module trainable" (now renamed into the more clear: "should_module_be_updated_fn") is now the learner group. This is due to the fact that it might very well be in the future that we do NOT need this information at all anymore on the EnvRunners (RolloutWorkers) as these are only concerned with sampling and not training/updating (separation of concerns).
Either way, for now, both LearnerGroup AND RolloutWorkers will carry this information properly at all times.

@@ -23,7 +23,7 @@
from ray.rllib.core.learner.learner import LearnerHyperparameters
from ray.rllib.core.learner.learner_group_config import LearnerGroupConfig, ModuleSpec
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.core.rl_module.rl_module import ModuleID, SingleAgentRLModuleSpec
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this into utils.typing for better structure of the lib.

@@ -169,8 +169,6 @@ def training_step(self) -> ResultDict:
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()

# Updating the policy.
is_module_trainable = self.workers.local_worker().is_policy_to_train
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These algo specific "hacks" are not needed anymore, which makes the algo specific training_step code a little easier to read.

@@ -437,21 +437,6 @@ def training_step(self) -> ResultDict:
if self.config._enable_new_api_stack:
# TODO (Kourosh) Clearly define what train_batch_size
# vs. sgd_minibatch_size and num_sgd_iter is in the config.
# TODO (Kourosh) Do this inside the Learner so that we don't have to do
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same: see BC above

# lengths in B. See SampleBatch for more information.
if (
self.module[pid].is_stateful()
or policy_batch.get("seq_lens") is not None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug fix: For APPO, we do have SEQ_LENS and need to slice on the B-axis (not T-axis), but the Module might still not be stateful.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does policy_batch.get("seq_lens") is not None cover self.module[pid].is_stateful() or no?

Copy link
Contributor Author

@sven1977 sven1977 Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think these are different concepts. You might have a stateful model that does NOT operate on (time series) sequences (e.g. DreamerV3; well, it does operate on sequences but does NOT need the seq_lens key) or vice-versa (like transformers, operating on sequences but statelessly).

@@ -51,10 +51,13 @@ def _get_backend_config(learner_class: Type["Learner"]) -> str:
return backend_config


def _is_module_trainable(module_id: ModuleID, batch: MultiAgentBatch) -> bool:
"""Default implemntation for is_module_trainable()
def _default_should_module_be_updated_fn(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleaner (more descriptive) naming.

With the old name, one could be confused about the module itself not being trainable (frozen weights, etc..?). However, the control here happens on a much higher level (LearnerGroup decides on filtering single-agent batches for modules that should not be updated before(!) even sending individual batches to the individual Learner workers).

@@ -433,25 +436,6 @@ def remove_module(self, module_id: ModuleID) -> None:
refs.append(ref)
ray.get(refs)

def set_weights(self, weights: Mapping[str, Any]) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this to the right place: getter before setter

Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
# lengths in B. See SampleBatch for more information.
if (
self.module[pid].is_stateful()
or policy_batch.get("seq_lens") is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does policy_batch.get("seq_lens") is not None cover self.module[pid].is_stateful() or no?

"""
if self.is_local:
return self._learner.get_state()
self._learner.set_module_state(weights)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume these are exactly copy pasted.

learner_state = self._get_results(results)[0]
return {
"learner_state": learner_state,
"should_module_be_updated_fn": self.should_module_be_updated_fn,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.

# If container given, construct a simple callable returning True
# if the ModuleID is found in the list/set of IDs.
elif not callable(should_module_be_updated_fn):
assert isinstance(should_module_be_updated_fn, (list, set, tuple)), (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't raise assertion errors for things that user have to take an action on. Only use assertion for internal violation of assumption, things that should point to a bug. When people report an assertion error we immediately know we have a bug an not mis-use.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right! Fixed.


self._should_module_be_updated_fn = should_module_be_updated_fn

# TODO (sven): Why did we chose to re-invent the wheel here and provide load/save
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair. I think it's this philosophy that things should be usable stand-alone.

rllib/algorithms/algorithm.py Outdated Show resolved Hide resolved
)

ray.init(
num_cpus=args.num_cpus or None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh ok, I think this is the pattern we should use for all the setup calls of all tests then

rllib/policy/rnn_sequencing.py Outdated Show resolved Hide resolved
sven1977 and others added 2 commits December 8, 2023 21:05
Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com>
Signed-off-by: Sven Mika <sven@anyscale.io>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
@sven1977 sven1977 merged commit 2307fd1 into ray-project:master Dec 9, 2023
10 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants