/
sample_batch.py
1785 lines (1501 loc) · 65.1 KB
/
sample_batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import collections
from functools import partial
import itertools
import sys
from numbers import Number
from typing import Dict, Iterator, Set, Union
from typing import List, Optional
import numpy as np
import tree # pip install dm_tree
from ray.rllib.core.columns import Columns
from ray.rllib.utils.annotations import DeveloperAPI, ExperimentalAPI, PublicAPI
from ray.rllib.utils.compression import pack, unpack, is_compressed
from ray.rllib.utils.deprecation import Deprecated, deprecation_warning
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
from ray.rllib.utils.typing import (
PolicyID,
TensorType,
SampleBatchType,
ViewRequirementsDict,
)
from ray.util import log_once
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
# Default policy id for single agent environments
DEFAULT_POLICY_ID = "default_policy"
@DeveloperAPI
def attempt_count_timesteps(tensor_dict: dict):
"""Attempt to count timesteps based on dimensions of individual elements.
Returns the first successfully counted number of timesteps.
We do not attempt to count on INFOS or any state_in_* and state_out_* keys. The
number of timesteps we count in cases where we are unable to count is zero.
Args:
tensor_dict: A SampleBatch or another dict.
Returns:
count: The inferred number of timesteps >= 0.
"""
# Try to infer the "length" of the SampleBatch by finding the first
# value that is actually a ndarray/tensor.
# Skip manual counting routine if we can directly infer count from sequence lengths
seq_lens = tensor_dict.get(SampleBatch.SEQ_LENS)
if (
seq_lens is not None
and not (tf and tf.is_tensor(seq_lens) and not hasattr(seq_lens, "numpy"))
and len(seq_lens) > 0
):
if torch and torch.is_tensor(seq_lens):
return seq_lens.sum().item()
else:
return int(sum(seq_lens))
for k, v in tensor_dict.items():
if k == SampleBatch.SEQ_LENS:
continue
assert isinstance(k, str), tensor_dict
if (
k == SampleBatch.INFOS
or k.startswith("state_in_")
or k.startswith("state_out_")
):
# Don't attempt to count on infos since we make no assumptions
# about its content
# Don't attempt to count on state since nesting can potentially mess
# things up
continue
# If this is a nested dict (for example a nested observation),
# try to flatten it, assert that all elements have the same length (batch
# dimension)
v_list = tree.flatten(v) if isinstance(v, (dict, tuple)) else [v]
# TODO: Drop support for lists and Numbers as values.
# If v_list contains lists or Numbers, convert them to arrays, too.
v_list = [
np.array(_v) if isinstance(_v, (Number, list)) else _v for _v in v_list
]
try:
# Add one of the elements' length, since they are all the same
_len = len(v_list[0])
if _len:
return _len
except Exception:
pass
# Return zero if we are unable to count
return 0
@PublicAPI
class SampleBatch(dict):
"""Wrapper around a dictionary with string keys and array-like values.
For example, {"obs": [1, 2, 3], "reward": [0, -1, 1]} is a batch of three
samples, each with an "obs" and "reward" attribute.
"""
# On rows in SampleBatch:
# Each comment signifies how values relate to each other within a given row.
# A row generally signifies one timestep. Most importantly, at t=0, SampleBatch.OBS
# will usually be the reset-observation, while SampleBatch.ACTIONS will be the
# action based on the reset-observation and so on. This scheme is derived from
# RLlib's sampling logic.
# The following fields have all been moved to `Columns` and are only left here
# for backward compatibility.
OBS = Columns.OBS
ACTIONS = Columns.ACTIONS
REWARDS = Columns.REWARDS
TERMINATEDS = Columns.TERMINATEDS
TRUNCATEDS = Columns.TRUNCATEDS
INFOS = Columns.INFOS
SEQ_LENS = Columns.SEQ_LENS
T = Columns.T
ACTION_DIST_INPUTS = Columns.ACTION_DIST_INPUTS
ACTION_PROB = Columns.ACTION_PROB
ACTION_LOGP = Columns.ACTION_LOGP
VF_PREDS = Columns.VF_PREDS
VALUES_BOOTSTRAPPED = Columns.VALUES_BOOTSTRAPPED
EPS_ID = Columns.EPS_ID
NEXT_OBS = Columns.NEXT_OBS
# Action distribution object.
ACTION_DIST = "action_dist"
# Action chosen before SampleBatch.ACTIONS.
PREV_ACTIONS = "prev_actions"
# Reward received before SampleBatch.REWARDS.
PREV_REWARDS = "prev_rewards"
ENV_ID = "env_id" # An env ID (e.g. the index for a vectorized sub-env).
AGENT_INDEX = "agent_index" # Uniquely identifies an agent within an episode.
# Uniquely identifies a sample batch. This is important to distinguish RNN
# sequences from the same episode when multiple sample batches are
# concatenated (fusing sequences across batches can be unsafe).
UNROLL_ID = "unroll_id"
# RE 3
# This is only computed and used when RE3 exploration strategy is enabled.
OBS_EMBEDS = "obs_embeds"
# Decision Transformer
RETURNS_TO_GO = "returns_to_go"
ATTENTION_MASKS = "attention_masks"
# Do not set this key directly. Instead, the values under this key are
# auto-computed via the values of the TERMINATEDS and TRUNCATEDS keys.
DONES = "dones"
# Use SampleBatch.OBS instead.
CUR_OBS = "obs"
@PublicAPI
def __init__(self, *args, **kwargs):
"""Constructs a sample batch (same params as dict constructor).
Note: All args and those kwargs not listed below will be passed
as-is to the parent dict constructor.
Args:
_time_major: Whether data in this sample batch
is time-major. This is False by default and only relevant
if the data contains sequences.
_max_seq_len: The max sequence chunk length
if the data contains sequences.
_zero_padded: Whether the data in this batch
contains sequences AND these sequences are right-zero-padded
according to the `_max_seq_len` setting.
_is_training: Whether this batch is used for
training. If False, batch may be used for e.g. action
computations (inference).
"""
if SampleBatch.DONES in kwargs:
raise KeyError(
"SampleBatch cannot be constructed anymore with a `DONES` key! "
"Instead, set the new TERMINATEDS and TRUNCATEDS keys. The values under"
" DONES will then be automatically computed using terminated|truncated."
)
# Possible seq_lens (TxB or BxT) setup.
self.time_major = kwargs.pop("_time_major", None)
# Maximum seq len value.
self.max_seq_len = kwargs.pop("_max_seq_len", None)
# Is alredy right-zero-padded?
self.zero_padded = kwargs.pop("_zero_padded", False)
# Whether this batch is used for training (vs inference).
self._is_training = kwargs.pop("_is_training", None)
# Weighted average number of grad updates that have been performed on the
# policy/ies that were used to collect this batch.
# E.g.: Two rollout workers collect samples of 50ts each
# (rollout_fragment_length=50). One of them has a policy that has undergone
# 2 updates thus far, the other worker uses a policy that has undergone 3
# updates thus far. The train batch size is 100, so we concatenate these 2
# batches to a new one that's 100ts long. This new 100ts batch will have its
# `num_gradient_updates` property set to 2.5 as it's the weighted average
# (both original batches contribute 50%).
self.num_grad_updates: Optional[float] = kwargs.pop("_num_grad_updates", None)
# Call super constructor. This will make the actual data accessible
# by column name (str) via e.g. self["some-col"].
dict.__init__(self, *args, **kwargs)
# Indicates whether, for this batch, sequence lengths should be slices by
# their index in the batch or by their index as a sequence.
# This is useful if a batch contains tensors of shape (B, T, ...), where each
# index of B indicates one sequence. In this case, when slicing the batch,
# we want one sequence to be slices out per index in B (
# `_slice_seq_lens_by_batch_index=True`. However, if the padded batch
# contains tensors of shape (B*T, ...), where each index of B*T indicates
# one timestep, we want one sequence to be sliced per T steps in B*T (
# `self._slice_seq_lens_in_B=False`).
# ._slice_seq_lens_in_B = True is only meant to be used for batches that we
# feed into Learner._update(), all other places in RLlib are not expected to
# need this.
self._slice_seq_lens_in_B = False
self.accessed_keys = set()
self.added_keys = set()
self.deleted_keys = set()
self.intercepted_values = {}
self.get_interceptor = None
# Clear out None seq-lens.
seq_lens_ = self.get(SampleBatch.SEQ_LENS)
if seq_lens_ is None or (isinstance(seq_lens_, list) and len(seq_lens_) == 0):
self.pop(SampleBatch.SEQ_LENS, None)
# Numpyfy seq_lens if list.
elif isinstance(seq_lens_, list):
self[SampleBatch.SEQ_LENS] = seq_lens_ = np.array(seq_lens_, dtype=np.int32)
elif (torch and torch.is_tensor(seq_lens_)) or (tf and tf.is_tensor(seq_lens_)):
self[SampleBatch.SEQ_LENS] = seq_lens_
if (
self.max_seq_len is None
and seq_lens_ is not None
and not (tf and tf.is_tensor(seq_lens_))
and len(seq_lens_) > 0
):
if torch and torch.is_tensor(seq_lens_):
self.max_seq_len = seq_lens_.max().item()
else:
self.max_seq_len = max(seq_lens_)
if self._is_training is None:
self._is_training = self.pop("is_training", False)
for k, v in self.items():
# TODO: Drop support for lists and Numbers as values.
# Convert lists of int|float into numpy arrays make sure all data
# has same length.
if isinstance(v, (Number, list)) and not k == SampleBatch.INFOS:
self[k] = np.array(v)
self.count = attempt_count_timesteps(self)
# A convenience map for slicing this batch into sub-batches along
# the time axis. This helps reduce repeated iterations through the
# batch's seq_lens array to find good slicing points. Built lazily
# when needed.
self._slice_map = []
@PublicAPI
def __len__(self) -> int:
"""Returns the amount of samples in the sample batch."""
return self.count
@PublicAPI
def agent_steps(self) -> int:
"""Returns the same as len(self) (number of steps in this batch).
To make this compatible with `MultiAgentBatch.agent_steps()`.
"""
return len(self)
@PublicAPI
def env_steps(self) -> int:
"""Returns the same as len(self) (number of steps in this batch).
To make this compatible with `MultiAgentBatch.env_steps()`.
"""
return len(self)
@DeveloperAPI
def enable_slicing_by_batch_id(self):
self._slice_seq_lens_in_B = True
@DeveloperAPI
def disable_slicing_by_batch_id(self):
self._slice_seq_lens_in_B = False
@ExperimentalAPI
def is_terminated_or_truncated(self) -> bool:
"""Returns True if `self` is either terminated or truncated at idx -1."""
return self[SampleBatch.TERMINATEDS][-1] or (
SampleBatch.TRUNCATEDS in self and self[SampleBatch.TRUNCATEDS][-1]
)
@ExperimentalAPI
def is_single_trajectory(self) -> bool:
"""Returns True if this SampleBatch only contains one trajectory.
This is determined by checking all timesteps (except for the last) for being
not terminated AND (if applicable) not truncated.
"""
return not any(self[SampleBatch.TERMINATEDS][:-1]) and (
SampleBatch.TRUNCATEDS not in self
or not any(self[SampleBatch.TRUNCATEDS][:-1])
)
@staticmethod
@PublicAPI
@Deprecated(new="concat_samples() from rllib.policy.sample_batch", error=True)
def concat_samples(samples):
pass
@PublicAPI
def concat(self, other: "SampleBatch") -> "SampleBatch":
"""Concatenates `other` to this one and returns a new SampleBatch.
Args:
other: The other SampleBatch object to concat to this one.
Returns:
The new SampleBatch, resulting from concating `other` to `self`.
.. testcode::
:skipif: True
import numpy as np
from ray.rllib.policy.sample_batch import SampleBatch
b1 = SampleBatch({"a": np.array([1, 2])})
b2 = SampleBatch({"a": np.array([3, 4, 5])})
print(b1.concat(b2))
.. testoutput::
{"a": np.array([1, 2, 3, 4, 5])}
"""
return concat_samples([self, other])
@PublicAPI
def copy(self, shallow: bool = False) -> "SampleBatch":
"""Creates a deep or shallow copy of this SampleBatch and returns it.
Args:
shallow: Whether the copying should be done shallowly.
Returns:
A deep or shallow copy of this SampleBatch object.
"""
copy_ = {k: v for k, v in self.items()}
data = tree.map_structure(
lambda v: (
np.array(v, copy=not shallow) if isinstance(v, np.ndarray) else v
),
copy_,
)
copy_ = SampleBatch(
data,
_time_major=self.time_major,
_zero_padded=self.zero_padded,
_max_seq_len=self.max_seq_len,
_num_grad_updates=self.num_grad_updates,
)
copy_.set_get_interceptor(self.get_interceptor)
copy_.added_keys = self.added_keys
copy_.deleted_keys = self.deleted_keys
copy_.accessed_keys = self.accessed_keys
return copy_
@PublicAPI
def rows(self) -> Iterator[Dict[str, TensorType]]:
"""Returns an iterator over data rows, i.e. dicts with column values.
Note that if `seq_lens` is set in self, we set it to 1 in the rows.
Yields:
The column values of the row in this iteration.
.. testcode::
:skipif: True
from ray.rllib.policy.sample_batch import SampleBatch
batch = SampleBatch({
"a": [1, 2, 3],
"b": [4, 5, 6],
"seq_lens": [1, 2]
})
for row in batch.rows():
print(row)
.. testoutput::
{"a": 1, "b": 4, "seq_lens": 1}
{"a": 2, "b": 5, "seq_lens": 1}
{"a": 3, "b": 6, "seq_lens": 1}
"""
seq_lens = None if self.get(SampleBatch.SEQ_LENS, 1) is None else 1
self_as_dict = {k: v for k, v in self.items()}
for i in range(self.count):
yield tree.map_structure_with_path(
lambda p, v: v[i] if p[0] != self.SEQ_LENS else seq_lens,
self_as_dict,
)
@PublicAPI
def columns(self, keys: List[str]) -> List[any]:
"""Returns a list of the batch-data in the specified columns.
Args:
keys: List of column names fo which to return the data.
Returns:
The list of data items ordered by the order of column
names in `keys`.
.. testcode::
:skipif: True
from ray.rllib.policy.sample_batch import SampleBatch
batch = SampleBatch({"a": [1], "b": [2], "c": [3]})
print(batch.columns(["a", "b"]))
.. testoutput::
[[1], [2]]
"""
# TODO: (sven) Make this work for nested data as well.
out = []
for k in keys:
out.append(self[k])
return out
@PublicAPI
def shuffle(self) -> "SampleBatch":
"""Shuffles the rows of this batch in-place.
Returns:
This very (now shuffled) SampleBatch.
Raises:
ValueError: If self[SampleBatch.SEQ_LENS] is defined.
.. testcode::
:skipif: True
from ray.rllib.policy.sample_batch import SampleBatch
batch = SampleBatch({"a": [1, 2, 3, 4]})
print(batch.shuffle())
.. testoutput::
{"a": [4, 1, 3, 2]}
"""
# Shuffling the data when we have `seq_lens` defined is probably
# a bad idea!
if self.get(SampleBatch.SEQ_LENS) is not None:
raise ValueError(
"SampleBatch.shuffle not possible when your data has "
"`seq_lens` defined!"
)
# Get a permutation over the single items once and use the same
# permutation for all the data (otherwise, data would become
# meaningless).
permutation = np.random.permutation(self.count)
self_as_dict = {k: v for k, v in self.items()}
shuffled = tree.map_structure(lambda v: v[permutation], self_as_dict)
self.update(shuffled)
# Flush cache such that intercepted values are recalculated after the
# shuffling.
self.intercepted_values = {}
return self
@PublicAPI
def split_by_episode(self, key: Optional[str] = None) -> List["SampleBatch"]:
"""Splits by `eps_id` column and returns list of new batches.
If `eps_id` is not present, splits by `dones` instead.
Args:
key: If specified, overwrite default and use key to split.
Returns:
List of batches, one per distinct episode.
Raises:
KeyError: If the `eps_id` AND `dones` columns are not present.
.. testcode::
:skipif: True
from ray.rllib.policy.sample_batch import SampleBatch
# "eps_id" is present
batch = SampleBatch(
{"a": [1, 2, 3], "eps_id": [0, 0, 1]})
print(batch.split_by_episode())
# "eps_id" not present, split by "dones" instead
batch = SampleBatch(
{"a": [1, 2, 3, 4, 5], "dones": [0, 0, 1, 0, 1]})
print(batch.split_by_episode())
# The last episode is appended even if it does not end with done
batch = SampleBatch(
{"a": [1, 2, 3, 4, 5], "dones": [0, 0, 1, 0, 0]})
print(batch.split_by_episode())
batch = SampleBatch(
{"a": [1, 2, 3, 4, 5], "dones": [0, 0, 0, 0, 0]})
print(batch.split_by_episode())
.. testoutput::
[{"a": [1, 2], "eps_id": [0, 0]}, {"a": [3], "eps_id": [1]}]
[{"a": [1, 2, 3], "dones": [0, 0, 1]}, {"a": [4, 5], "dones": [0, 1]}]
[{"a": [1, 2, 3], "dones": [0, 0, 1]}, {"a": [4, 5], "dones": [0, 0]}]
[{"a": [1, 2, 3, 4, 5], "dones": [0, 0, 0, 0, 0]}]
"""
assert key is None or key in [SampleBatch.EPS_ID, SampleBatch.DONES], (
f"`SampleBatch.split_by_episode(key={key})` invalid! "
f"Must be [None|'dones'|'eps_id']."
)
def slice_by_eps_id():
slices = []
# Produce a new slice whenever we find a new episode ID.
cur_eps_id = self[SampleBatch.EPS_ID][0]
offset = 0
for i in range(self.count):
next_eps_id = self[SampleBatch.EPS_ID][i]
if next_eps_id != cur_eps_id:
slices.append(self[offset:i])
offset = i
cur_eps_id = next_eps_id
# Add final slice.
slices.append(self[offset : self.count])
return slices
def slice_by_terminateds_or_truncateds():
slices = []
offset = 0
for i in range(self.count):
if self[SampleBatch.TERMINATEDS][i] or (
SampleBatch.TRUNCATEDS in self and self[SampleBatch.TRUNCATEDS][i]
):
# Since self[i] is the last timestep of the episode,
# append it to the batch, then set offset to the start
# of the next batch
slices.append(self[offset : i + 1])
offset = i + 1
# Add final slice.
if offset != self.count:
slices.append(self[offset:])
return slices
key_to_method = {
SampleBatch.EPS_ID: slice_by_eps_id,
SampleBatch.DONES: slice_by_terminateds_or_truncateds,
}
# If key not specified, default to this order.
key_resolve_order = [SampleBatch.EPS_ID, SampleBatch.DONES]
slices = None
if key is not None:
# If key specified, directly use it.
if key == SampleBatch.EPS_ID and key not in self:
raise KeyError(f"{self} does not have key `{key}`!")
slices = key_to_method[key]()
else:
# If key not specified, go in order.
for key in key_resolve_order:
if key == SampleBatch.DONES or key in self:
slices = key_to_method[key]()
break
if slices is None:
raise KeyError(f"{self} does not have keys {key_resolve_order}!")
assert (
sum(s.count for s in slices) == self.count
), f"Calling split_by_episode on {self} returns {slices}"
f"which should in total have {self.count} timesteps!"
return slices
def slice(
self, start: int, end: int, state_start=None, state_end=None
) -> "SampleBatch":
"""Returns a slice of the row data of this batch (w/o copying).
Args:
start: Starting index. If < 0, will left-zero-pad.
end: Ending index.
Returns:
A new SampleBatch, which has a slice of this batch's data.
"""
if (
self.get(SampleBatch.SEQ_LENS) is not None
and len(self[SampleBatch.SEQ_LENS]) > 0
):
if start < 0:
data = {
k: np.concatenate(
[
np.zeros(shape=(-start,) + v.shape[1:], dtype=v.dtype),
v[0:end],
]
)
for k, v in self.items()
if k != SampleBatch.SEQ_LENS and not k.startswith("state_in_")
}
else:
data = {
k: tree.map_structure(lambda s: s[start:end], v)
for k, v in self.items()
if k != SampleBatch.SEQ_LENS and not k.startswith("state_in_")
}
if state_start is not None:
assert state_end is not None
state_idx = 0
state_key = "state_in_{}".format(state_idx)
while state_key in self:
data[state_key] = self[state_key][state_start:state_end]
state_idx += 1
state_key = "state_in_{}".format(state_idx)
seq_lens = list(self[SampleBatch.SEQ_LENS][state_start:state_end])
# Adjust seq_lens if necessary.
data_len = len(data[next(iter(data))])
if sum(seq_lens) != data_len:
assert sum(seq_lens) > data_len
seq_lens[-1] = data_len - sum(seq_lens[:-1])
else:
# Fix state_in_x data.
count = 0
state_start = None
seq_lens = None
for i, seq_len in enumerate(self[SampleBatch.SEQ_LENS]):
count += seq_len
if count >= end:
state_idx = 0
state_key = "state_in_{}".format(state_idx)
if state_start is None:
state_start = i
while state_key in self:
data[state_key] = self[state_key][state_start : i + 1]
state_idx += 1
state_key = "state_in_{}".format(state_idx)
seq_lens = list(self[SampleBatch.SEQ_LENS][state_start:i]) + [
seq_len - (count - end)
]
if start < 0:
seq_lens[0] += -start
diff = sum(seq_lens) - (end - start)
if diff > 0:
seq_lens[0] -= diff
assert sum(seq_lens) == (end - start)
break
elif state_start is None and count > start:
state_start = i
return SampleBatch(
data,
seq_lens=seq_lens,
_is_training=self.is_training,
_time_major=self.time_major,
_num_grad_updates=self.num_grad_updates,
)
else:
return SampleBatch(
tree.map_structure(lambda value: value[start:end], self),
_is_training=self.is_training,
_time_major=self.time_major,
_num_grad_updates=self.num_grad_updates,
)
def _batch_slice(self, slice_: slice) -> "SampleBatch":
"""Helper method to handle SampleBatch slicing using a slice object.
The returned SampleBatch uses the same underlying data object as
`self`, so changing the slice will also change `self`.
Note that only zero or positive bounds are allowed for both start
and stop values. The slice step must be 1 (or None, which is the
same).
Args:
slice_: The python slice object to slice by.
Returns:
A new SampleBatch, however "linking" into the same data
(sliced) as self.
"""
start = slice_.start or 0
stop = slice_.stop or len(self[SampleBatch.SEQ_LENS])
# If stop goes beyond the length of this batch -> Make it go till the
# end only (including last item).
# Analogous to `l = [0, 1, 2]; l[:100] -> [0, 1, 2];`.
if stop > len(self):
stop = len(self)
assert start >= 0 and stop >= 0 and slice_.step in [1, None]
# Exclude INFOs from regular array slicing as the data under this column might
# be a list (not good for `tree.map_structure` call).
# Furthermore, slicing does not work when the data in the column is
# singular (not a list or array).
infos = self.pop(SampleBatch.INFOS, None)
data = tree.map_structure(lambda value: value[start:stop], self)
if infos is not None:
data[SampleBatch.INFOS] = infos[start:stop]
return SampleBatch(
data,
_is_training=self.is_training,
_time_major=self.time_major,
_num_grad_updates=self.num_grad_updates,
)
@PublicAPI
def timeslices(
self,
size: Optional[int] = None,
num_slices: Optional[int] = None,
k: Optional[int] = None,
) -> List["SampleBatch"]:
"""Returns SampleBatches, each one representing a k-slice of this one.
Will start from timestep 0 and produce slices of size=k.
Args:
size: The size (in timesteps) of each returned SampleBatch.
num_slices: The number of slices to produce.
k: Deprecated: Use size or num_slices instead. The size
(in timesteps) of each returned SampleBatch.
Returns:
The list of `num_slices` (new) SampleBatches or n (new)
SampleBatches each one of size `size`.
"""
if size is None and num_slices is None:
deprecation_warning("k", "size or num_slices")
assert k is not None
size = k
if size is None:
assert isinstance(num_slices, int)
slices = []
left = len(self)
start = 0
while left:
len_ = left // (num_slices - len(slices))
stop = start + len_
slices.append(self[start:stop])
left -= len_
start = stop
return slices
else:
assert isinstance(size, int)
slices = []
left = len(self)
start = 0
while left:
stop = start + size
slices.append(self[start:stop])
left -= size
start = stop
return slices
@Deprecated(new="SampleBatch.right_zero_pad", error=True)
def zero_pad(self, max_seq_len, exclude_states=True):
pass
def right_zero_pad(self, max_seq_len: int, exclude_states: bool = True):
"""Right (adding zeros at end) zero-pads this SampleBatch in-place.
This will set the `self.zero_padded` flag to True and
`self.max_seq_len` to the given `max_seq_len` value.
Args:
max_seq_len: The max (total) length to zero pad to.
exclude_states: If False, also right-zero-pad all
`state_in_x` data. If True, leave `state_in_x` keys
as-is.
Returns:
This very (now right-zero-padded) SampleBatch.
Raises:
ValueError: If self[SampleBatch.SEQ_LENS] is None (not defined).
.. testcode::
:skipif: True
from ray.rllib.policy.sample_batch import SampleBatch
batch = SampleBatch(
{"a": [1, 2, 3], "seq_lens": [1, 2]})
print(batch.right_zero_pad(max_seq_len=4))
batch = SampleBatch({"a": [1, 2, 3],
"state_in_0": [1.0, 3.0],
"seq_lens": [1, 2]})
print(batch.right_zero_pad(max_seq_len=5))
.. testoutput::
{"a": [1, 0, 0, 0, 2, 3, 0, 0], "seq_lens": [1, 2]}
{"a": [1, 0, 0, 0, 0, 2, 3, 0, 0, 0],
"state_in_0": [1.0, 3.0], # <- all state-ins remain as-is
"seq_lens": [1, 2]}
"""
seq_lens = self.get(SampleBatch.SEQ_LENS)
if seq_lens is None:
raise ValueError(
"Cannot right-zero-pad SampleBatch if no `seq_lens` field "
f"present! SampleBatch={self}"
)
length = len(seq_lens) * max_seq_len
def _zero_pad_in_place(path, value):
# Skip "state_in_..." columns and "seq_lens".
if (exclude_states is True and path[0].startswith("state_in_")) or path[
0
] == SampleBatch.SEQ_LENS:
return
# Generate zero-filled primer of len=max_seq_len.
if value.dtype == object or value.dtype.type is np.str_:
f_pad = [None] * length
else:
# Make sure type doesn't change.
f_pad = np.zeros((length,) + np.shape(value)[1:], dtype=value.dtype)
# Fill primer with data.
f_pad_base = f_base = 0
for len_ in self[SampleBatch.SEQ_LENS]:
f_pad[f_pad_base : f_pad_base + len_] = value[f_base : f_base + len_]
f_pad_base += max_seq_len
f_base += len_
assert f_base == len(value), value
# Update our data in-place.
curr = self
for i, p in enumerate(path):
if i == len(path) - 1:
curr[p] = f_pad
curr = curr[p]
self_as_dict = {k: v for k, v in self.items()}
tree.map_structure_with_path(_zero_pad_in_place, self_as_dict)
# Set flags to indicate, we are now zero-padded (and to what extend).
self.zero_padded = True
self.max_seq_len = max_seq_len
return self
@ExperimentalAPI
def to_device(self, device, framework="torch"):
"""TODO: transfer batch to given device as framework tensor."""
if framework == "torch":
assert torch is not None
for k, v in self.items():
self[k] = convert_to_torch_tensor(v, device)
else:
raise NotImplementedError
return self
@PublicAPI
def size_bytes(self) -> int:
"""Returns sum over number of bytes of all data buffers.
For numpy arrays, we use ``.nbytes``. For all other value types, we use
sys.getsizeof(...).
Returns:
The overall size in bytes of the data buffer (all columns).
"""
return sum(
v.nbytes if isinstance(v, np.ndarray) else sys.getsizeof(v)
for v in tree.flatten(self)
)
def get(self, key, default=None):
"""Returns one column (by key) from the data or a default value."""
try:
return self.__getitem__(key)
except KeyError:
return default
@PublicAPI
def as_multi_agent(self) -> "MultiAgentBatch":
"""Returns the respective MultiAgentBatch using DEFAULT_POLICY_ID.
Returns:
The MultiAgentBatch (using DEFAULT_POLICY_ID) corresponding
to this SampleBatch.
"""
return MultiAgentBatch({DEFAULT_POLICY_ID: self}, self.count)
@PublicAPI
def __getitem__(self, key: Union[str, slice]) -> TensorType:
"""Returns one column (by key) from the data or a sliced new batch.
Args:
key: The key (column name) to return or
a slice object for slicing this SampleBatch.
Returns:
The data under the given key or a sliced version of this batch.
"""
if isinstance(key, slice):
return self._slice(key)
# Special key DONES -> Translate to `TERMINATEDS | TRUNCATEDS` to reflect
# the old meaning of DONES.
if key == SampleBatch.DONES:
return self[SampleBatch.TERMINATEDS]
# Backward compatibility for when "input-dicts" were used.
elif key == "is_training":
if log_once("SampleBatch['is_training']"):
deprecation_warning(
old="SampleBatch['is_training']",
new="SampleBatch.is_training",
error=False,
)
return self.is_training
if not hasattr(self, key) and key in self:
self.accessed_keys.add(key)
value = dict.__getitem__(self, key)
if self.get_interceptor is not None:
if key not in self.intercepted_values:
self.intercepted_values[key] = self.get_interceptor(value)
value = self.intercepted_values[key]
return value
@PublicAPI
def __setitem__(self, key, item) -> None:
"""Inserts (overrides) an entire column (by key) in the data buffer.
Args:
key: The column name to set a value for.
item: The data to insert.
"""
# Disallow setting DONES key directly.
if key == SampleBatch.DONES:
raise KeyError(
"Cannot set `DONES` anymore in a SampleBatch! "
"Instead, set the new TERMINATEDS and TRUNCATEDS keys. The values under"
" DONES will then be automatically computed using terminated|truncated."
)
# Defend against creating SampleBatch via pickle (no property
# `added_keys` and first item is already set).
elif not hasattr(self, "added_keys"):
dict.__setitem__(self, key, item)
return
# Backward compatibility for when "input-dicts" were used.
if key == "is_training":
if log_once("SampleBatch['is_training']"):
deprecation_warning(
old="SampleBatch['is_training']",
new="SampleBatch.is_training",
error=False,
)
self._is_training = item
return
if key not in self:
self.added_keys.add(key)
dict.__setitem__(self, key, item)
if key in self.intercepted_values:
self.intercepted_values[key] = item
@property
def is_training(self):
if self.get_interceptor is not None and isinstance(self._is_training, bool):
if "_is_training" not in self.intercepted_values:
self.intercepted_values["_is_training"] = self.get_interceptor(