/
recurrent_net.py
283 lines (239 loc) · 10.8 KB
/
recurrent_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
import numpy as np
import gymnasium as gym
from gymnasium.spaces import Discrete, MultiDiscrete
import logging
import tree # pip install dm_tree
from typing import Dict, List, Tuple
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.tf_utils import flatten_inputs_to_1d_tensor, one_hot
from ray.rllib.utils.typing import ModelConfigDict, TensorType
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
@DeveloperAPI
class RecurrentNetwork(TFModelV2):
"""Helper class to simplify implementing RNN models with TFModelV2.
Instead of implementing forward(), you can implement forward_rnn() which
takes batches with the time dimension added already.
Here is an example implementation for a subclass
``MyRNNClass(RecurrentNetwork)``::
def __init__(self, *args, **kwargs):
super(MyModelClass, self).__init__(*args, **kwargs)
cell_size = 256
# Define input layers
input_layer = tf.keras.layers.Input(
shape=(None, obs_space.shape[0]))
state_in_h = tf.keras.layers.Input(shape=(256, ))
state_in_c = tf.keras.layers.Input(shape=(256, ))
seq_in = tf.keras.layers.Input(shape=(), dtype=tf.int32)
# Send to LSTM cell
lstm_out, state_h, state_c = tf.keras.layers.LSTM(
cell_size, return_sequences=True, return_state=True,
name="lstm")(
inputs=input_layer,
mask=tf.sequence_mask(seq_in),
initial_state=[state_in_h, state_in_c])
output_layer = tf.keras.layers.Dense(...)(lstm_out)
# Create the RNN model
self.rnn_model = tf.keras.Model(
inputs=[input_layer, seq_in, state_in_h, state_in_c],
outputs=[output_layer, state_h, state_c])
self.rnn_model.summary()
"""
@override(ModelV2)
def forward(
self,
input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType,
) -> Tuple[TensorType, List[TensorType]]:
"""Adds time dimension to batch before sending inputs to forward_rnn().
You should implement forward_rnn() in your subclass."""
assert seq_lens is not None
flat_inputs = input_dict["obs_flat"]
inputs = add_time_dimension(
padded_inputs=flat_inputs, seq_lens=seq_lens, framework="tf"
)
output, new_state = self.forward_rnn(
inputs,
state,
seq_lens,
)
return tf.reshape(output, [-1, self.num_outputs]), new_state
def forward_rnn(
self, inputs: TensorType, state: List[TensorType], seq_lens: TensorType
) -> Tuple[TensorType, List[TensorType]]:
"""Call the model with the given input tensors and state.
Args:
inputs: observation tensor with shape [B, T, obs_size].
state: list of state tensors, each with shape [B, T, size].
seq_lens: 1d tensor holding input sequence lengths.
Returns:
(outputs, new_state): The model output tensor of shape
[B, T, num_outputs] and the list of new state tensors each with
shape [B, size].
Sample implementation for the ``MyRNNClass`` example::
def forward_rnn(self, inputs, state, seq_lens):
model_out, h, c = self.rnn_model([inputs, seq_lens] + state)
return model_out, [h, c]
"""
raise NotImplementedError("You must implement this for a RNN model")
def get_initial_state(self) -> List[TensorType]:
"""Get the initial recurrent state values for the model.
Returns:
list of np.array objects, if any
Sample implementation for the ``MyRNNClass`` example::
def get_initial_state(self):
return [
np.zeros(self.cell_size, np.float32),
np.zeros(self.cell_size, np.float32),
]
"""
raise NotImplementedError("You must implement this for a RNN model")
@DeveloperAPI
class LSTMWrapper(RecurrentNetwork):
"""An LSTM wrapper serving as an interface for ModelV2s that set use_lstm."""
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
):
super(LSTMWrapper, self).__init__(
obs_space, action_space, None, model_config, name
)
# At this point, self.num_outputs is the number of nodes coming
# from the wrapped (underlying) model. In other words, self.num_outputs
# is the input size for the LSTM layer.
# If None, set it to the observation space.
if self.num_outputs is None:
self.num_outputs = int(np.product(self.obs_space.shape))
self.cell_size = model_config["lstm_cell_size"]
self.use_prev_action = model_config["lstm_use_prev_action"]
self.use_prev_reward = model_config["lstm_use_prev_reward"]
self.action_space_struct = get_base_struct_from_space(self.action_space)
self.action_dim = 0
for space in tree.flatten(self.action_space_struct):
if isinstance(space, Discrete):
self.action_dim += space.n
elif isinstance(space, MultiDiscrete):
self.action_dim += np.sum(space.nvec)
elif space.shape is not None:
self.action_dim += int(np.product(space.shape))
else:
self.action_dim += int(len(space))
# Add prev-action/reward nodes to input to LSTM.
if self.use_prev_action:
self.num_outputs += self.action_dim
if self.use_prev_reward:
self.num_outputs += 1
# Define input layers.
input_layer = tf.keras.layers.Input(
shape=(None, self.num_outputs), name="inputs"
)
# Set self.num_outputs to the number of output nodes desired by the
# caller of this constructor.
self.num_outputs = num_outputs
state_in_h = tf.keras.layers.Input(shape=(self.cell_size,), name="h")
state_in_c = tf.keras.layers.Input(shape=(self.cell_size,), name="c")
seq_in = tf.keras.layers.Input(shape=(), name="seq_in", dtype=tf.int32)
# Preprocess observation with a hidden layer and send to LSTM cell
lstm_out, state_h, state_c = tf.keras.layers.LSTM(
self.cell_size, return_sequences=True, return_state=True, name="lstm"
)(
inputs=input_layer,
mask=tf.sequence_mask(seq_in),
initial_state=[state_in_h, state_in_c],
)
# Postprocess LSTM output with another hidden layer and compute values
logits = tf.keras.layers.Dense(
self.num_outputs, activation=tf.keras.activations.linear, name="logits"
)(lstm_out)
values = tf.keras.layers.Dense(1, activation=None, name="values")(lstm_out)
# Create the RNN model
self._rnn_model = tf.keras.Model(
inputs=[input_layer, seq_in, state_in_h, state_in_c],
outputs=[logits, values, state_h, state_c],
)
# Print out model summary in INFO logging mode.
if logger.isEnabledFor(logging.INFO):
self._rnn_model.summary()
# Add prev-a/r to this model's view, if required.
if model_config["lstm_use_prev_action"]:
self.view_requirements[SampleBatch.PREV_ACTIONS] = ViewRequirement(
SampleBatch.ACTIONS, space=self.action_space, shift=-1
)
if model_config["lstm_use_prev_reward"]:
self.view_requirements[SampleBatch.PREV_REWARDS] = ViewRequirement(
SampleBatch.REWARDS, shift=-1
)
@override(RecurrentNetwork)
def forward(
self,
input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType,
) -> Tuple[TensorType, List[TensorType]]:
assert seq_lens is not None
# Push obs through "unwrapped" net's `forward()` first.
wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
# Concat. prev-action/reward if required.
prev_a_r = []
# Prev actions.
if self.model_config["lstm_use_prev_action"]:
prev_a = input_dict[SampleBatch.PREV_ACTIONS]
# If actions are not processed yet (in their original form as
# have been sent to environment):
# Flatten/one-hot into 1D array.
if self.model_config["_disable_action_flattening"]:
prev_a_r.append(
flatten_inputs_to_1d_tensor(
prev_a,
spaces_struct=self.action_space_struct,
time_axis=False,
)
)
# If actions are already flattened (but not one-hot'd yet!),
# one-hot discrete/multi-discrete actions here.
else:
if isinstance(self.action_space, (Discrete, MultiDiscrete)):
prev_a = one_hot(prev_a, self.action_space)
prev_a_r.append(
tf.reshape(tf.cast(prev_a, tf.float32), [-1, self.action_dim])
)
# Prev rewards.
if self.model_config["lstm_use_prev_reward"]:
prev_a_r.append(
tf.reshape(
tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32), [-1, 1]
)
)
# Concat prev. actions + rewards to the "main" input.
if prev_a_r:
wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)
# Push everything through our LSTM.
input_dict["obs_flat"] = wrapped_out
return super().forward(input_dict, state, seq_lens)
@override(RecurrentNetwork)
def forward_rnn(
self, inputs: TensorType, state: List[TensorType], seq_lens: TensorType
) -> Tuple[TensorType, List[TensorType]]:
model_out, self._value_out, h, c = self._rnn_model([inputs, seq_lens] + state)
return model_out, [h, c]
@override(ModelV2)
def get_initial_state(self) -> List[np.ndarray]:
return [
np.zeros(self.cell_size, np.float32),
np.zeros(self.cell_size, np.float32),
]
@override(ModelV2)
def value_function(self) -> TensorType:
return tf.reshape(self._value_out, [-1])