Skip to content

Commit

Permalink
[RLlib] Issue 20920 (partial solution): contrib/MADDPG + pettingzoo c…
Browse files Browse the repository at this point in the history
…oop-pong-v4 not working. (#21452)
  • Loading branch information
sven1977 committed Jan 10, 2022
1 parent f8244a4 commit b10d553
Show file tree
Hide file tree
Showing 20 changed files with 23 additions and 3 deletions.
1 change: 1 addition & 0 deletions rllib/agents/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def get_default_config(cls) -> TrainerConfigDict:

@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)

if config["entropy_coeff"] < 0:
Expand Down
1 change: 1 addition & 0 deletions rllib/agents/ars/ars.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def get_default_config(cls) -> TrainerConfigDict:

@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)

if config["num_gpus"] > 1:
Expand Down
1 change: 1 addition & 0 deletions rllib/agents/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def get_default_config(cls) -> TrainerConfigDict:

@override(SACTrainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)

if config["num_gpus"] > 1:
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def get_default_policy_class(self,

@override(SimpleQTrainer)
def validate_config(self, config: TrainerConfigDict) -> None:

# Call super's validation method.
super().validate_config(config)

if config["model"]["custom_model"]:
Expand Down
1 change: 1 addition & 0 deletions rllib/agents/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def get_default_config(cls) -> TrainerConfigDict:

@override(SimpleQTrainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)

# Update effective batch size to include n-step
Expand Down
1 change: 1 addition & 0 deletions rllib/agents/dqn/r2d2.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def validate_config(self, config: TrainerConfigDict) -> None:
Rewrites rollout_fragment_length to take into account burn-in and
max_seq_len truncation.
"""
# Call super's validation method.
super().validate_config(config)

if config["replay_sequence_length"] != -1:
Expand Down
1 change: 1 addition & 0 deletions rllib/agents/dqn/simple_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def get_default_config(cls) -> TrainerConfigDict:
def validate_config(self, config: TrainerConfigDict) -> None:
"""Checks and updates the config based on settings.
"""
# Call super's validation method.
super().validate_config(config)

if config["exploration_config"]["type"] == "ParameterNoise":
Expand Down
1 change: 1 addition & 0 deletions rllib/agents/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def get_default_config(cls) -> TrainerConfigDict:

@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)

config["action_repeat"] = config["env_config"]["frame_skip"]
Expand Down
1 change: 1 addition & 0 deletions rllib/agents/es/es.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def get_default_config(cls) -> TrainerConfigDict:

@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)

if config["num_gpus"] > 1:
Expand Down
1 change: 1 addition & 0 deletions rllib/agents/maml/maml.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def get_default_config(cls) -> TrainerConfigDict:

@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)

if config["num_gpus"] > 1:
Expand Down
1 change: 1 addition & 0 deletions rllib/agents/marwil/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def get_default_config(cls) -> TrainerConfigDict:

@override(MARWILTrainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)

if config["beta"] != 0.0:
Expand Down
1 change: 1 addition & 0 deletions rllib/agents/marwil/marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def get_default_config(cls) -> TrainerConfigDict:

@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)

if config["num_gpus"] > 1:
Expand Down
3 changes: 3 additions & 0 deletions rllib/agents/mbmpo/mbmpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,9 @@ def get_default_config(cls) -> TrainerConfigDict:

@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)

if config["num_gpus"] > 1:
raise ValueError("`num_gpus` > 1 not yet supported for MB-MPO!")
if config["framework"] != "torch":
Expand Down
1 change: 1 addition & 0 deletions rllib/agents/qmix/qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def get_default_config(cls) -> TrainerConfigDict:

@override(SimpleQTrainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)

if config["framework"] != "torch":
Expand Down
1 change: 1 addition & 0 deletions rllib/agents/sac/rnnsac.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def get_default_config(cls) -> TrainerConfigDict:

@override(SACTrainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)

if config["replay_sequence_length"] != -1:
Expand Down
1 change: 1 addition & 0 deletions rllib/agents/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def get_default_config(cls) -> TrainerConfigDict:

@override(DQNTrainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)

if config["use_state_preprocessor"] != DEPRECATED_VALUE:
Expand Down
1 change: 1 addition & 0 deletions rllib/agents/slateq/slateq.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def get_default_config(cls) -> TrainerConfigDict:

@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)

if config["num_gpus"] > 1:
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/trainer_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _init(self, config: TrainerConfigDict,

@override(Trainer)
def validate_config(self, config: PartialTrainerConfigDict):
# Call super (Trainer) validation method first.
# Call super's validation method.
Trainer.validate_config(self, config)
# Then call user defined one, if any.
if validate_config is not None:
Expand Down
2 changes: 2 additions & 0 deletions rllib/contrib/maddpg/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def validate_config(self, config: TrainerConfigDict) -> None:
This hook is called explicitly prior to TrainOneStep() in the execution
setups for DQN and APEX.
"""
# Call super's validation method.
super().validate_config(config)

def f(batch, workers, config):
policies = dict(workers.local_worker()
Expand Down
2 changes: 1 addition & 1 deletion rllib/evaluation/worker_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def valid_module(class_path):
ma_policies = config["multiagent"]["policies"]
if ma_policies:
for pid, policy_spec in ma_policies.copy().items():
assert isinstance(policy_spec, (PolicySpec, list, tuple))
assert isinstance(policy_spec, PolicySpec)
# Class is None -> Use `policy_cls`.
if policy_spec.policy_class is None:
ma_policies[pid] = ma_policies[pid]._replace(
Expand Down

0 comments on commit b10d553

Please sign in to comment.