Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
import gym
import logging
import importlib.util
from types import FunctionType
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
import ray
from ray.actor import ActorHandle
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.env_context import EnvContext
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
ShuffledInput, D4RLReader
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.typing import EnvType, PolicyID, TrainerConfigDict
from ray.tune.registry import registry_contains_input, registry_get_input
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
# Generic type var for foreach_* methods.
T = TypeVar("T")
@DeveloperAPI
class WorkerSet:
"""Set of RolloutWorkers with n @ray.remote workers and one local worker.
Where n may be 0.
"""
def __init__(
self,
*,
env_creator: Optional[Callable[[EnvContext], EnvType]] = None,
validate_env: Optional[Callable[[EnvType], None]] = None,
policy_class: Optional[Type[Policy]] = None,
trainer_config: Optional[TrainerConfigDict] = None,
num_workers: int = 0,
local_worker: bool = True,
logdir: Optional[str] = None,
_setup: bool = True,
):
"""Initializes a WorkerSet instance.
Args:
env_creator: Function that returns env given env config.
validate_env: Optional callable to validate the generated
environment (only on worker=0).
policy_class: An optional Policy class. If None, PolicySpecs can be
generated automatically by using the Trainer's default class
of via a given multi-agent policy config dict.
trainer_config: Optional dict that extends the common config of
the Trainer class.
num_workers: Number of remote rollout workers to create.
local_worker: Whether to create a local (non @ray.remote) worker
in the returned set as well (default: True). If `num_workers`
is 0, always create a local worker.
logdir: Optional logging directory for workers.
_setup: Whether to setup workers. This is only for testing.
"""
if not trainer_config:
from ray.rllib.agents.trainer import COMMON_CONFIG
trainer_config = COMMON_CONFIG
self._env_creator = env_creator
self._policy_class = policy_class
self._remote_config = trainer_config
self._logdir = logdir
if _setup:
# Force a local worker if num_workers == 0 (no remote workers).
# Otherwise, this WorkerSet would be empty.
self._local_worker = None
if num_workers == 0:
local_worker = True
self._local_config = merge_dicts(
trainer_config,
{"tf_session_args": trainer_config["local_tf_session_args"]})
# Create a number of @ray.remote workers.
self._remote_workers = []
self.add_workers(num_workers)
# Create a local worker, if needed.
# If num_workers > 0 and we don't have an env on the local worker,
# get the observation- and action spaces for each policy from
# the first remote worker (which does have an env).
if local_worker and self._remote_workers and \
not trainer_config.get("create_env_on_driver") and \
(not trainer_config.get("observation_space") or
not trainer_config.get("action_space")):
remote_spaces = ray.get(self.remote_workers(
)[0].foreach_policy.remote(
lambda p, pid: (pid, p.observation_space, p.action_space)))
spaces = {
e[0]: (getattr(e[1], "original_space", e[1]), e[2])
for e in remote_spaces
}
# Try to add the actual env's obs/action spaces.
try:
env_spaces = ray.get(self.remote_workers(
)[0].foreach_env.remote(
lambda env: (env.observation_space, env.action_space))
)[0]
spaces["__env__"] = env_spaces
except Exception:
pass
logger.info("Inferred observation/action spaces from remote "
f"worker (local worker has no env): {spaces}")
else:
spaces = None
if local_worker:
self._local_worker = self._make_worker(
cls=RolloutWorker,
env_creator=env_creator,
validate_env=validate_env,
policy_cls=self._policy_class,
worker_index=0,
num_workers=num_workers,
config=self._local_config,
spaces=spaces,
)
def local_worker(self) -> RolloutWorker:
"""Returns the local rollout worker."""
return self._local_worker
def remote_workers(self) -> List[ActorHandle]:
"""Returns a list of remote rollout workers."""
return self._remote_workers
def sync_weights(self,
policies: Optional[List[PolicyID]] = None,
from_worker: Optional[RolloutWorker] = None) -> None:
"""Syncs model weights from the local worker to all remote workers.
Args:
policies: Optional list of PolicyIDs to sync weights for.
If None (default), sync weights to/from all policies.
from_worker: Optional RolloutWorker instance to sync from.
If None (default), sync from this WorkerSet's local worker.
"""
if self.local_worker() is None and from_worker is None:
raise TypeError(
"No `local_worker` in WorkerSet, must provide `from_worker` "
"arg in `sync_weights()`!")
# Only sync if we have remote workers or `from_worker` is provided.
if self.remote_workers() or from_worker is not None:
weights = (from_worker
or self.local_worker()).get_weights(policies)
# Put weights only once into object store and use same object
# ref to synch to all workers.
weights_ref = ray.put(weights)
# Sync to all remote workers in this WorkerSet.
for to_worker in self.remote_workers():
to_worker.set_weights.remote(weights_ref)
# If `from_worker` is provided, also sync to this WorkerSet's
# local worker.
if from_worker is not None and self.local_worker() is not None:
self.local_worker().set_weights(weights)
def add_workers(self, num_workers: int) -> None:
"""Creates and adds a number of remote workers to this worker set.
Can be called several times on the same WorkerSet to add more
RolloutWorkers to the set.
Args:
num_workers: The number of remote Workers to add to this
WorkerSet.
"""
remote_args = {
"num_cpus": self._remote_config["num_cpus_per_worker"],
"num_gpus": self._remote_config["num_gpus_per_worker"],
"resources": self._remote_config["custom_resources_per_worker"],
}
cls = RolloutWorker.as_remote(**remote_args).remote
self._remote_workers.extend([
self._make_worker(
cls=cls,
env_creator=self._env_creator,
validate_env=None,
policy_cls=self._policy_class,
worker_index=i + 1,
num_workers=num_workers,
config=self._remote_config,
) for i in range(num_workers)
])
def reset(self, new_remote_workers: List[ActorHandle]) -> None:
"""Hard overrides the remote workers in this set with the given one.
Args:
new_remote_workers: A list of new RolloutWorkers
(as `ActorHandles`) to use as remote workers.
"""
self._remote_workers = new_remote_workers
def stop(self) -> None:
"""Calls `stop` on all rollout workers (including the local one)."""
try:
self.local_worker().stop()
tids = [w.stop.remote() for w in self.remote_workers()]
ray.get(tids)
except Exception:
logger.exception("Failed to stop workers!")
finally:
for w in self.remote_workers():
w.__ray_terminate__.remote()
@DeveloperAPI
def foreach_worker(self, func: Callable[[RolloutWorker], T]) -> List[T]:
"""Calls the given function with each worker instance as arg.
Args:
func: The function to call for each worker (as only arg).
Returns:
The list of return values of all calls to `func([worker])`.
"""
local_result = []
if self.local_worker() is not None:
local_result = [func(self.local_worker())]
remote_results = ray.get(
[w.apply.remote(func) for w in self.remote_workers()])
return local_result + remote_results
@DeveloperAPI
def foreach_worker_with_index(
self, func: Callable[[RolloutWorker, int], T]) -> List[T]:
"""Calls `func` with each worker instance and worker idx as args.
The index will be passed as the second arg to the given function.
Args:
func: The function to call for each worker and its index
(as args). The local worker has index 0, all remote workers
have indices > 0.
Returns:
The list of return values of all calls to `func([worker, idx])`.
The first entry in this list are the results of the local
worker, followed by all remote workers' results.
"""
local_result = []
# Local worker: Index=0.
if self.local_worker() is not None:
local_result = [func(self.local_worker(), 0)]
# Remote workers: Index > 0.
remote_results = ray.get([
w.apply.remote(func, i + 1)
for i, w in enumerate(self.remote_workers())
])
return local_result + remote_results
@DeveloperAPI
def foreach_policy(self, func: Callable[[Policy, PolicyID], T]) -> List[T]:
"""Calls `func` with each worker's (policy, PolicyID) tuple.
Note that in the multi-agent case, each worker may have more than one
policy.
Args:
func: A function - taking a Policy and its ID - that is
called on all workers' Policies.
Returns:
The list of return values of func over all workers' policies. The
length of this list is:
(num_workers + 1 (local-worker)) *
[num policies in the multi-agent config dict].
The local workers' results are first, followed by all remote
workers' results
"""
results = []
if self.local_worker() is not None:
results = self.local_worker().foreach_policy(func)
ray_gets = []
for worker in self.remote_workers():
ray_gets.append(
worker.apply.remote(lambda w: w.foreach_policy(func)))
remote_results = ray.get(ray_gets)
for r in remote_results:
results.extend(r)
return results
@DeveloperAPI
def trainable_policies(self) -> List[PolicyID]:
"""Returns the list of trainable policy ids."""
if self.local_worker() is not None:
return self.local_worker().policies_to_train
else:
raise NotImplementedError
@DeveloperAPI
def foreach_trainable_policy(
self, func: Callable[[Policy, PolicyID], T]) -> List[T]:
"""Apply `func` to all workers' Policies iff in `policies_to_train`.
Args:
func: A function - taking a Policy and its ID - that is
called on all workers' Policies in `worker.policies_to_train`.
Returns:
List[any]: The list of n return values of all
`func([trainable policy], [ID])`-calls.
"""
results = []
if self.local_worker() is not None:
results = self.local_worker().foreach_trainable_policy(func)
ray_gets = []
for worker in self.remote_workers():
ray_gets.append(
worker.apply.remote(
lambda w: w.foreach_trainable_policy(func)))
remote_results = ray.get(ray_gets)
for r in remote_results:
results.extend(r)
return results
@DeveloperAPI
def foreach_env(self, func: Callable[[EnvType], List[T]]) -> List[List[T]]:
"""Calls `func` with all workers' sub-environments as args.
An "underlying sub environment" is a single clone of an env within
a vectorized environment.
`func` takes a single underlying sub environment as arg, e.g. a
gym.Env object.
Args:
func: A function - taking an EnvType (normally a gym.Env object)
as arg and returning a list of lists of return values, one
value per underlying sub-environment per each worker.
Returns:
The list (workers) of lists (sub environments) of results.
"""
local_results = []
if self.local_worker() is not None:
local_results = [self.local_worker().foreach_env(func)]
ray_gets = []
for worker in self.remote_workers():
ray_gets.append(worker.foreach_env.remote(func))
return local_results + ray.get(ray_gets)
@DeveloperAPI
def foreach_env_with_context(
self,
func: Callable[[BaseEnv, EnvContext], List[T]]) -> List[List[T]]:
"""Calls `func` with all workers' sub-environments and env_ctx as args.
An "underlying sub environment" is a single clone of an env within
a vectorized environment.
`func` takes a single underlying sub environment and the env_context
as args.
Args:
func: A function - taking a BaseEnv object and an EnvContext as
arg - and returning a list of lists of return values over envs
of the worker.
Returns:
The list (1 item per workers) of lists (1 item per sub-environment)
of results.
"""
local_results = []
if self.local_worker() is not None:
local_results = [
self.local_worker().foreach_env_with_context(func)
]
ray_gets = []
for worker in self.remote_workers():
ray_gets.append(worker.foreach_env_with_context.remote(func))
return local_results + ray.get(ray_gets)
@staticmethod
def _from_existing(local_worker: RolloutWorker,
remote_workers: List[ActorHandle] = None):
workers = WorkerSet(
env_creator=None,
policy_class=None,
trainer_config={},
_setup=False)
workers._local_worker = local_worker
workers._remote_workers = remote_workers or []
return workers
def _make_worker(
self,
*,
cls: Callable,
env_creator: Callable[[EnvContext], EnvType],
validate_env: Optional[Callable[[EnvType], None]],
policy_cls: Type[Policy],
worker_index: int,
num_workers: int,
config: TrainerConfigDict,
spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
gym.spaces.Space]]] = None,
) -> Union[RolloutWorker, ActorHandle]:
def session_creator():
logger.debug("Creating TF session {}".format(
config["tf_session_args"]))
return tf1.Session(
config=tf1.ConfigProto(**config["tf_session_args"]))
def valid_module(class_path):
if isinstance(class_path, str) and "." in class_path:
module_path, class_name = class_path.rsplit(".", 1)
try:
spec = importlib.util.find_spec(module_path)
if spec is not None:
return True
except (ModuleNotFoundError, ValueError):
print(
f"module {module_path} not found while trying to get "
f"input {class_path}")
return False
if isinstance(config["input"], FunctionType):
input_creator = config["input"]
elif config["input"] == "sampler":
input_creator = (lambda ioctx: ioctx.default_sampler_input())
elif isinstance(config["input"], dict):
input_creator = (
lambda ioctx: ShuffledInput(MixedInput(config["input"], ioctx),
config["shuffle_buffer_size"]))
elif isinstance(config["input"], str) and \
registry_contains_input(config["input"]):
input_creator = registry_get_input(config["input"])
elif "d4rl" in config["input"]:
env_name = config["input"].split(".")[-1]
input_creator = (lambda ioctx: D4RLReader(env_name, ioctx))
elif valid_module(config["input"]):
input_creator = (lambda ioctx: ShuffledInput(from_config(
config["input"], ioctx=ioctx)))
else:
input_creator = (
lambda ioctx: ShuffledInput(JsonReader(config["input"], ioctx),
config["shuffle_buffer_size"]))
if isinstance(config["output"], FunctionType):
output_creator = config["output"]
elif config["output"] is None:
output_creator = (lambda ioctx: NoopOutput())
elif config["output"] == "logdir":
output_creator = (lambda ioctx: JsonWriter(
ioctx.log_dir,
ioctx,
max_file_size=config["output_max_file_size"],
compress_columns=config["output_compress_columns"]))
else:
output_creator = (lambda ioctx: JsonWriter(
config["output"],
ioctx,
max_file_size=config["output_max_file_size"],
compress_columns=config["output_compress_columns"]))
if config["input"] == "sampler":
input_evaluation = []
else:
input_evaluation = config["input_evaluation"]
# Assert everything is correct in "multiagent" config dict (if given).
ma_policies = config["multiagent"]["policies"]
if ma_policies:
for pid, policy_spec in ma_policies.copy().items():
assert isinstance(policy_spec, PolicySpec)
# Class is None -> Use `policy_cls`.
if policy_spec.policy_class is None:
ma_policies[pid] = ma_policies[pid]._replace(
policy_class=policy_cls)
policies = ma_policies
# Create a policy_spec (MultiAgentPolicyConfigDict),
# even if no "multiagent" setup given by user.
else:
policies = policy_cls
if worker_index == 0:
extra_python_environs = config.get(
"extra_python_environs_for_driver", None)
else:
extra_python_environs = config.get(
"extra_python_environs_for_worker", None)
worker = cls(
env_creator=env_creator,
validate_env=validate_env,
policy_spec=policies,
policy_mapping_fn=config["multiagent"]["policy_mapping_fn"],
policies_to_train=config["multiagent"]["policies_to_train"],
tf_session_creator=(session_creator
if config["tf_session_args"] else None),
rollout_fragment_length=config["rollout_fragment_length"],
count_steps_by=config["multiagent"]["count_steps_by"],
batch_mode=config["batch_mode"],
episode_horizon=config["horizon"],
preprocessor_pref=config["preprocessor_pref"],
sample_async=config["sample_async"],
compress_observations=config["compress_observations"],
num_envs=config["num_envs_per_worker"],
observation_fn=config["multiagent"]["observation_fn"],
observation_filter=config["observation_filter"],
clip_rewards=config["clip_rewards"],
normalize_actions=config["normalize_actions"],
clip_actions=config["clip_actions"],
env_config=config["env_config"],
policy_config=config,
worker_index=worker_index,
num_workers=num_workers,
record_env=config["record_env"],
log_dir=self._logdir,
log_level=config["log_level"],
callbacks=config["callbacks"],
input_creator=input_creator,
input_evaluation=input_evaluation,
output_creator=output_creator,
remote_worker_envs=config["remote_worker_envs"],
remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
soft_horizon=config["soft_horizon"],
no_done_at_end=config["no_done_at_end"],
seed=(config["seed"] + worker_index)
if config["seed"] is not None else None,
fake_sampler=config["fake_sampler"],
extra_python_environs=extra_python_environs,
spaces=spaces,
)
return worker