-
Notifications
You must be signed in to change notification settings - Fork 5.4k
/
env_runner_v2.py
1223 lines (1067 loc) · 50.3 KB
/
env_runner_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from collections import defaultdict
import logging
import time
import tree # pip install dm_tree
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Set, Tuple, Union
import numpy as np
from ray.rllib.env.base_env import ASYNC_RESET_RETURN, BaseEnv
from ray.rllib.env.external_env import ExternalEnvWrapper
from ray.rllib.env.wrappers.atari_wrappers import MonitorEnv, get_wrapper_by_cls
from ray.rllib.evaluation.collectors.simple_list_collector import _PolicyCollectorGroup
from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.evaluation.metrics import RolloutMetrics
from ray.rllib.models.preprocessors import Preprocessor
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, concat_samples
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.filter import Filter
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.spaces.space_utils import unbatch, get_original_space
from ray.rllib.utils.typing import (
ActionConnectorDataType,
AgentConnectorDataType,
AgentID,
EnvActionType,
EnvID,
EnvInfoDict,
EnvObsType,
MultiAgentDict,
MultiEnvDict,
PolicyID,
PolicyOutputType,
SampleBatchType,
StateBatches,
TensorStructType,
)
from ray.util.debug import log_once
if TYPE_CHECKING:
from gymnasium.envs.classic_control.rendering import SimpleImageViewer
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.evaluation.rollout_worker import RolloutWorker
logger = logging.getLogger(__name__)
MIN_LARGE_BATCH_THRESHOLD = 1000
DEFAULT_LARGE_BATCH_THRESHOLD = 5000
MS_TO_SEC = 1000.0
class _PerfStats:
"""Sampler perf stats that will be included in rollout metrics."""
def __init__(self, ema_coef: Optional[float] = None):
# If not None, enable Exponential Moving Average mode.
# The way we update stats is by:
# updated = (1 - ema_coef) * old + ema_coef * new
# In general provides more responsive stats about sampler performance.
# TODO(jungong) : make ema the default (only) mode if it works well.
self.ema_coef = ema_coef
self.iters = 0
self.raw_obs_processing_time = 0.0
self.inference_time = 0.0
self.action_processing_time = 0.0
self.env_wait_time = 0.0
self.env_render_time = 0.0
def incr(self, field: str, value: Union[int, float]):
if field == "iters":
self.iters += value
return
# All the other fields support either global average or ema mode.
if self.ema_coef is None:
# Global average.
self.__dict__[field] += value
else:
self.__dict__[field] = (1.0 - self.ema_coef) * self.__dict__[
field
] + self.ema_coef * value
def _get_avg(self):
# Mean multiplicator (1000 = sec -> ms).
factor = MS_TO_SEC / self.iters
return {
# Raw observation preprocessing.
"mean_raw_obs_processing_ms": self.raw_obs_processing_time * factor,
# Computing actions through policy.
"mean_inference_ms": self.inference_time * factor,
# Processing actions (to be sent to env, e.g. clipping).
"mean_action_processing_ms": self.action_processing_time * factor,
# Waiting for environment (during poll).
"mean_env_wait_ms": self.env_wait_time * factor,
# Environment rendering (False by default).
"mean_env_render_ms": self.env_render_time * factor,
}
def _get_ema(self):
# In EMA mode, stats are already (exponentially) averaged,
# hence we only need to do the sec -> ms conversion here.
return {
# Raw observation preprocessing.
"mean_raw_obs_processing_ms": self.raw_obs_processing_time * MS_TO_SEC,
# Computing actions through policy.
"mean_inference_ms": self.inference_time * MS_TO_SEC,
# Processing actions (to be sent to env, e.g. clipping).
"mean_action_processing_ms": self.action_processing_time * MS_TO_SEC,
# Waiting for environment (during poll).
"mean_env_wait_ms": self.env_wait_time * MS_TO_SEC,
# Environment rendering (False by default).
"mean_env_render_ms": self.env_render_time * MS_TO_SEC,
}
def get(self):
if self.ema_coef is None:
return self._get_avg()
else:
return self._get_ema()
class _NewDefaultDict(defaultdict):
def __missing__(self, env_id):
ret = self[env_id] = self.default_factory(env_id)
return ret
def _build_multi_agent_batch(
episode_id: int,
batch_builder: _PolicyCollectorGroup,
large_batch_threshold: int,
multiple_episodes_in_batch: bool,
) -> MultiAgentBatch:
"""Build MultiAgentBatch from a dict of _PolicyCollectors.
Args:
env_steps: total env steps.
policy_collectors: collected training SampleBatchs by policy.
Returns:
Always returns a sample batch in MultiAgentBatch format.
"""
ma_batch = {}
for pid, collector in batch_builder.policy_collectors.items():
if collector.agent_steps <= 0:
continue
if batch_builder.agent_steps > large_batch_threshold and log_once(
"large_batch_warning"
):
logger.warning(
"More than {} observations in {} env steps for "
"episode {} ".format(
batch_builder.agent_steps, batch_builder.env_steps, episode_id
)
+ "are buffered in the sampler. If this is more than you "
"expected, check that that you set a horizon on your "
"environment correctly and that it terminates at some "
"point. Note: In multi-agent environments, "
"`rollout_fragment_length` sets the batch size based on "
"(across-agents) environment steps, not the steps of "
"individual agents, which can result in unexpectedly "
"large batches."
+ (
"Also, you may be waiting for your Env to "
"terminate (batch_mode=`complete_episodes`). Make sure "
"it does at some point."
if not multiple_episodes_in_batch
else ""
)
)
ma_batch[pid] = collector.build()
# Create the multi agent batch.
return MultiAgentBatch(policy_batches=ma_batch, env_steps=batch_builder.env_steps)
def _batch_inference_sample_batches(eval_data: List[SampleBatch]) -> SampleBatch:
"""Batch a list of input SampleBatches into a single SampleBatch.
Args:
eval_data: list of SampleBatches.
Returns:
single batched SampleBatch.
"""
inference_batch = concat_samples(eval_data)
if "state_in_0" in inference_batch:
batch_size = len(eval_data)
inference_batch[SampleBatch.SEQ_LENS] = np.ones(batch_size, dtype=np.int32)
return inference_batch
@DeveloperAPI
class EnvRunnerV2:
"""Collect experiences from user environment using Connectors."""
def __init__(
self,
worker: "RolloutWorker",
base_env: BaseEnv,
multiple_episodes_in_batch: bool,
callbacks: "DefaultCallbacks",
perf_stats: _PerfStats,
rollout_fragment_length: int = 200,
count_steps_by: str = "env_steps",
render: bool = None,
):
"""
Args:
worker: Reference to the current rollout worker.
base_env: Env implementing BaseEnv.
multiple_episodes_in_batch: Whether to pack multiple
episodes into each batch. This guarantees batches will be exactly
`rollout_fragment_length` in size.
callbacks: User callbacks to run on episode events.
perf_stats: Record perf stats into this object.
rollout_fragment_length: The length of a fragment to collect
before building a SampleBatch from the data and resetting
the SampleBatchBuilder object.
count_steps_by: One of "env_steps" (default) or "agent_steps".
Use "agent_steps", if you want rollout lengths to be counted
by individual agent steps. In a multi-agent env,
a single env_step contains one or more agent_steps, depending
on how many agents are present at any given time in the
ongoing episode.
render: Whether to try to render the environment after each
step.
"""
self._worker = worker
if isinstance(base_env, ExternalEnvWrapper):
raise ValueError(
"Policies using the new Connector API do not support ExternalEnv."
)
self._base_env = base_env
self._multiple_episodes_in_batch = multiple_episodes_in_batch
self._callbacks = callbacks
self._perf_stats = perf_stats
self._rollout_fragment_length = rollout_fragment_length
self._count_steps_by = count_steps_by
self._render = render
# May be populated for image rendering.
self._simple_image_viewer: Optional[
"SimpleImageViewer"
] = self._get_simple_image_viewer()
# Keeps track of active episodes.
self._active_episodes: Dict[EnvID, EpisodeV2] = {}
self._batch_builders: Dict[EnvID, _PolicyCollectorGroup] = _NewDefaultDict(
self._new_batch_builder
)
self._large_batch_threshold: int = (
max(MIN_LARGE_BATCH_THRESHOLD, self._rollout_fragment_length * 10)
if self._rollout_fragment_length != float("inf")
else DEFAULT_LARGE_BATCH_THRESHOLD
)
def _get_simple_image_viewer(self):
"""Maybe construct a SimpleImageViewer instance for episode rendering."""
# Try to render the env, if required.
if not self._render:
return None
try:
from gymnasium.envs.classic_control.rendering import SimpleImageViewer
return SimpleImageViewer()
except (ImportError, ModuleNotFoundError):
self._render = False # disable rendering
logger.warning(
"Could not import gymnasium.envs.classic_control."
"rendering! Try `pip install gymnasium[all]`."
)
return None
def _call_on_episode_start(self, episode, env_id):
# Call each policy's Exploration.on_episode_start method.
# Note: This may break the exploration (e.g. ParameterNoise) of
# policies in the `policy_map` that have not been recently used
# (and are therefore stashed to disk). However, we certainly do not
# want to loop through all (even stashed) policies here as that
# would counter the purpose of the LRU policy caching.
for p in self._worker.policy_map.cache.values():
if getattr(p, "exploration", None) is not None:
p.exploration.on_episode_start(
policy=p,
environment=self._base_env,
episode=episode,
tf_sess=p.get_session(),
)
# Call `on_episode_start()` callback.
self._callbacks.on_episode_start(
worker=self._worker,
base_env=self._base_env,
policies=self._worker.policy_map,
env_index=env_id,
episode=episode,
)
def _new_batch_builder(self, _) -> _PolicyCollectorGroup:
"""Create a new batch builder.
We create a _PolicyCollectorGroup based on the full policy_map
as the batch builder.
"""
return _PolicyCollectorGroup(self._worker.policy_map)
def run(self) -> Iterator[SampleBatchType]:
"""Samples and yields training episodes continuously.
Yields:
Object containing state, action, reward, terminal condition,
and other fields as dictated by `policy`.
"""
while True:
outputs = self.step()
for o in outputs:
yield o
def step(self) -> List[SampleBatchType]:
"""Samples training episodes by stepping through environments."""
self._perf_stats.incr("iters", 1)
t0 = time.time()
# Get observations from all ready agents.
# types: MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, ...
(
unfiltered_obs,
rewards,
terminateds,
truncateds,
infos,
off_policy_actions,
) = self._base_env.poll()
env_poll_time = time.time() - t0
# Process observations and prepare for policy evaluation.
t1 = time.time()
# types: Set[EnvID], Dict[PolicyID, List[AgentConnectorDataType]],
# List[Union[RolloutMetrics, SampleBatchType]]
active_envs, to_eval, outputs = self._process_observations(
unfiltered_obs=unfiltered_obs,
rewards=rewards,
terminateds=terminateds,
truncateds=truncateds,
infos=infos,
)
self._perf_stats.incr("raw_obs_processing_time", time.time() - t1)
# Do batched policy eval (accross vectorized envs).
t2 = time.time()
# types: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]
eval_results = self._do_policy_eval(to_eval=to_eval)
self._perf_stats.incr("inference_time", time.time() - t2)
# Process results and update episode state.
t3 = time.time()
actions_to_send: Dict[
EnvID, Dict[AgentID, EnvActionType]
] = self._process_policy_eval_results(
active_envs=active_envs,
to_eval=to_eval,
eval_results=eval_results,
off_policy_actions=off_policy_actions,
)
self._perf_stats.incr("action_processing_time", time.time() - t3)
# Return computed actions to ready envs. We also send to envs that have
# taken off-policy actions; those envs are free to ignore the action.
t4 = time.time()
self._base_env.send_actions(actions_to_send)
self._perf_stats.incr("env_wait_time", env_poll_time + time.time() - t4)
self._maybe_render()
return outputs
def _get_rollout_metrics(
self, episode: EpisodeV2, policy_map: Dict[str, Policy]
) -> List[RolloutMetrics]:
"""Get rollout metrics from completed episode."""
# TODO(jungong) : why do we need to handle atari metrics differently?
# Can we unify atari and normal env metrics?
atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(self._base_env)
if atari_metrics is not None:
for m in atari_metrics:
m._replace(custom_metrics=episode.custom_metrics)
return atari_metrics
# Create connector metrics
connector_metrics = {}
active_agents = episode.get_agents()
for agent in active_agents:
policy_id = episode.policy_for(agent)
policy = episode.policy_map[policy_id]
connector_metrics[policy_id] = policy.get_connector_metrics()
# Otherwise, return RolloutMetrics for the episode.
return [
RolloutMetrics(
episode_length=episode.length,
episode_reward=episode.total_reward,
agent_rewards=dict(episode.agent_rewards),
custom_metrics=episode.custom_metrics,
perf_stats={},
hist_data=episode.hist_data,
media=episode.media,
connector_metrics=connector_metrics,
)
]
def _process_observations(
self,
unfiltered_obs: MultiEnvDict,
rewards: MultiEnvDict,
terminateds: MultiEnvDict,
truncateds: MultiEnvDict,
infos: MultiEnvDict,
) -> Tuple[
Set[EnvID],
Dict[PolicyID, List[AgentConnectorDataType]],
List[Union[RolloutMetrics, SampleBatchType]],
]:
"""Process raw obs from env.
Group data for active agents by policy. Reset environments that are done.
Args:
unfiltered_obs: The unfiltered, raw observations from the BaseEnv
(vectorized, possibly multi-agent). Dict of dict: By env index,
then agent ID, then mapped to actual obs.
rewards: The rewards MultiEnvDict of the BaseEnv.
terminateds: The `terminated` flags MultiEnvDict of the BaseEnv.
truncateds: The `truncated` flags MultiEnvDict of the BaseEnv.
infos: The MultiEnvDict of infos dicts of the BaseEnv.
Returns:
A tuple of:
A list of envs that were active during this step.
AgentConnectorDataType for active agents for policy evaluation.
SampleBatches and RolloutMetrics for completed agents for output.
"""
# Output objects.
# Note that we need to track envs that are active during this round explicitly,
# just to be confident which envs require us to send at least an empty action
# dict to.
# We can not get this from the _active_episode or to_eval lists because
# 1. All envs are not required to step during every single step. And
# 2. to_eval only contains data for the agents that are still active. An env may
# be active but all agents are done during the step.
active_envs: Set[EnvID] = set()
to_eval: Dict[PolicyID, List[AgentConnectorDataType]] = defaultdict(list)
outputs: List[Union[RolloutMetrics, SampleBatchType]] = []
# For each (vectorized) sub-environment.
# types: EnvID, Dict[AgentID, EnvObsType]
for env_id, env_obs in unfiltered_obs.items():
# Check for env_id having returned an error instead of a multi-agent
# obs dict. This is how our BaseEnv can tell the caller to `poll()` that
# one of its sub-environments is faulty and should be restarted (and the
# ongoing episode should not be used for training).
if isinstance(env_obs, Exception):
assert terminateds[env_id]["__all__"] is True, (
f"ERROR: When a sub-environment (env-id {env_id}) returns an error "
"as observation, the terminateds[__all__] flag must also be set to "
"True!"
)
# all_agents_obs is an Exception here.
# Drop this episode and skip to next.
self._handle_done_episode(
env_id=env_id,
env_obs_or_exception=env_obs,
is_done=True,
active_envs=active_envs,
to_eval=to_eval,
outputs=outputs,
)
continue
if env_id not in self._active_episodes:
episode: EpisodeV2 = self.create_episode(env_id)
self._active_episodes[env_id] = episode
else:
episode: EpisodeV2 = self._active_episodes[env_id]
# If this episode is brand-new, call the episode start callback(s).
# Note: EpisodeV2s are initialized with length=-1 (before the reset).
if not episode.has_init_obs():
self._call_on_episode_start(episode, env_id)
# Check episode termination conditions.
if terminateds[env_id]["__all__"] or truncateds[env_id]["__all__"]:
all_agents_done = True
else:
all_agents_done = False
active_envs.add(env_id)
# Special handling of common info dict.
episode.set_last_info("__common__", infos[env_id].get("__common__", {}))
# Agent sample batches grouped by policy. Each set of sample batches will
# go through agent connectors together.
sample_batches_by_policy = defaultdict(list)
# Whether an agent is terminated or truncated.
agent_terminateds = {}
agent_truncateds = {}
for agent_id, obs in env_obs.items():
assert agent_id != "__all__"
policy_id: PolicyID = episode.policy_for(agent_id)
agent_terminated = bool(
terminateds[env_id]["__all__"] or terminateds[env_id].get(agent_id)
)
agent_terminateds[agent_id] = agent_terminated
agent_truncated = bool(
truncateds[env_id]["__all__"]
or truncateds[env_id].get(agent_id, False)
)
agent_truncateds[agent_id] = agent_truncated
# A completely new agent is already done -> Skip entirely.
if not episode.has_init_obs(agent_id) and (
agent_terminated or agent_truncated
):
continue
values_dict = {
SampleBatch.T: episode.length, # Episodes start at -1 before we
# add the initial obs. After that, we infer from initial obs at
# t=0 since that will be our new episode.length.
SampleBatch.ENV_ID: env_id,
SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
# Last action (SampleBatch.ACTIONS) column will be populated by
# StateBufferConnector.
# Reward received after taking action at timestep t.
SampleBatch.REWARDS: rewards[env_id].get(agent_id, 0.0),
# After taking action=a, did we reach terminal?
SampleBatch.TERMINATEDS: agent_terminated,
# Was the episode truncated artificially
# (e.g. b/c of some time limit)?
SampleBatch.TRUNCATEDS: agent_truncated,
SampleBatch.INFOS: infos[env_id].get(agent_id, {}),
SampleBatch.NEXT_OBS: obs,
}
# Queue this obs sample for connector preprocessing.
sample_batches_by_policy[policy_id].append((agent_id, values_dict))
# The entire episode is done.
if all_agents_done:
# Let's check to see if there are any agents that haven't got the
# last obs yet. If there are, we have to create fake-last
# observations for them. (the environment is not required to do so if
# terminateds[__all__]==True or truncateds[__all__]==True).
for agent_id in episode.get_agents():
# If the latest obs we got for this agent is done, or if its
# episode state is already done, nothing to do.
if (
agent_terminateds.get(agent_id, False)
or agent_truncateds.get(agent_id, False)
or episode.is_done(agent_id)
):
continue
policy_id: PolicyID = episode.policy_for(agent_id)
policy = self._worker.policy_map[policy_id]
# Create a fake observation by sampling the original env
# observation space.
obs_space = get_original_space(policy.observation_space)
# Although there is no obs for this agent, there may be
# good rewards and info dicts for it.
# This is the case for e.g. OpenSpiel games, where a reward
# is only earned with the last step, but the obs for that
# step is {}.
reward = rewards[env_id].get(agent_id, 0.0)
info = infos[env_id].get(agent_id, {})
values_dict = {
SampleBatch.T: episode.length,
SampleBatch.ENV_ID: env_id,
SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
# TODO(sven): These should be the summed-up(!) rewards since the
# last observation received for this agent.
SampleBatch.REWARDS: reward,
SampleBatch.TERMINATEDS: True,
SampleBatch.TRUNCATEDS: truncateds[env_id].get(agent_id, False),
SampleBatch.INFOS: info,
SampleBatch.NEXT_OBS: obs_space.sample(),
}
# Queue these fake obs for connector preprocessing too.
sample_batches_by_policy[policy_id].append((agent_id, values_dict))
# Run agent connectors.
for policy_id, batches in sample_batches_by_policy.items():
policy: Policy = self._worker.policy_map[policy_id]
# Collected full MultiAgentDicts for this environment.
# Run agent connectors.
assert (
policy.agent_connectors
), "EnvRunnerV2 requires agent connectors to work."
acd_list: List[AgentConnectorDataType] = [
AgentConnectorDataType(env_id, agent_id, data)
for agent_id, data in batches
]
# For all agents mapped to policy_id, run their data
# through agent_connectors.
processed = policy.agent_connectors(acd_list)
for d in processed:
# Record transition info if applicable.
if not episode.has_init_obs(d.agent_id):
episode.add_init_obs(
agent_id=d.agent_id,
init_obs=d.data.raw_dict[SampleBatch.NEXT_OBS],
init_infos=d.data.raw_dict[SampleBatch.INFOS],
t=d.data.raw_dict[SampleBatch.T],
)
else:
episode.add_action_reward_done_next_obs(
d.agent_id, d.data.raw_dict
)
# Need to evaluate next actions.
if not (
all_agents_done
or agent_terminateds.get(d.agent_id, False)
or agent_truncateds.get(d.agent_id, False)
or episode.is_done(d.agent_id)
):
# Add to eval set if env is not done and this particular agent
# is also not done.
item = AgentConnectorDataType(d.env_id, d.agent_id, d.data)
to_eval[policy_id].append(item)
# Finished advancing episode by 1 step, mark it so.
episode.step()
# Exception: The very first env.poll() call causes the env to get reset
# (no step taken yet, just a single starting observation logged).
# We need to skip this callback in this case.
if episode.length > 0:
# Invoke the `on_episode_step` callback after the step is logged
# to the episode.
self._callbacks.on_episode_step(
worker=self._worker,
base_env=self._base_env,
policies=self._worker.policy_map,
episode=episode,
env_index=env_id,
)
# Episode is terminated/truncated for all agents
# (terminateds[__all__] == True or truncateds[__all__] == True).
if all_agents_done:
# _handle_done_episode will build a MultiAgentBatch for all
# the agents that are done during this step of rollout in
# the case of _multiple_episodes_in_batch=False.
self._handle_done_episode(
env_id,
env_obs,
terminateds[env_id]["__all__"] or truncateds[env_id]["__all__"],
active_envs,
to_eval,
outputs,
)
# Try to build something.
if self._multiple_episodes_in_batch:
sample_batch = self._try_build_truncated_episode_multi_agent_batch(
self._batch_builders[env_id], episode
)
if sample_batch:
outputs.append(sample_batch)
# SampleBatch built from data collected by batch_builder.
# Clean up and delete the batch_builder.
del self._batch_builders[env_id]
return active_envs, to_eval, outputs
def _build_done_episode(
self,
env_id: EnvID,
is_done: bool,
outputs: List[SampleBatchType],
):
"""Builds a MultiAgentSampleBatch from the episode and adds it to outputs.
Args:
env_id: The env id.
is_done: Whether the env is done.
outputs: The list of outputs to add the
"""
episode: EpisodeV2 = self._active_episodes[env_id]
batch_builder = self._batch_builders[env_id]
episode.postprocess_episode(
batch_builder=batch_builder,
is_done=is_done,
check_dones=is_done,
)
# If, we are not allowed to pack the next episode into the same
# SampleBatch (batch_mode=complete_episodes) -> Build the
# MultiAgentBatch from a single episode and add it to "outputs".
# Otherwise, just postprocess and continue collecting across
# episodes.
if not self._multiple_episodes_in_batch:
ma_sample_batch = _build_multi_agent_batch(
episode.episode_id,
batch_builder,
self._large_batch_threshold,
self._multiple_episodes_in_batch,
)
if ma_sample_batch:
outputs.append(ma_sample_batch)
# SampleBatch built from data collected by batch_builder.
# Clean up and delete the batch_builder.
del self._batch_builders[env_id]
def __process_resetted_obs_for_eval(
self,
env_id: EnvID,
obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]],
episode: EpisodeV2,
to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
):
"""Process resetted obs through agent connectors for policy eval.
Args:
env_id: The env id.
obs: The Resetted obs.
episode: New episode.
to_eval: List of agent connector data for policy eval.
"""
per_policy_resetted_obs: Dict[PolicyID, List] = defaultdict(list)
# types: AgentID, EnvObsType
for agent_id, raw_obs in obs[env_id].items():
policy_id: PolicyID = episode.policy_for(agent_id)
per_policy_resetted_obs[policy_id].append((agent_id, raw_obs))
for policy_id, agents_obs in per_policy_resetted_obs.items():
policy = self._worker.policy_map[policy_id]
acd_list: List[AgentConnectorDataType] = [
AgentConnectorDataType(
env_id,
agent_id,
{
SampleBatch.NEXT_OBS: obs,
SampleBatch.INFOS: infos,
SampleBatch.T: episode.length,
},
)
for agent_id, obs in agents_obs
]
# Call agent connectors on these initial obs.
processed = policy.agent_connectors(acd_list)
for d in processed:
episode.add_init_obs(
agent_id=d.agent_id,
init_obs=d.data.raw_dict[SampleBatch.NEXT_OBS],
init_infos=d.data.raw_dict[SampleBatch.INFOS],
t=d.data.raw_dict[SampleBatch.T],
)
to_eval[policy_id].append(d)
def _handle_done_episode(
self,
env_id: EnvID,
env_obs_or_exception: MultiAgentDict,
is_done: bool,
active_envs: Set[EnvID],
to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
outputs: List[SampleBatchType],
) -> None:
"""Handle an all-finished episode.
Add collected SampleBatch to batch builder. Reset corresponding env, etc.
Args:
env_id: Environment ID.
env_obs_or_exception: Last per-environment observation or Exception.
env_infos: Last per-environment infos.
is_done: If all agents are done.
active_envs: Set of active env ids.
to_eval: Output container for policy eval data.
outputs: Output container for collected sample batches.
"""
if isinstance(env_obs_or_exception, Exception):
episode_or_exception: Exception = env_obs_or_exception
# Tell the sampler we have got a faulty episode.
outputs.append(RolloutMetrics(episode_faulty=True))
else:
episode_or_exception: EpisodeV2 = self._active_episodes[env_id]
# Add rollout metrics.
outputs.extend(
self._get_rollout_metrics(
episode_or_exception, policy_map=self._worker.policy_map
)
)
# Output the collected episode after adding rollout metrics so that we
# always fetch metrics with RolloutWorker before we fetch samples.
# This is because we need to behave like env_runner() for now.
self._build_done_episode(env_id, is_done, outputs)
# Clean up and deleted the post-processed episode now that we have collected
# its data.
self.end_episode(env_id, episode_or_exception)
# Create a new episode instance (before we reset the sub-environment).
new_episode: EpisodeV2 = self.create_episode(env_id)
# The sub environment at index `env_id` might throw an exception
# during the following `try_reset()` attempt. If configured with
# `restart_failed_sub_environments=True`, the BaseEnv will restart
# the affected sub environment (create a new one using its c'tor) and
# must reset the recreated sub env right after that.
# Should the sub environment fail indefinitely during these
# repeated reset attempts, the entire worker will be blocked.
# This would be ok, b/c the alternative would be the worker crashing
# entirely.
while True:
resetted_obs, resetted_infos = self._base_env.try_reset(env_id)
if (
resetted_obs is None
or resetted_obs == ASYNC_RESET_RETURN
or not isinstance(resetted_obs[env_id], Exception)
):
break
else:
# Report a faulty episode.
outputs.append(RolloutMetrics(episode_faulty=True))
# Reset connector state if this is a hard reset.
for p in self._worker.policy_map.cache.values():
p.agent_connectors.reset(env_id)
# Creates a new episode if this is not async return.
# If reset is async, we will get its result in some future poll.
if resetted_obs is not None and resetted_obs != ASYNC_RESET_RETURN:
self._active_episodes[env_id] = new_episode
self._call_on_episode_start(new_episode, env_id)
self.__process_resetted_obs_for_eval(
env_id,
resetted_obs,
resetted_infos,
new_episode,
to_eval,
)
# Step after adding initial obs. This will give us 0 env and agent step.
new_episode.step()
active_envs.add(env_id)
def create_episode(self, env_id: EnvID) -> EpisodeV2:
"""Creates a new EpisodeV2 instance and returns it.
Calls `on_episode_created` callbacks, but does NOT reset the respective
sub-environment yet.
Args:
env_id: Env ID.
Returns:
The newly created EpisodeV2 instance.
"""
# Make sure we currently don't have an active episode under this env ID.
assert env_id not in self._active_episodes
# Create a new episode under the same `env_id` and call the
# `on_episode_created` callbacks.
new_episode = EpisodeV2(
env_id,
self._worker.policy_map,
self._worker.policy_mapping_fn,
worker=self._worker,
callbacks=self._callbacks,
)
# Call `on_episode_created()` callback.
self._callbacks.on_episode_created(
worker=self._worker,
base_env=self._base_env,
policies=self._worker.policy_map,
env_index=env_id,
episode=new_episode,
)
return new_episode
def end_episode(
self, env_id: EnvID, episode_or_exception: Union[EpisodeV2, Exception]
):
"""Cleans up an episode that has finished.
Args:
env_id: Env ID.
episode_or_exception: Instance of an episode if it finished successfully.
Otherwise, the exception that was thrown,
"""
# Signal the end of an episode, either successfully with an Episode or
# unsuccessfully with an Exception.
self._callbacks.on_episode_end(
worker=self._worker,
base_env=self._base_env,
policies=self._worker.policy_map,
episode=episode_or_exception,
env_index=env_id,
)
# Call each (in-memory) policy's Exploration.on_episode_end
# method.
# Note: This may break the exploration (e.g. ParameterNoise) of
# policies in the `policy_map` that have not been recently used
# (and are therefore stashed to disk). However, we certainly do not
# want to loop through all (even stashed) policies here as that
# would counter the purpose of the LRU policy caching.
for p in self._worker.policy_map.cache.values():
if getattr(p, "exploration", None) is not None:
p.exploration.on_episode_end(
policy=p,
environment=self._base_env,
episode=episode_or_exception,
tf_sess=p.get_session(),
)
if isinstance(episode_or_exception, EpisodeV2):
episode = episode_or_exception
if episode.total_agent_steps == 0:
# if the key does not exist it means that throughout the episode all
# observations were empty (i.e. there was no agent in the env)
msg = (
f"Data from episode {episode.episode_id} does not show any agent "
f"interactions. Hint: Make sure for at least one timestep in the "
f"episode, env.step() returns non-empty values."
)
raise ValueError(msg)
# Clean up the episode and batch_builder for this env id.
if env_id in self._active_episodes:
del self._active_episodes[env_id]
def _try_build_truncated_episode_multi_agent_batch(
self, batch_builder: _PolicyCollectorGroup, episode: EpisodeV2
) -> Union[None, SampleBatch, MultiAgentBatch]:
# Measure batch size in env-steps.
if self._count_steps_by == "env_steps":
built_steps = batch_builder.env_steps
ongoing_steps = episode.active_env_steps
# Measure batch-size in agent-steps.
else:
built_steps = batch_builder.agent_steps
ongoing_steps = episode.active_agent_steps
# Reached the fragment-len -> We should build an MA-Batch.
if built_steps + ongoing_steps >= self._rollout_fragment_length:
if self._count_steps_by != "agent_steps":
assert built_steps + ongoing_steps == self._rollout_fragment_length, (
f"built_steps ({built_steps}) + ongoing_steps ({ongoing_steps}) != "
f"rollout_fragment_length ({self._rollout_fragment_length})."
)
# If we reached the fragment-len only because of `episode_id`
# (still ongoing) -> postprocess `episode_id` first.
if built_steps < self._rollout_fragment_length:
episode.postprocess_episode(batch_builder=batch_builder, is_done=False)
# If builder has collected some data,
# build the MA-batch and add to return values.
if batch_builder.agent_steps > 0:
return _build_multi_agent_batch(
episode.episode_id,
batch_builder,
self._large_batch_threshold,
self._multiple_episodes_in_batch,
)
# No batch-builder:
# We have reached the rollout-fragment length w/o any agent
# steps! Warn that the environment may never request any
# actions from any agents.
elif log_once("no_agent_steps"):
logger.warning(
"Your environment seems to be stepping w/o ever "
"emitting agent observations (agents are never "
"requested to act)!"
)
return None