/
critic_network.py
169 lines (149 loc) · 7.5 KB
/
critic_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# coding=utf-8
# Copyright 2018 The TF-Agents Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sample Critic/Q network to use with DDPG agents."""
import gin
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
from tf_agents.networks import network
from tf_agents.networks import utils
@gin.configurable
class CriticNetwork(network.Network):
"""Creates a critic network."""
def __init__(self,
input_tensor_spec,
observation_conv_layer_params=None,
observation_fc_layer_params=None,
observation_dropout_layer_params=None,
action_fc_layer_params=None,
action_dropout_layer_params=None,
joint_fc_layer_params=None,
joint_dropout_layer_params=None,
activation_fn=tf.nn.relu,
output_activation_fn=None,
kernel_initializer=None,
last_kernel_initializer=None,
name='CriticNetwork'):
"""Creates an instance of `CriticNetwork`.
Args:
input_tensor_spec: A tuple of (observation, action) each a nest of
`tensor_spec.TensorSpec` representing the inputs.
observation_conv_layer_params: Optional list of convolution layer
parameters for observations, where each item is a length-three tuple
indicating (num_units, kernel_size, stride).
observation_fc_layer_params: Optional list of fully connected parameters
for observations, where each item is the number of units in the layer.
observation_dropout_layer_params: Optional list of dropout layer
parameters, each item is the fraction of input units to drop or a
dictionary of parameters according to the keras.Dropout documentation.
The additional parameter `permanent', if set to True, allows to apply
dropout at inference for approximated Bayesian inference. The dropout
layers are interleaved with the fully connected layers; there is a
dropout layer after each fully connected layer, except if the entry in
the list is None. This list must have the same length of
observation_fc_layer_params, or be None.
action_fc_layer_params: Optional list of fully connected parameters for
actions, where each item is the number of units in the layer.
action_dropout_layer_params: Optional list of dropout layer parameters,
each item is the fraction of input units to drop or a dictionary of
parameters according to the keras.Dropout documentation. The additional
parameter `permanent', if set to True, allows to apply dropout at
inference for approximated Bayesian inference. The dropout layers are
interleaved with the fully connected layers; there is a dropout layer
after each fully connected layer, except if the entry in the list is
None. This list must have the same length of action_fc_layer_params, or
be None.
joint_fc_layer_params: Optional list of fully connected parameters after
merging observations and actions, where each item is the number of units
in the layer.
joint_dropout_layer_params: Optional list of dropout layer parameters,
each item is the fraction of input units to drop or a dictionary of
parameters according to the keras.Dropout documentation. The additional
parameter `permanent', if set to True, allows to apply dropout at
inference for approximated Bayesian inference. The dropout layers are
interleaved with the fully connected layers; there is a dropout layer
after each fully connected layer, except if the entry in the list is
None. This list must have the same length of joint_fc_layer_params, or
be None.
activation_fn: Activation function, e.g. tf.nn.relu, slim.leaky_relu, ...
output_activation_fn: Activation function for the last layer. This can be
used to restrict the range of the output. For example, one can pass
tf.keras.activations.sigmoid here to restrict the output to be bounded
between 0 and 1.
kernel_initializer: kernel initializer for all layers except for the value
regression layer. If None, a VarianceScaling initializer will be used.
last_kernel_initializer: kernel initializer for the value regression
layer. If None, a RandomUniform initializer will be used.
name: A string representing name of the network.
Raises:
ValueError: If `observation_spec` or `action_spec` contains more than one
observation.
"""
super(CriticNetwork, self).__init__(
input_tensor_spec=input_tensor_spec,
state_spec=(),
name=name)
observation_spec, action_spec = input_tensor_spec
if len(tf.nest.flatten(observation_spec)) > 1:
raise ValueError('Only a single observation is supported by this network')
flat_action_spec = tf.nest.flatten(action_spec)
if len(flat_action_spec) > 1:
raise ValueError('Only a single action is supported by this network')
self._single_action_spec = flat_action_spec[0]
if kernel_initializer is None:
kernel_initializer = tf.compat.v1.keras.initializers.VarianceScaling(
scale=1. / 3., mode='fan_in', distribution='uniform')
if last_kernel_initializer is None:
last_kernel_initializer = tf.keras.initializers.RandomUniform(
minval=-0.003, maxval=0.003)
# TODO(kbanoop): Replace mlp_layers with encoding networks.
self._observation_layers = utils.mlp_layers(
observation_conv_layer_params,
observation_fc_layer_params,
observation_dropout_layer_params,
activation_fn=activation_fn,
kernel_initializer=kernel_initializer,
name='observation_encoding')
self._action_layers = utils.mlp_layers(
None,
action_fc_layer_params,
action_dropout_layer_params,
activation_fn=activation_fn,
kernel_initializer=kernel_initializer,
name='action_encoding')
self._joint_layers = utils.mlp_layers(
None,
joint_fc_layer_params,
joint_dropout_layer_params,
activation_fn=activation_fn,
kernel_initializer=kernel_initializer,
name='joint_mlp')
self._joint_layers.append(
tf.keras.layers.Dense(
1,
activation=output_activation_fn,
kernel_initializer=last_kernel_initializer,
name='value'))
def call(self, inputs, step_type=(), network_state=(), training=False):
observations, actions = inputs
del step_type # unused.
observations = tf.cast(tf.nest.flatten(observations)[0], tf.float32)
for layer in self._observation_layers:
observations = layer(observations, training=training)
actions = tf.cast(tf.nest.flatten(actions)[0], tf.float32)
for layer in self._action_layers:
actions = layer(actions, training=training)
joint = tf.concat([observations, actions], 1)
for layer in self._joint_layers:
joint = layer(joint, training=training)
return tf.reshape(joint, [-1]), network_state