-
Notifications
You must be signed in to change notification settings - Fork 5.4k
/
replay_ops.py
197 lines (160 loc) · 6.82 KB
/
replay_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
from typing import List, Any, Optional
import random
from ray.actor import ActorHandle
from ray.util.iter import from_actors, LocalIterator, _NextValueNotReady
from ray.util.iter_metrics import SharedMetrics
from ray.rllib.execution.buffers.replay_buffer import warn_replay_capacity
from ray.rllib.execution.buffers.multi_agent_replay_buffer import \
MultiAgentReplayBuffer
from ray.rllib.execution.common import \
STEPS_SAMPLED_COUNTER, _get_shared_metrics
from ray.rllib.utils.typing import SampleBatchType
class StoreToReplayBuffer:
"""Callable that stores data into replay buffer actors.
If constructed with a local replay actor, data will be stored into that
buffer. If constructed with a list of replay actor handles, data will
be stored randomly among those actors.
This should be used with the .for_each() operator on a rollouts iterator.
The batch that was stored is returned.
Examples:
>>> actors = [ReplayActor.remote() for _ in range(4)]
>>> rollouts = ParallelRollouts(...)
>>> store_op = rollouts.for_each(StoreToReplayActors(actors=actors))
>>> next(store_op)
SampleBatch(...)
"""
def __init__(
self,
*,
local_buffer: Optional[MultiAgentReplayBuffer] = None,
actors: Optional[List[ActorHandle]] = None,
):
"""
Args:
local_buffer: The local replay buffer to store the data into.
actors: An optional list of replay actors to use instead of
`local_buffer`.
"""
if bool(local_buffer) == bool(actors):
raise ValueError(
"Either `local_buffer` or `replay_actors` must be given, "
"not both!")
if local_buffer:
self.local_actor = local_buffer
self.replay_actors = None
else:
self.local_actor = None
self.replay_actors = actors
def __call__(self, batch: SampleBatchType):
if self.local_actor:
self.local_actor.add_batch(batch)
else:
actor = random.choice(self.replay_actors)
actor.add_batch.remote(batch)
return batch
def Replay(*,
local_buffer: MultiAgentReplayBuffer = None,
actors: List[ActorHandle] = None,
num_async: int = 4) -> LocalIterator[SampleBatchType]:
"""Replay experiences from the given buffer or actors.
This should be combined with the StoreToReplayActors operation using the
Concurrently() operator.
Args:
local_buffer: Local buffer to use. Only one of this and replay_actors
can be specified.
actors: List of replay actors. Only one of this and local_buffer
can be specified.
num_async: In async mode, the max number of async requests in flight
per actor.
Examples:
>>> actors = [ReplayActor.remote() for _ in range(4)]
>>> replay_op = Replay(actors=actors)
>>> next(replay_op)
SampleBatch(...)
"""
if bool(local_buffer) == bool(actors):
raise ValueError(
"Exactly one of local_buffer and replay_actors must be given.")
if actors:
replay = from_actors(actors)
return replay.gather_async(
num_async=num_async).filter(lambda x: x is not None)
def gen_replay(_):
while True:
item = local_buffer.replay()
if item is None:
yield _NextValueNotReady()
else:
yield item
return LocalIterator(gen_replay, SharedMetrics())
class WaitUntilTimestepsElapsed:
"""Callable that returns True once a given number of timesteps are hit."""
def __init__(self, target_num_timesteps: int):
self.target_num_timesteps = target_num_timesteps
def __call__(self, item: Any) -> bool:
metrics = _get_shared_metrics()
ts = metrics.counters[STEPS_SAMPLED_COUNTER]
return ts > self.target_num_timesteps
# TODO(ekl) deprecate this in favor of the replay_sequence_length option.
class SimpleReplayBuffer:
"""Simple replay buffer that operates over batches."""
def __init__(self,
num_slots: int,
replay_proportion: Optional[float] = None):
"""Initialize SimpleReplayBuffer.
Args:
num_slots (int): Number of batches to store in total.
"""
self.num_slots = num_slots
self.replay_batches = []
self.replay_index = 0
def add_batch(self, sample_batch: SampleBatchType) -> None:
warn_replay_capacity(item=sample_batch, num_items=self.num_slots)
if self.num_slots > 0:
if len(self.replay_batches) < self.num_slots:
self.replay_batches.append(sample_batch)
else:
self.replay_batches[self.replay_index] = sample_batch
self.replay_index += 1
self.replay_index %= self.num_slots
def replay(self) -> SampleBatchType:
return random.choice(self.replay_batches)
class MixInReplay:
"""This operator adds replay to a stream of experiences.
It takes input batches, and returns a list of batches that include replayed
data as well. The number of replayed batches is determined by the
configured replay proportion. The max age of a batch is determined by the
number of replay slots.
"""
def __init__(self, num_slots: int, replay_proportion: float):
"""Initialize MixInReplay.
Args:
num_slots (int): Number of batches to store in total.
replay_proportion (float): The input batch will be returned
and an additional number of batches proportional to this value
will be added as well.
Examples:
# replay proportion 2:1
>>> replay_op = MixInReplay(rollouts, 100, replay_proportion=2)
>>> print(next(replay_op))
[SampleBatch(<input>), SampleBatch(<replay>), SampleBatch(<rep.>)]
# replay proportion 0:1, replay disabled
>>> replay_op = MixInReplay(rollouts, 100, replay_proportion=0)
>>> print(next(replay_op))
[SampleBatch(<input>)]
"""
if replay_proportion > 0 and num_slots == 0:
raise ValueError(
"You must set num_slots > 0 if replay_proportion > 0.")
self.replay_buffer = SimpleReplayBuffer(num_slots)
self.replay_proportion = replay_proportion
def __call__(self, sample_batch: SampleBatchType) -> List[SampleBatchType]:
# Put in replay buffer if enabled.
self.replay_buffer.add_batch(sample_batch)
# Proportional replay.
output_batches = [sample_batch]
f = self.replay_proportion
while random.random() < f:
f -= 1
output_batches.append(self.replay_buffer.replay())
return output_batches