forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 1
/
algorithm.py
2953 lines (2615 loc) · 122 KB
/
algorithm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import concurrent
import copy
import functools
import logging
import math
import os
import pickle
import tempfile
import time
import importlib
from collections import defaultdict
from datetime import datetime
from typing import (
Callable,
Container,
DefaultDict,
Dict,
List,
Optional,
Set,
Tuple,
Type,
Union,
)
import gym
import numpy as np
import pkg_resources
from packaging import version
import ray
from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag
from ray.actor import ActorHandle
from ray.exceptions import GetTimeoutError, RayActorError, RayError
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.registry import ALGORITHMS as ALL_ALGORITHMS
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.utils import _gym_env_creator
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.metrics import (
collect_episodes,
collect_metrics,
summarize_episodes,
)
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.common import (
STEPS_TRAINED_THIS_ITER_COUNTER, # TODO: Backward compatibility.
)
from ray.rllib.execution.parallel_requests import AsyncRequestsManager
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
from ray.rllib.execution.train_ops import multi_gpu_train_one_step, train_one_step
from ray.rllib.offline import get_offline_io_resource_bundles
from ray.rllib.offline.estimators import (
OffPolicyEstimator,
ImportanceSampling,
WeightedImportanceSampling,
DirectMethod,
DoublyRobust,
)
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch, concat_samples
from ray.rllib.utils import deep_update, FilterManager, merge_dicts
from ray.rllib.utils.annotations import (
DeveloperAPI,
ExperimentalAPI,
OverrideToImplementCustomLogic,
OverrideToImplementCustomLogic_CallToSuperRecommended,
PublicAPI,
override,
)
from ray.rllib.utils.debug import update_global_seed_if_necessary
from ray.rllib.utils.deprecation import (
DEPRECATED_VALUE,
Deprecated,
deprecation_warning,
)
from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_SAMPLED,
NUM_AGENT_STEPS_SAMPLED_THIS_ITER,
NUM_AGENT_STEPS_TRAINED,
NUM_ENV_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED_THIS_ITER,
NUM_ENV_STEPS_TRAINED,
SYNCH_WORKER_WEIGHTS_TIMER,
TRAINING_ITERATION_TIMER,
)
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
from ray.rllib.utils.pre_checks.multi_agent import check_multi_agent
from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer
from ray.rllib.utils.spaces import space_utils
from ray.rllib.utils.typing import (
AgentID,
AlgorithmConfigDict,
EnvCreator,
EnvInfoDict,
EnvType,
EpisodeID,
PartialAlgorithmConfigDict,
PolicyID,
PolicyState,
ResultDict,
SampleBatchType,
TensorStructType,
TensorType,
)
from ray.tune.logger import Logger, UnifiedLogger
from ray.tune.registry import ENV_CREATOR, _global_registry
from ray.tune.resources import Resources
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.trainable import Trainable
from ray.tune.experiment.trial import ExportFormat
from ray.tune.execution.placement_groups import PlacementGroupFactory
from ray.util import log_once
from ray.util.timer import _Timer
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
@DeveloperAPI
def with_common_config(extra_config: PartialAlgorithmConfigDict) -> AlgorithmConfigDict:
"""Returns the given config dict merged with common agent confs.
Args:
extra_config: A user defined partial config
which will get merged with a default AlgorithmConfig() object and returned
as plain python dict.
Returns:
AlgorithmConfigDict: The merged config dict resulting from AlgorithmConfig()
plus `extra_config`.
"""
return Algorithm.merge_trainer_configs(
AlgorithmConfig().to_dict(), extra_config, _allow_unknown_configs=True
)
@PublicAPI
class Algorithm(Trainable):
"""An RLlib algorithm responsible for optimizing one or more Policies.
Algorithms contain a WorkerSet under `self.workers`. A WorkerSet is
normally composed of a single local worker
(self.workers.local_worker()), used to compute and apply learning updates,
and optionally one or more remote workers (self.workers.remote_workers()),
used to generate environment samples in parallel.
Each worker (remotes or local) contains a PolicyMap, which itself
may contain either one policy for single-agent training or one or more
policies for multi-agent training. Policies are synchronized
automatically from time to time using ray.remote calls. The exact
synchronization logic depends on the specific algorithm used,
but this usually happens from local worker to all remote workers and
after each training update.
You can write your own Algorithm classes by sub-classing from `Algorithm`
or any of its built-in sub-classes.
This allows you to override the `execution_plan` method to implement
your own algorithm logic. You can find the different built-in
algorithms' execution plans in their respective main py files,
e.g. rllib.algorithms.dqn.dqn.py or rllib.algorithms.impala.impala.py.
The most important API methods a Algorithm exposes are `train()`,
`evaluate()`, `save()` and `restore()`.
"""
# Whether to allow unknown top-level config keys.
_allow_unknown_configs = False
# List of top-level keys with value=dict, for which new sub-keys are
# allowed to be added to the value dict.
_allow_unknown_subkeys = [
"tf_session_args",
"local_tf_session_args",
"env_config",
"model",
"optimizer",
"multiagent",
"custom_resources_per_worker",
"evaluation_config",
"exploration_config",
"replay_buffer_config",
"extra_python_environs_for_worker",
"input_config",
"output_config",
]
# List of top level keys with value=dict, for which we always override the
# entire value (dict), iff the "type" key in that value dict changes.
_override_all_subkeys_if_type_changes = [
"exploration_config",
"replay_buffer_config",
]
# List of keys that are always fully overridden if present in any dict or sub-dict
_override_all_key_list = ["off_policy_estimation_methods"]
_progress_metrics = [
"episode_reward_mean",
"evaluation/episode_reward_mean",
"num_env_steps_sampled",
"num_env_steps_trained",
]
@PublicAPI
def __init__(
self,
config: Optional[Union[PartialAlgorithmConfigDict, AlgorithmConfig]] = None,
env: Optional[Union[str, EnvType]] = None,
logger_creator: Optional[Callable[[], Logger]] = None,
**kwargs,
):
"""Initializes an Algorithm instance.
Args:
config: Algorithm-specific configuration dict.
env: Name of the environment to use (e.g. a gym-registered str),
a full class path (e.g.
"ray.rllib.examples.env.random_env.RandomEnv"), or an Env
class directly. Note that this arg can also be specified via
the "env" key in `config`.
logger_creator: Callable that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
**kwargs: Arguments passed to the Trainable base class.
"""
# User provided (partial) config (this may be w/o the default
# Trainer's Config object). Will get merged with AlgorithmConfig()
# in self.setup().
config = config or {}
# Resolve AlgorithmConfig into a plain dict.
# TODO: In the future, only support AlgorithmConfig objects here.
if isinstance(config, AlgorithmConfig):
config = config.to_dict()
# Convert `env` provided in config into a concrete env creator callable, which
# takes an EnvContext (config dict) as arg and returning an RLlib supported Env
# type (e.g. a gym.Env).
self._env_id, self.env_creator = self._get_env_id_and_creator(
env or config.get("env"), config
)
env_descr = (
self._env_id.__name__ if isinstance(self._env_id, type) else self._env_id
)
# Placeholder for a local replay buffer instance.
self.local_replay_buffer = None
# Create a default logger creator if no logger_creator is specified
if logger_creator is None:
# Default logdir prefix containing the agent's name and the
# env id.
timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
logdir_prefix = "{}_{}_{}".format(str(self), env_descr, timestr)
if not os.path.exists(DEFAULT_RESULTS_DIR):
os.makedirs(DEFAULT_RESULTS_DIR)
logdir = tempfile.mkdtemp(prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
# Allow users to more precisely configure the created logger
# via "logger_config.type".
if config.get("logger_config") and "type" in config["logger_config"]:
def default_logger_creator(config):
"""Creates a custom logger with the default prefix."""
cfg = config["logger_config"].copy()
cls = cfg.pop("type")
# Provide default for logdir, in case the user does
# not specify this in the "logger_config" dict.
logdir_ = cfg.pop("logdir", logdir)
return from_config(cls=cls, _args=[cfg], logdir=logdir_)
# If no `type` given, use tune's UnifiedLogger as last resort.
else:
def default_logger_creator(config):
"""Creates a Unified logger with the default prefix."""
return UnifiedLogger(config, logdir, loggers=None)
logger_creator = default_logger_creator
# Metrics-related properties.
self._timers = defaultdict(_Timer)
self._counters = defaultdict(int)
self._episode_history = []
self._episodes_to_be_collected = []
self._remote_workers_for_metrics = []
# Evaluation WorkerSet and metrics last returned by `self.evaluate()`.
self.evaluation_workers: Optional[WorkerSet] = None
# If evaluation duration is "auto", use a AsyncRequestsManager to be more
# robust against eval worker failures.
self._evaluation_async_req_manager: Optional[AsyncRequestsManager] = None
# Initialize common evaluation_metrics to nan, before they become
# available. We want to make sure the metrics are always present
# (although their values may be nan), so that Tune does not complain
# when we use these as stopping criteria.
self.evaluation_metrics = {
"evaluation": {
"episode_reward_max": np.nan,
"episode_reward_min": np.nan,
"episode_reward_mean": np.nan,
}
}
super().__init__(config=config, logger_creator=logger_creator, **kwargs)
# Check, whether `training_iteration` is still a tune.Trainable property
# and has not been overridden by the user in the attempt to implement the
# algos logic (this should be done now inside `training_step`).
try:
assert isinstance(self.training_iteration, int)
except AssertionError:
raise AssertionError(
"Your Algorithm's `training_iteration` seems to be overridden by your "
"custom training logic! To solve this problem, simply rename your "
"`self.training_iteration()` method into `self.training_step`."
)
@OverrideToImplementCustomLogic
@classmethod
def get_default_config(cls) -> AlgorithmConfigDict:
return AlgorithmConfig().to_dict()
@OverrideToImplementCustomLogic_CallToSuperRecommended
@override(Trainable)
def setup(self, config: PartialAlgorithmConfigDict):
# Setup our config: Merge the user-supplied config (which could
# be a partial config dict with the class' default).
self.config = self.merge_trainer_configs(
self.get_default_config(), config, self._allow_unknown_configs
)
self.config["env"] = self._env_id
# Validate the framework settings in config.
self.validate_framework(self.config)
# Set Trainer's seed after we have - if necessary - enabled
# tf eager-execution.
update_global_seed_if_necessary(self.config["framework"], self.config["seed"])
self.validate_config(self.config)
self._record_usage(self.config)
self.callbacks = self.config["callbacks"]()
log_level = self.config.get("log_level")
if log_level in ["WARN", "ERROR"]:
logger.info(
"Current log_level is {}. For more information, "
"set 'log_level': 'INFO' / 'DEBUG' or use the -v and "
"-vv flags.".format(log_level)
)
if self.config.get("log_level"):
logging.getLogger("ray.rllib").setLevel(self.config["log_level"])
# Create local replay buffer if necessary.
self.local_replay_buffer = self._create_local_replay_buffer_if_necessary(
self.config
)
# Create a dict, mapping ActorHandles to sets of open remote
# requests (object refs). This way, we keep track, of which actors
# inside this Trainer (e.g. a remote RolloutWorker) have
# already been sent how many (e.g. `sample()`) requests.
self.remote_requests_in_flight: DefaultDict[
ActorHandle, Set[ray.ObjectRef]
] = defaultdict(set)
self.workers: Optional[WorkerSet] = None
self.train_exec_impl = None
# Offline RL settings.
input_evaluation = self.config.get("input_evaluation")
if input_evaluation is not None and input_evaluation is not DEPRECATED_VALUE:
ope_dict = {str(ope): {"type": ope} for ope in input_evaluation}
deprecation_warning(
old="config.input_evaluation={}".format(input_evaluation),
new='config["evaluation_config"]'
'["off_policy_estimation_methods"]={}'.format(
ope_dict,
),
error=False,
help="Running OPE during training is not recommended.",
)
self.config["off_policy_estimation_methods"] = ope_dict
# Deprecated way of implementing Trainer sub-classes (or "templates"
# via the `build_trainer` utility function).
# Instead, sub-classes should override the Trainable's `setup()`
# method and call super().setup() from within that override at some
# point.
# Old design: Override `Trainer._init`.
_init = False
try:
self._init(self.config, self.env_creator)
_init = True
# New design: Override `Trainable.setup()` (as indented by tune.Trainable)
# and do or don't call `super().setup()` from within your override.
# By default, `super().setup()` will create both worker sets:
# "rollout workers" for collecting samples for training and - if
# applicable - "evaluation workers" for evaluation runs in between or
# parallel to training.
# TODO: Deprecate `_init()` and remove this try/except block.
except NotImplementedError:
pass
# Only if user did not override `_init()`:
if _init is False:
# - Create rollout workers here automatically.
# - Run the execution plan to create the local iterator to `next()`
# in each training iteration.
# This matches the behavior of using `build_trainer()`, which
# has been deprecated.
try:
self.workers = WorkerSet(
env_creator=self.env_creator,
validate_env=self.validate_env,
policy_class=self.get_default_policy_class(self.config),
trainer_config=self.config,
num_workers=self.config["num_workers"],
local_worker=True,
logdir=self.logdir,
)
# WorkerSet creation possibly fails, if some (remote) workers cannot
# be initialized properly (due to some errors in the RolloutWorker's
# constructor).
except RayActorError as e:
# In case of an actor (remote worker) init failure, the remote worker
# may still exist and will be accessible, however, e.g. calling
# its `sample.remote()` would result in strange "property not found"
# errors.
if e.actor_init_failed:
# Raise the original error here that the RolloutWorker raised
# during its construction process. This is to enforce transparency
# for the user (better to understand the real reason behind the
# failure).
# - e.args[0]: The RayTaskError (inside the caught RayActorError).
# - e.args[0].args[2]: The original Exception (e.g. a ValueError due
# to a config mismatch) thrown inside the actor.
raise e.args[0].args[2]
# In any other case, raise the RayActorError as-is.
else:
raise e
# By default, collect metrics for all remote workers.
self._remote_workers_for_metrics = self.workers.remote_workers()
# Function defining one single training iteration's behavior.
if self.config["_disable_execution_plan_api"]:
# Ensure remote workers are initially in sync with the local worker.
self.workers.sync_weights()
# LocalIterator-creating "execution plan".
# Only call this once here to create `self.train_exec_impl`,
# which is a ray.util.iter.LocalIterator that will be `next`'d
# on each training iteration.
else:
self.train_exec_impl = self.execution_plan(
self.workers, self.config, **self._kwargs_for_execution_plan()
)
# Now that workers have been created, update our policies
# dict in config[multiagent] (with the correct original/
# unpreprocessed spaces).
self.config["multiagent"][
"policies"
] = self.workers.local_worker().policy_dict
# Evaluation WorkerSet setup.
# User would like to setup a separate evaluation worker set.
# Update with evaluation settings:
user_eval_config = copy.deepcopy(self.config["evaluation_config"])
# Assert that user has not unset "in_evaluation".
assert (
"in_evaluation" not in user_eval_config
or user_eval_config["in_evaluation"] is True
)
# Merge user-provided eval config with the base config. This makes sure
# the eval config is always complete, no matter whether we have eval
# workers or perform evaluation on the (non-eval) local worker.
eval_config = merge_dicts(self.config, user_eval_config)
self.config["evaluation_config"] = eval_config
if self.config.get("evaluation_num_workers", 0) > 0 or self.config.get(
"evaluation_interval"
):
logger.debug(f"Using evaluation_config: {user_eval_config}.")
# Validate evaluation config.
self.validate_config(eval_config)
# Set the `in_evaluation` flag.
eval_config["in_evaluation"] = True
# Evaluation duration unit: episodes.
# Switch on `complete_episode` rollouts. Also, make sure
# rollout fragments are short so we never have more than one
# episode in one rollout.
if eval_config["evaluation_duration_unit"] == "episodes":
eval_config.update(
{
"batch_mode": "complete_episodes",
"rollout_fragment_length": 1,
}
)
# Evaluation duration unit: timesteps.
# - Set `batch_mode=truncate_episodes` so we don't perform rollouts
# strictly along episode borders.
# Set `rollout_fragment_length` such that desired steps are divided
# equally amongst workers or - in "auto" duration mode - set it
# to a reasonably small number (10), such that a single `sample()`
# call doesn't take too much time and we can stop evaluation as soon
# as possible after the train step is completed.
else:
eval_config.update(
{
"batch_mode": "truncate_episodes",
"rollout_fragment_length": 10
if self.config["evaluation_duration"] == "auto"
else int(
math.ceil(
self.config["evaluation_duration"]
/ (self.config["evaluation_num_workers"] or 1)
)
),
}
)
self.config["evaluation_config"] = eval_config
_, env_creator = self._get_env_id_and_creator(
eval_config.get("env"), eval_config
)
# Create a separate evaluation worker set for evaluation.
# If evaluation_num_workers=0, use the evaluation set's local
# worker for evaluation, otherwise, use its remote workers
# (parallelized evaluation).
self.evaluation_workers: WorkerSet = WorkerSet(
env_creator=env_creator,
validate_env=None,
policy_class=self.get_default_policy_class(self.config),
trainer_config=eval_config,
num_workers=self.config["evaluation_num_workers"],
# Don't even create a local worker if num_workers > 0.
local_worker=False,
logdir=self.logdir,
)
if self.config["enable_async_evaluation"]:
self._evaluation_async_req_manager = AsyncRequestsManager(
workers=self.evaluation_workers.remote_workers(),
max_remote_requests_in_flight_per_worker=1,
return_object_refs=True,
)
self._evaluation_weights_seq_number = 0
self.reward_estimators: Dict[str, OffPolicyEstimator] = {}
ope_types = {
"is": ImportanceSampling,
"wis": WeightedImportanceSampling,
"dm": DirectMethod,
"dr": DoublyRobust,
}
for name, method_config in self.config["off_policy_estimation_methods"].items():
method_type = method_config.pop("type")
if method_type in ope_types:
deprecation_warning(
old=method_type,
new=str(ope_types[method_type]),
error=False,
)
method_type = ope_types[method_type]
elif isinstance(method_type, str):
logger.log(0, "Trying to import from string: " + method_type)
mod, obj = method_type.rsplit(".", 1)
mod = importlib.import_module(mod)
method_type = getattr(mod, obj)
if isinstance(method_type, type) and issubclass(
method_type, OffPolicyEstimator
):
policy = self.get_policy()
gamma = self.config["gamma"]
self.reward_estimators[name] = method_type(
policy, gamma, **method_config
)
else:
raise ValueError(
f"Unknown off_policy_estimation type: {method_type}! Must be "
"either a class path or a sub-class of ray.rllib."
"offline.estimators.off_policy_estimator::OffPolicyEstimator"
)
# Run `on_algorithm_init` callback after initialization is done.
self.callbacks.on_algorithm_init(algorithm=self)
# TODO: Deprecated: In your sub-classes of Trainer, override `setup()`
# directly and call super().setup() from within it if you would like the
# default setup behavior plus some own setup logic.
# If you don't need the env/workers/config/etc.. setup for you by super,
# simply do not call super().setup() from your overridden method.
def _init(self, config: AlgorithmConfigDict, env_creator: EnvCreator) -> None:
raise NotImplementedError
@OverrideToImplementCustomLogic
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
"""Returns a default Policy class to use, given a config.
This class will be used inside RolloutWorkers' PolicyMaps in case
the policy class is not provided by the user in any single- or
multi-agent PolicySpec.
This method is experimental and currently only used, iff the Trainer
class was not created using the `build_trainer` utility and if
the Trainer sub-class does not override `_init()` and create it's
own WorkerSet in `_init()`.
"""
return getattr(self, "_policy_class", None)
@override(Trainable)
def step(self) -> ResultDict:
"""Implements the main `Trainer.train()` logic.
Takes n attempts to perform a single training step. Thereby
catches RayErrors resulting from worker failures. After n attempts,
fails gracefully.
Override this method in your Trainer sub-classes if you would like to
handle worker failures yourself.
Otherwise, override only `training_step()` to implement the core
algorithm logic.
Returns:
The results dict with stats/infos on sampling, training,
and - if required - evaluation.
"""
# Do we have to run `self.evaluate()` this iteration?
# `self.iteration` gets incremented after this function returns,
# meaning that e. g. the first time this function is called,
# self.iteration will be 0.
evaluate_this_iter = (
self.config["evaluation_interval"] is not None
and (self.iteration + 1) % self.config["evaluation_interval"] == 0
)
# Results dict for training (and if appolicable: evaluation).
results: ResultDict = {}
local_worker = (
self.workers.local_worker()
if hasattr(self.workers, "local_worker")
else None
)
# Parallel eval + training: Kick off evaluation-loop and parallel train() call.
if evaluate_this_iter and self.config["evaluation_parallel_to_training"]:
(
results,
train_iter_ctx,
) = self._run_one_training_iteration_and_evaluation_in_parallel()
# - No evaluation necessary, just run the next training iteration.
# - We have to evaluate in this training iteration, but no parallelism ->
# evaluate after the training iteration is entirely done.
else:
results, train_iter_ctx = self._run_one_training_iteration()
# Sequential: Train (already done above), then evaluate.
if evaluate_this_iter and not self.config["evaluation_parallel_to_training"]:
results.update(self._run_one_evaluation(train_future=None))
# Attach latest available evaluation results to train results,
# if necessary.
if not evaluate_this_iter and self.config["always_attach_evaluation_results"]:
assert isinstance(
self.evaluation_metrics, dict
), "Trainer.evaluate() needs to return a dict."
results.update(self.evaluation_metrics)
if hasattr(self, "workers") and isinstance(self.workers, WorkerSet):
# Sync filters on workers.
self._sync_filters_if_needed(
from_worker=self.workers.local_worker(),
workers=self.workers,
timeout_seconds=self.config[
"sync_filters_on_rollout_workers_timeout_s"
],
)
# Collect worker metrics and add combine them with `results`.
if self.config["_disable_execution_plan_api"]:
episodes_this_iter, self._episodes_to_be_collected = collect_episodes(
local_worker,
self._remote_workers_for_metrics,
self._episodes_to_be_collected,
timeout_seconds=self.config["metrics_episode_collection_timeout_s"],
)
results = self._compile_iteration_results(
episodes_this_iter=episodes_this_iter,
step_ctx=train_iter_ctx,
iteration_results=results,
)
# Check `env_task_fn` for possible update of the env's task.
if self.config["env_task_fn"] is not None:
if not callable(self.config["env_task_fn"]):
raise ValueError(
"`env_task_fn` must be None or a callable taking "
"[train_results, env, env_ctx] as args!"
)
def fn(env, env_context, task_fn):
new_task = task_fn(results, env, env_context)
cur_task = env.get_task()
if cur_task != new_task:
env.set_task(new_task)
fn = functools.partial(fn, task_fn=self.config["env_task_fn"])
self.workers.foreach_env_with_context(fn)
return results
@PublicAPI
def evaluate(
self,
duration_fn: Optional[Callable[[int], int]] = None,
) -> dict:
"""Evaluates current policy under `evaluation_config` settings.
Note that this default implementation does not do anything beyond
merging evaluation_config with the normal trainer config.
Args:
duration_fn: An optional callable taking the already run
num episodes as only arg and returning the number of
episodes left to run. It's used to find out whether
evaluation should continue.
"""
# Call the `_before_evaluate` hook.
self._before_evaluate()
# Sync weights to the evaluation WorkerSet.
if self.evaluation_workers is not None:
self.evaluation_workers.sync_weights(
from_worker=self.workers.local_worker()
)
self._sync_filters_if_needed(
from_worker=self.workers.local_worker(),
workers=self.evaluation_workers,
timeout_seconds=self.config[
"sync_filters_on_rollout_workers_timeout_s"
],
)
if self.config["custom_eval_function"]:
logger.info(
"Running custom eval function {}".format(
self.config["custom_eval_function"]
)
)
metrics = self.config["custom_eval_function"](self, self.evaluation_workers)
if not metrics or not isinstance(metrics, dict):
raise ValueError(
"Custom eval function must return "
"dict of metrics, got {}.".format(metrics)
)
else:
if (
self.evaluation_workers is None
and self.workers.local_worker().input_reader is None
):
raise ValueError(
"Cannot evaluate w/o an evaluation worker set in "
"the Trainer or w/o an env on the local worker!\n"
"Try one of the following:\n1) Set "
"`evaluation_interval` >= 0 to force creating a "
"separate evaluation worker set.\n2) Set "
"`create_env_on_driver=True` to force the local "
"(non-eval) worker to have an environment to "
"evaluate on."
)
# How many episodes/timesteps do we need to run?
# In "auto" mode (only for parallel eval + training): Run as long
# as training lasts.
unit = self.config["evaluation_duration_unit"]
eval_cfg = self.config["evaluation_config"]
rollout = eval_cfg["rollout_fragment_length"]
num_envs = eval_cfg["num_envs_per_worker"]
auto = self.config["evaluation_duration"] == "auto"
duration = (
self.config["evaluation_duration"]
if not auto
else (self.config["evaluation_num_workers"] or 1)
* (1 if unit == "episodes" else rollout)
)
agent_steps_this_iter = 0
env_steps_this_iter = 0
# Default done-function returns True, whenever num episodes
# have been completed.
if duration_fn is None:
def duration_fn(num_units_done):
return duration - num_units_done
logger.info(f"Evaluating current policy for {duration} {unit}.")
metrics = None
all_batches = []
# No evaluation worker set ->
# Do evaluation using the local worker. Expect error due to the
# local worker not having an env.
if self.evaluation_workers is None:
# If unit=episodes -> Run n times `sample()` (each sample
# produces exactly 1 episode).
# If unit=ts -> Run 1 `sample()` b/c the
# `rollout_fragment_length` is exactly the desired ts.
iters = duration if unit == "episodes" else 1
for _ in range(iters):
batch = self.workers.local_worker().sample()
agent_steps_this_iter += batch.agent_steps()
env_steps_this_iter += batch.env_steps()
if self.reward_estimators:
all_batches.append(batch)
metrics = collect_metrics(
self.workers.local_worker(),
keep_custom_metrics=eval_cfg["keep_per_episode_custom_metrics"],
timeout_seconds=eval_cfg["metrics_episode_collection_timeout_s"],
)
# Evaluation worker set only has local worker.
elif self.config["evaluation_num_workers"] == 0:
# If unit=episodes -> Run n times `sample()` (each sample
# produces exactly 1 episode).
# If unit=ts -> Run 1 `sample()` b/c the
# `rollout_fragment_length` is exactly the desired ts.
iters = duration if unit == "episodes" else 1
for _ in range(iters):
batch = self.evaluation_workers.local_worker().sample()
agent_steps_this_iter += batch.agent_steps()
env_steps_this_iter += batch.env_steps()
if self.reward_estimators:
all_batches.append(batch)
# Evaluation worker set has n remote workers.
else:
# How many episodes have we run (across all eval workers)?
num_units_done = 0
_round = 0
while True:
units_left_to_do = duration_fn(num_units_done)
if units_left_to_do <= 0:
break
_round += 1
try:
batches = ray.get(
[
w.sample.remote()
for i, w in enumerate(
self.evaluation_workers.remote_workers()
)
if i * (1 if unit == "episodes" else rollout * num_envs)
< units_left_to_do
],
timeout=self.config["evaluation_sample_timeout_s"],
)
except GetTimeoutError:
logger.warning(
"Calling `sample()` on your remote evaluation worker(s) "
"resulted in a timeout (after the configured "
f"{self.config['evaluation_sample_timeout_s']} seconds)! "
"Try to set `evaluation_sample_timeout_s` in your config"
" to a larger value."
+ (
" If your episodes don't terminate easily, you may "
"also want to set `evaluation_duration_unit` to "
"'timesteps' (instead of 'episodes')."
if unit == "episodes"
else ""
)
)
break
_agent_steps = sum(b.agent_steps() for b in batches)
_env_steps = sum(b.env_steps() for b in batches)
# 1 episode per returned batch.
if unit == "episodes":
num_units_done += len(batches)
# Make sure all batches are exactly one episode.
for ma_batch in batches:
ma_batch = ma_batch.as_multi_agent()
for batch in ma_batch.policy_batches.values():
assert np.sum(batch[SampleBatch.DONES])
# n timesteps per returned batch.
else:
num_units_done += (
_agent_steps if self._by_agent_steps else _env_steps
)
if self.reward_estimators:
all_batches.extend(batches)
agent_steps_this_iter += _agent_steps
env_steps_this_iter += _env_steps
logger.info(
f"Ran round {_round} of parallel evaluation "
f"({num_units_done}/{duration if not auto else '?'} "
f"{unit} done)"
)
if metrics is None:
metrics = collect_metrics(
self.evaluation_workers.local_worker(),
self.evaluation_workers.remote_workers(),
keep_custom_metrics=self.config["keep_per_episode_custom_metrics"],
timeout_seconds=eval_cfg["metrics_episode_collection_timeout_s"],
)
metrics[NUM_AGENT_STEPS_SAMPLED_THIS_ITER] = agent_steps_this_iter
metrics[NUM_ENV_STEPS_SAMPLED_THIS_ITER] = env_steps_this_iter
# TODO: Remove this key at some point. Here for backward compatibility.
metrics["timesteps_this_iter"] = env_steps_this_iter
if self.reward_estimators:
# Compute off-policy estimates
metrics["off_policy_estimator"] = {}
total_batch = concat_samples(all_batches)
for name, estimator in self.reward_estimators.items():
estimates = estimator.estimate(total_batch)
metrics["off_policy_estimator"][name] = estimates
# Evaluation does not run for every step.
# Save evaluation metrics on trainer, so it can be attached to
# subsequent step results as latest evaluation result.
self.evaluation_metrics = {"evaluation": metrics}
# Also return the results here for convenience.
return self.evaluation_metrics
@ExperimentalAPI
def _evaluate_async(
self,
duration_fn: Optional[Callable[[int], int]] = None,
) -> dict:
"""Evaluates current policy under `evaluation_config` settings.
Note that this default implementation does not do anything beyond
merging evaluation_config with the normal trainer config.
Args:
duration_fn: An optional callable taking the already run
num episodes as only arg and returning the number of
episodes left to run. It's used to find out whether
evaluation should continue.
"""
# How many episodes/timesteps do we need to run?
# In "auto" mode (only for parallel eval + training): Run as long
# as training lasts.
unit = self.config["evaluation_duration_unit"]
eval_cfg = self.config["evaluation_config"]
rollout = eval_cfg["rollout_fragment_length"]
num_envs = eval_cfg["num_envs_per_worker"]
auto = self.config["evaluation_duration"] == "auto"
duration = (
self.config["evaluation_duration"]
if not auto
else (self.config["evaluation_num_workers"] or 1)
* (1 if unit == "episodes" else rollout)
)
# Call the `_before_evaluate` hook.
self._before_evaluate()
# Put weights only once into object store and use same object
# ref to synch to all workers.
self._evaluation_weights_seq_number += 1
weights_ref = ray.put(self.workers.local_worker().get_weights())
# TODO(Jun): Make sure this cannot block for e.g. 1h. Implement solution via
# connectors.
self._sync_filters_if_needed(
from_worker=self.workers.local_worker(),
workers=self.evaluation_workers,
timeout_seconds=eval_cfg.get("sync_filters_on_rollout_workers_timeout_s"),
)
if self.config["custom_eval_function"]:
raise ValueError(
"`custom_eval_function` not supported in combination "
"with `enable_async_evaluation=True` config setting!"
)
if self.evaluation_workers is None and (
self.workers.local_worker().input_reader is None
or self.config["evaluation_num_workers"] == 0