-
-
Notifications
You must be signed in to change notification settings - Fork 535
/
categorical.py
executable file
·204 lines (164 loc) · 8.54 KB
/
categorical.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# Copyright 2018 Tensorforce Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
from tensorforce import TensorforceError, util
from tensorforce.core import layer_modules, Module
from tensorforce.core.distributions import Distribution
class Categorical(Distribution):
"""
Categorical distribution, for discrete integer actions (specification key: `categorical`).
Args:
name (string): Distribution name
(<span style="color:#0000C0"><b>internal use</b></span>).
action_spec (specification): Action specification
(<span style="color:#0000C0"><b>internal use</b></span>).
embedding_shape (iter[int > 0]): Embedding shape
(<span style="color:#0000C0"><b>internal use</b></span>).
infer_states_value (bool): Whether to infer the state value from state-action values as
softmax denominator (<span style="color:#00C000"><b>default</b></span>: true).
summary_labels ('all' | iter[string]): Labels of summaries to record
(<span style="color:#00C000"><b>default</b></span>: inherit value of parent module).
"""
def __init__(
self, name, action_spec, embedding_shape, infer_states_value=True, summary_labels=None
):
super().__init__(
name=name, action_spec=action_spec, embedding_shape=embedding_shape,
summary_labels=summary_labels
)
input_spec = dict(type='float', shape=self.embedding_shape)
num_values = self.action_spec['num_values']
if len(self.embedding_shape) == 1:
action_size = util.product(xs=self.action_spec['shape'])
self.deviations = self.add_module(
name='deviations', module='linear', modules=layer_modules,
size=(action_size * num_values), input_spec=input_spec
)
if infer_states_value:
self.value = None
else:
self.value = self.add_module(
name='value', module='linear', modules=layer_modules, size=action_size,
input_spec=input_spec
)
else:
if len(self.embedding_shape) < 1 or len(self.embedding_shape) > 3:
raise TensorforceError.unexpected()
if self.embedding_shape[:-1] == self.action_spec['shape'][:-1]:
size = self.action_spec['shape'][-1]
elif self.embedding_shape[:-1] == self.action_spec['shape']:
size = 1
else:
raise TensorforceError.unexpected()
self.deviations = self.add_module(
name='deviations', module='linear', modules=layer_modules,
size=(size * num_values), input_spec=input_spec
)
if infer_states_value:
self.value = None
else:
self.value = self.add_module(
name='value', module='linear', modules=layer_modules, size=size,
input_spec=input_spec
)
Module.register_tensor(
name=(self.name + '-probabilities'),
spec=dict(type='float', shape=(self.action_spec['shape'] + (num_values,))),
batched=True
)
def tf_parametrize(self, x, mask):
epsilon = tf.constant(value=util.epsilon, dtype=util.tf_dtype(dtype='float'))
shape = (-1,) + self.action_spec['shape'] + (self.action_spec['num_values'],)
value_shape = (-1,) + self.action_spec['shape'] + (1,)
# Deviations
action_values = self.deviations.apply(x=x)
action_values = tf.reshape(tensor=action_values, shape=shape)
min_float = tf.fill(
dims=tf.shape(input=action_values), value=util.tf_dtype(dtype='float').min
)
# States value
if self.value is None:
action_values = tf.where(condition=mask, x=action_values, y=min_float)
states_value = tf.reduce_logsumexp(input_tensor=action_values, axis=-1)
else:
states_value = self.value.apply(x=x)
if len(self.embedding_shape) == 1:
states_value = tf.reshape(tensor=states_value, shape=value_shape)
action_values = states_value + action_values - tf.math.reduce_mean(
input_tensor=action_values, axis=-1, keepdims=True
)
states_value = tf.squeeze(input=states_value, axis=-1)
action_values = tf.where(condition=mask, x=action_values, y=min_float)
# Softmax for corresponding probabilities
probabilities = tf.nn.softmax(logits=action_values, axis=-1)
# "Normalized" logits
logits = tf.math.log(x=tf.maximum(x=probabilities, y=epsilon))
Module.update_tensor(name=(self.name + '-probabilities'), tensor=probabilities)
return logits, probabilities, states_value, action_values
def tf_sample(self, parameters, temperature):
logits, probabilities, _, _ = parameters
summary_probs = probabilities
for _ in range(len(self.action_spec['shape'])):
summary_probs = tf.math.reduce_mean(input_tensor=summary_probs, axis=1)
logits, probabilities = self.add_summary(
label=('distributions', 'categorical'), name='probabilities', tensor=summary_probs,
pass_tensors=(logits, probabilities), enumerate_last_rank=True
)
one = tf.constant(value=1.0, dtype=util.tf_dtype(dtype='float'))
epsilon = tf.constant(value=util.epsilon, dtype=util.tf_dtype(dtype='float'))
# Deterministic: maximum likelihood action
definite = tf.argmax(input=logits, axis=-1)
definite = tf.dtypes.cast(x=definite, dtype=util.tf_dtype('int'))
# Set logits to minimal value
min_float = tf.fill(dims=tf.shape(input=logits), value=util.tf_dtype(dtype='float').min)
logits = logits / temperature
logits = tf.where(condition=(probabilities < epsilon), x=min_float, y=logits)
# Non-deterministic: sample action using Gumbel distribution
uniform_distribution = tf.random.uniform(
shape=tf.shape(input=logits), minval=epsilon, maxval=(one - epsilon),
dtype=util.tf_dtype(dtype='float')
)
gumbel_distribution = -tf.math.log(x=-tf.math.log(x=uniform_distribution))
sampled = tf.argmax(input=(logits + gumbel_distribution), axis=-1)
sampled = tf.dtypes.cast(x=sampled, dtype=util.tf_dtype('int'))
return tf.where(condition=(temperature < epsilon), x=definite, y=sampled)
def tf_log_probability(self, parameters, action):
logits, _, _, _ = parameters
if util.tf_dtype(dtype='int') not in (tf.int32, tf.int64):
action = tf.dtypes.cast(x=action, dtype=tf.int32)
logits = tf.gather(
params=logits, indices=tf.expand_dims(input=action, axis=-1), batch_dims=-1
)
return tf.squeeze(input=logits, axis=-1)
def tf_entropy(self, parameters):
logits, probabilities, _, _ = parameters
return -tf.reduce_sum(input_tensor=(probabilities * logits), axis=-1)
def tf_kl_divergence(self, parameters1, parameters2):
logits1, probabilities1, _, _ = parameters1
logits2, _, _, _ = parameters2
log_prob_ratio = logits1 - logits2
return tf.reduce_sum(input_tensor=(probabilities1 * log_prob_ratio), axis=-1)
def tf_action_value(self, parameters, action=None):
_, _, _, action_values = parameters
if action is not None:
if util.tf_dtype(dtype='int') not in (tf.int32, tf.int64):
action = tf.dtypes.cast(x=action, dtype=tf.int32)
action = tf.expand_dims(input=action, axis=-1)
action_values = tf.gather(params=action_values, indices=action, batch_dims=-1)
action_values = tf.squeeze(input=action_values, axis=-1)
return action_values # states_value + tf.squeeze(input=logits, axis=-1)
def tf_states_value(self, parameters):
_, _, states_value, _ = parameters
return states_value