Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Policies get/set_state fixes and enhancements. #16354

Merged
merged 8 commits into from
Jun 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/ray/tune/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def get_state(self):
def save(self, checkpoint_dir=None):
"""Saves the current model state to a checkpoint.

Subclasses should override ``_save()`` instead to save state.
Subclasses should override ``save_checkpoint()`` instead to save state.
This method dumps additional metadata alongside the saved path.

Args:
Expand Down
7 changes: 7 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1286,6 +1286,13 @@ py_test(
srcs = ["policy/tests/test_compute_log_likelihoods.py"]
)

py_test(
name = "policy/tests/test_policy",
tags = ["policy"],
size = "medium",
srcs = ["policy/tests/test_policy.py"]
)

py_test(
name = "policy/tests/test_sample_batch",
tags = ["policy"],
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/ddpg/tests/test_apex_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_apex_ddpg_compilation_and_per_worker_epsilon_values(self):

# Test per-worker scale distribution.
infos = trainer.workers.foreach_policy(
lambda p, _: p.get_exploration_info())
lambda p, _: p.get_exploration_state())
scale = [i["cur_scale"] for i in infos]
expected = [
0.4**(1 + (i + 1) / float(config["num_workers"] - 1) * 7)
Expand All @@ -46,7 +46,7 @@ def test_apex_ddpg_compilation_and_per_worker_epsilon_values(self):
# Test again per-worker scale distribution
# (should not have changed).
infos = trainer.workers.foreach_policy(
lambda p, _: p.get_exploration_info())
lambda p, _: p.get_exploration_state())
scale = [i["cur_scale"] for i in infos]
check(scale, [0.0] + expected)

Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/dqn/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def add_apex_metrics(result: dict) -> dict:
replay_stats = ray.get(replay_actors[0].stats.remote(
config["optimizer"].get("debug")))
exploration_infos = workers.foreach_trainable_policy(
lambda p, _: p.get_exploration_info())
lambda p, _: p.get_exploration_state())
result["info"].update({
"exploration_infos": exploration_infos,
"learner_queue": learner_thread.learner_queue_size.stats(),
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/dqn/tests/test_apex_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_apex_dqn_compilation_and_per_worker_epsilon_values(self):

# Test per-worker epsilon distribution.
infos = trainer.workers.foreach_policy(
lambda p, _: p.get_exploration_info())
lambda p, _: p.get_exploration_state())
expected = [0.4, 0.016190862, 0.00065536]
check([i["cur_epsilon"] for i in infos], [0.0] + expected)

Expand All @@ -55,7 +55,7 @@ def test_apex_dqn_compilation_and_per_worker_epsilon_values(self):
# Test again per-worker epsilon distribution
# (should not have changed).
infos = trainer.workers.foreach_policy(
lambda p, _: p.get_exploration_info())
lambda p, _: p.get_exploration_state())
check([i["cur_epsilon"] for i in infos], [0.0] + expected)

trainer.stop()
Expand Down
4 changes: 2 additions & 2 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,7 +1083,7 @@ def get_filters(self, flush_after: bool = False) -> dict:
return return_filters

@DeveloperAPI
def save(self) -> str:
def save(self) -> bytes:
filters = self.get_filters(flush_after=True)
state = {
pid: self.policy_map[pid].get_state()
Expand All @@ -1092,7 +1092,7 @@ def save(self) -> str:
return pickle.dumps({"filters": filters, "state": state})

@DeveloperAPI
def restore(self, objs: str) -> None:
def restore(self, objs: bytes) -> None:
objs = pickle.loads(objs)
self.sync_filters(objs["filters"])
for pid, state in objs["state"].items():
Expand Down
18 changes: 12 additions & 6 deletions rllib/policy/eager_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,8 +617,8 @@ def apply_gradients(self, gradients):
for g in gradients], self.model.trainable_variables()))

@override(Policy)
def get_exploration_info(self):
return _convert_to_numpy(self.exploration.get_info())
def get_exploration_state(self):
return _convert_to_numpy(self.exploration.get_state())

@override(Policy)
def get_weights(self, as_dict=False):
Expand All @@ -637,18 +637,20 @@ def set_weights(self, weights):

@override(Policy)
def get_state(self):
state = {"_state": super().get_state()}
state = super().get_state()
if self._optimizer and \
len(self._optimizer.variables()) > 0:
state["_optimizer_variables"] = \
self._optimizer.variables()
# Add exploration state.
state["_exploration_state"] = self.exploration.get_state()
return state

@override(Policy)
def set_state(self, state):
state = state.copy() # shallow copy
# Set optimizer vars first.
optimizer_vars = state.pop("_optimizer_variables", None)
optimizer_vars = state.get("_optimizer_variables", None)
if optimizer_vars and self._optimizer.variables():
logger.warning(
"Cannot restore an optimizer's state for tf eager! Keras "
Expand All @@ -658,8 +660,11 @@ def set_state(self, state):
for opt_var, value in zip(self._optimizer.variables(),
optimizer_vars):
opt_var.assign(value)
# Set exploration's state.
if hasattr(self, "exploration") and "_exploration_state" in state:
self.exploration.set_state(state=state["_exploration_state"])
# Then the Policy's (NN) weights.
super().set_state(state["_state"])
super().set_state(state)

def variables(self):
"""Return the list of all savable variables for this policy."""
Expand Down Expand Up @@ -698,9 +703,10 @@ def loss_initialized(self):
def export_model(self, export_dir):
pass

# TODO: (sven) Deprecate this in favor of `save()`.
@override(Policy)
def export_checkpoint(self, export_dir):
pass
deprecation_warning("export_checkpoint", "save")

def _get_is_training_placeholder(self):
return tf.convert_to_tensor(self._is_training)
Expand Down
45 changes: 29 additions & 16 deletions rllib/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.exploration.exploration import Exploration
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.from_config import from_config
Expand Down Expand Up @@ -424,7 +425,7 @@ def set_weights(self, weights: ModelWeights) -> None:
raise NotImplementedError

@DeveloperAPI
def get_exploration_info(self) -> Dict[str, TensorType]:
def get_exploration_state(self) -> Dict[str, TensorType]:
"""Returns the current exploration information of this policy.

This information depends on the policy's Exploration object.
Expand All @@ -433,7 +434,12 @@ def get_exploration_info(self) -> Dict[str, TensorType]:
Dict[str, TensorType]: Serializable information on the
`self.exploration` object.
"""
return self.exploration.get_info()
return self.exploration.get_state()

# TODO: (sven) Deprecate this method.
def get_exploration_info(self) -> Dict[str, TensorType]:
deprecation_warning("get_exploration_info", "get_exploration_state")
return self.get_exploration_state()

@DeveloperAPI
def is_recurrent(self) -> bool:
Expand Down Expand Up @@ -464,22 +470,28 @@ def get_initial_state(self) -> List[TensorType]:

@DeveloperAPI
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
"""Saves all local state.
"""Returns all local state.

Returns:
Union[Dict[str, TensorType], List[TensorType]]: Serialized local
state.
"""
return self.get_weights()
state = {
"weights": self.get_weights(),
"global_timestep": self.global_timestep,
}
return state

@DeveloperAPI
def set_state(self, state: object) -> None:
"""Restores all local state.
"""Restores all local state to the provided `state`.

Args:
state (obj): Serialized local state.
state (object): The new state to set this policy to. Can be
obtained by calling `Policy.get_state()`.
"""
self.set_weights(state)
self.set_weights(state["weights"])
self.global_timestep = state["global_timestep"]

@DeveloperAPI
def on_global_var_update(self, global_vars: Dict[str, TensorType]) -> None:
Expand All @@ -506,15 +518,6 @@ def export_model(self, export_dir: str) -> None:
"""
raise NotImplementedError

@DeveloperAPI
def export_checkpoint(self, export_dir: str) -> None:
"""Export Policy checkpoint to local directory.

Args:
export_dir (str): Local writable directory.
"""
raise NotImplementedError

@DeveloperAPI
def import_model_from_h5(self, import_file: str) -> None:
"""Imports Policy from local file.
Expand Down Expand Up @@ -810,6 +813,16 @@ def _update_model_view_requirements_from_init_state(self):
view_reqs["state_out_{}".format(i)] = ViewRequirement(
space=space, used_for_training=True)

# TODO: (sven) Deprecate this in favor of `save()`.
def export_checkpoint(self, export_dir: str) -> None:
"""Export Policy checkpoint to local directory.

Args:
export_dir (str): Local writable directory.
"""
deprecation_warning("export_checkpoint", "save")
raise NotImplementedError


def clip_action(action, action_space):
"""Clips all actions in `flat_actions` according to the given Spaces.
Expand Down
43 changes: 43 additions & 0 deletions rllib/policy/tests/test_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import unittest

import ray
from ray.rllib.agents.dqn import DQNTrainer, DEFAULT_CONFIG
from ray.rllib.utils.test_utils import check, framework_iterator


class TestPolicy(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()

@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()

def test_policy_save_restore(self):
config = DEFAULT_CONFIG.copy()
for _ in framework_iterator(config):
trainer = DQNTrainer(config=config, env="CartPole-v0")
policy = trainer.get_policy()
state1 = policy.get_state()
trainer.train()
state2 = policy.get_state()
check(
state1["_exploration_state"]["last_timestep"],
state2["_exploration_state"]["last_timestep"],
false=True)
check(
state1["global_timestep"],
state2["global_timestep"],
false=True)
# Reset policy to its original state and compare.
policy.set_state(state1)
state3 = policy.get_state()
# Make sure everything is the same.
check(state1, state3)


if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))
25 changes: 20 additions & 5 deletions rllib/policy/tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.framework import try_import_tf, get_variable
from ray.rllib.utils.schedules import PiecewiseSchedule
from ray.rllib.utils.tf_run_builder import TFRunBuilder
Expand Down Expand Up @@ -478,8 +479,13 @@ def apply_gradients(self, gradients: ModelGradients) -> None:

@override(Policy)
@DeveloperAPI
def get_exploration_state(self) -> Dict[str, TensorType]:
return self.exploration.get_state(sess=self.get_session())

# TODO: (sven) Deprecate this method.
def get_exploration_info(self) -> Dict[str, TensorType]:
return self.exploration.get_info(sess=self.get_session())
deprecation_warning("get_exploration_info", "get_exploration_state")
return self.get_exploration_state()

@override(Policy)
@DeveloperAPI
Expand All @@ -500,17 +506,24 @@ def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
len(self._optimizer_variables.variables) > 0:
state["_optimizer_variables"] = \
self._sess.run(self._optimizer_variables.variables)
# Add exploration state.
state["_exploration_state"] = \
self.exploration.get_state(self.get_session())
return state

@override(Policy)
@DeveloperAPI
def set_state(self, state) -> None:
state = state.copy() # shallow copy
def set_state(self, state: dict) -> None:
# Set optimizer vars first.
optimizer_vars = state.pop("_optimizer_variables", None)
optimizer_vars = state.get("_optimizer_variables", None)
if optimizer_vars:
self._optimizer_variables.set_weights(optimizer_vars)
# Then the Policy's (NN) weights.
# Set exploration's state.
if hasattr(self, "exploration") and "_exploration_state" in state:
self.exploration.set_state(
state=state["_exploration_state"], sess=self.get_session())

# Set the Policy's (NN) weights.
super().set_state(state)

@override(Policy)
Expand All @@ -527,12 +540,14 @@ def export_model(self, export_dir: str) -> None:
graph=self._sess.graph))
builder.save()

# TODO: (sven) Deprecate this in favor of `save()`.
@override(Policy)
@DeveloperAPI
def export_checkpoint(self,
export_dir: str,
filename_prefix: str = "model") -> None:
"""Export tensorflow checkpoint to export_dir."""
deprecation_warning("export_checkpoint", "save")
try:
os.makedirs(export_dir)
except OSError as e:
Expand Down
Loading