-
Notifications
You must be signed in to change notification settings - Fork 718
/
neural_linucb_policy.py
419 lines (382 loc) · 16.2 KB
/
neural_linucb_policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
# coding=utf-8
# Copyright 2020 The TF-Agents Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Neural + LinUCB Policy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Optional, Sequence, Text, Tuple
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
import tensorflow_probability as tfp
from tf_agents.bandits.policies import constraints
from tf_agents.bandits.policies import linalg
from tf_agents.bandits.specs import utils as bandit_spec_utils
from tf_agents.distributions import masked
from tf_agents.policies import tf_policy
from tf_agents.policies import utils as policy_utilities
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import policy_step
from tf_agents.typing import types
tfd = tfp.distributions
class NeuralLinUCBPolicy(tf_policy.TFPolicy):
"""Neural LinUCB Policy.
Applies LinUCB on top of an encoding network.
Since LinUCB is a linear method, the encoding network is used to capture the
non-linear relationship between the context features and the expected rewards.
The policy starts with exploration based on epsilon greedy and then switches
to LinUCB for exploring more efficiently.
This policy supports both the global-only observation model and the global and
per-arm model:
-- In the global-only case, there is one single observation per
time step, and every arm has its own reward estimation function.
-- In the per-arm case, all arms receive individual observations, and the
reward estimation function is identical for all arms.
Reference:
Carlos Riquelme, George Tucker, Jasper Snoek,
`Deep Bayesian Bandits Showdown: An Empirical Comparison of Bayesian Deep
Networks for Thompson Sampling`, ICLR 2018.
"""
def __init__(
self,
encoding_network: types.Network,
encoding_dim: int,
reward_layer: tf.keras.layers.Dense,
epsilon_greedy: float,
actions_from_reward_layer: types.Bool,
cov_matrix: Sequence[types.Float],
data_vector: Sequence[types.Float],
num_samples: Sequence[types.Int],
time_step_spec: types.TimeStep,
alpha: float = 1.0,
emit_policy_info: Sequence[Text] = (),
emit_log_probability: bool = False,
accepts_per_arm_features: bool = False,
distributed_use_reward_layer: bool = False,
observation_and_action_constraint_splitter: Optional[
types.Splitter
] = None,
name: Optional[Text] = None,
):
"""Initializes `NeuralLinUCBPolicy`.
Args:
encoding_network: network that encodes the observations.
encoding_dim: (int) dimension of the encoded observations.
reward_layer: final layer that predicts the expected reward per arm. In
case the policy accepts per-arm features, the output of this layer has
to be a scalar. This is because in the per-arm case, all encoded
observations have to go through the same computation to get the reward
estimates. The `num_actions` dimension of the encoded observation is
treated as a batch dimension in the reward layer.
epsilon_greedy: (float) representing the probability of choosing a random
action instead of the greedy action.
actions_from_reward_layer: (boolean variable) whether to get actions from
the reward layer or from LinUCB.
cov_matrix: list of the covariance matrices. There exists one covariance
matrix per arm, unless the policy accepts per-arm features, in which
case this list must have a single element.
data_vector: list of the data vectors. A data vector is a weighted sum of
the observations, where the weight is the corresponding reward. Each arm
has its own data vector, unless the policy accepts per-arm features, in
which case this list must have a single element.
num_samples: list of number of samples per arm. If the policy accepts per-
arm features, this is a single-element list counting the number of
steps.
time_step_spec: A `TimeStep` spec of the expected time_steps.
alpha: (float) non-negative weight multiplying the confidence intervals.
emit_policy_info: (tuple of strings) what side information we want to get
as part of the policy info. Allowed values can be found in
`policy_utilities.PolicyInfo`.
emit_log_probability: (bool) whether to emit log probabilities.
accepts_per_arm_features: (bool) Whether the policy accepts per-arm
features.
distributed_use_reward_layer: (bool) Whether to pick the actions using the
network or use LinUCB. This applies only in distributed training setting
and has a similar role to the `actions_from_reward_layer` mentioned
above.
observation_and_action_constraint_splitter: A function used for masking
valid/invalid actions with each state of the environment. The function
takes in a full observation and returns a tuple consisting of 1) the
part of the observation intended as input to the bandit policy and 2)
the mask. The mask should be a 0-1 `Tensor` of shape `[batch_size,
num_actions]`. This function should also work with a `TensorSpec` as
input, and should output `TensorSpec` objects for the observation and
mask.
name: The name of this policy.
"""
policy_utilities.check_no_mask_with_arm_features(
accepts_per_arm_features, observation_and_action_constraint_splitter
)
encoding_network.create_variables()
self._encoding_network = encoding_network
self._reward_layer = reward_layer
self._encoding_dim = encoding_dim
if accepts_per_arm_features and reward_layer.units != 1:
raise ValueError(
'The output dimension of the reward layer must be 1, got {}'.format(
reward_layer.units
)
)
if not isinstance(cov_matrix, (list, tuple)):
raise ValueError('cov_matrix must be a list of matrices (Tensors).')
self._cov_matrix = cov_matrix
if not isinstance(data_vector, (list, tuple)):
raise ValueError('data_vector must be a list of vectors (Tensors).')
self._data_vector = data_vector
if not isinstance(num_samples, (list, tuple)):
raise ValueError('num_samples must be a list of vectors (Tensors).')
self._num_samples = num_samples
self._alpha = alpha
self._actions_from_reward_layer = actions_from_reward_layer
self._epsilon_greedy = epsilon_greedy
self._dtype = self._data_vector[0].dtype
self._distributed_use_reward_layer = distributed_use_reward_layer
if len(cov_matrix) != len(data_vector):
raise ValueError(
'The size of list cov_matrix must match the size of '
'list data_vector. Got {} for cov_matrix and {} '
'for data_vector'.format(len(self._cov_matrix), len((data_vector)))
)
if len(num_samples) != len(cov_matrix):
raise ValueError(
'The size of num_samples must match the size of '
'list cov_matrix. Got {} for num_samples and {} '
'for cov_matrix'.format(len(self._num_samples), len((cov_matrix)))
)
self._accepts_per_arm_features = accepts_per_arm_features
if observation_and_action_constraint_splitter is not None:
context_spec, _ = observation_and_action_constraint_splitter(
time_step_spec.observation
)
else:
context_spec = time_step_spec.observation
if accepts_per_arm_features:
self._num_actions = tf.nest.flatten(
context_spec[bandit_spec_utils.PER_ARM_FEATURE_KEY]
)[0].shape.as_list()[0]
self._num_models = 1
else:
self._num_actions = len(cov_matrix)
self._num_models = self._num_actions
cov_matrix_dim = tf.compat.dimension_value(cov_matrix[0].shape[0])
if self._encoding_dim != cov_matrix_dim:
raise ValueError(
'The dimension of matrix `cov_matrix` must match '
'encoding dimension {}.'
'Got {} for `cov_matrix`.'.format(self._encoding_dim, cov_matrix_dim)
)
data_vector_dim = tf.compat.dimension_value(data_vector[0].shape[0])
if self._encoding_dim != data_vector_dim:
raise ValueError(
'The dimension of vector `data_vector` must match '
'encoding dimension {}. '
'Got {} for `data_vector`.'.format(
self._encoding_dim, data_vector_dim
)
)
action_spec = tensor_spec.BoundedTensorSpec(
shape=(),
dtype=tf.int32,
minimum=0,
maximum=self._num_actions - 1,
name='action',
)
self._emit_policy_info = emit_policy_info
predicted_rewards_mean = ()
if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN in emit_policy_info:
predicted_rewards_mean = tensor_spec.TensorSpec(
[self._num_actions], dtype=tf.float32
)
predicted_rewards_optimistic = ()
if (
policy_utilities.InfoFields.PREDICTED_REWARDS_OPTIMISTIC
in emit_policy_info
):
predicted_rewards_optimistic = tensor_spec.TensorSpec(
[self._num_actions], dtype=tf.float32
)
if accepts_per_arm_features:
chosen_arm_features_info_spec = (
policy_utilities.create_chosen_arm_features_info_spec(
time_step_spec.observation
)
)
info_spec = policy_utilities.PerArmPolicyInfo(
predicted_rewards_mean=predicted_rewards_mean,
predicted_rewards_optimistic=predicted_rewards_optimistic,
chosen_arm_features=chosen_arm_features_info_spec,
)
else:
info_spec = policy_utilities.PolicyInfo(
predicted_rewards_mean=predicted_rewards_mean,
predicted_rewards_optimistic=predicted_rewards_optimistic,
)
super(NeuralLinUCBPolicy, self).__init__(
time_step_spec=time_step_spec,
action_spec=action_spec,
emit_log_probability=emit_log_probability,
observation_and_action_constraint_splitter=(
observation_and_action_constraint_splitter
),
info_spec=info_spec,
name=name,
)
def _variables(self):
all_variables = [
self._cov_matrix,
self._data_vector,
self._num_samples,
self._actions_from_reward_layer,
self._encoding_network.variables,
self._reward_layer.variables,
]
return [
v for v in tf.nest.flatten(all_variables) if isinstance(v, tf.Variable)
]
def _get_actions_from_reward_layer(
self, encoded_observation: types.Float, mask: Optional[types.Tensor]
) -> Tuple[types.Int, types.Float, types.Float]:
# Get the predicted expected reward.
est_mean_reward = tf.reshape(
self._reward_layer(encoded_observation), shape=[-1, self._num_actions]
)
if mask is None:
greedy_actions = tf.argmax(est_mean_reward, axis=-1, output_type=tf.int32)
else:
greedy_actions = policy_utilities.masked_argmax(
est_mean_reward, mask, output_type=tf.int32
)
# Add epsilon greedy on top, if needed.
if self._epsilon_greedy:
batch_size = (
tf.compat.dimension_value(encoded_observation.shape[0])
or tf.shape(encoded_observation)[0]
)
if mask is None:
random_actions = tf.random.uniform(
[batch_size], maxval=self._num_actions, dtype=tf.int32
)
else:
zero_logits = tf.cast(tf.zeros_like(mask), tf.float32)
masked_categorical = masked.MaskedCategorical(
zero_logits, mask, dtype=tf.int32
)
random_actions = masked_categorical.sample()
rng = tf.random.uniform([batch_size], maxval=1.0)
cond = tf.greater(rng, self._epsilon_greedy)
chosen_actions = tf.compat.v1.where(cond, greedy_actions, random_actions)
else:
chosen_actions = greedy_actions
return chosen_actions, est_mean_reward, est_mean_reward
def _get_actions_from_linucb(
self, encoded_observation: types.Float, mask: Optional[types.Tensor]
) -> Tuple[types.Int, types.Float, types.Float]:
encoded_observation = tf.cast(encoded_observation, dtype=self._dtype)
p_values = []
est_rewards = []
for k in range(self._num_actions):
encoded_observation_for_arm = self._get_encoded_observation_for_arm(
encoded_observation, k
)
model_index = policy_utilities.get_model_index(
k, self._accepts_per_arm_features
)
a_inv_x = linalg.conjugate_gradient(
self._cov_matrix[model_index]
+ tf.eye(self._encoding_dim, dtype=self._dtype),
tf.linalg.matrix_transpose(encoded_observation_for_arm),
)
mean_reward_est = tf.einsum(
'j,jk->k', self._data_vector[model_index], a_inv_x
)
est_rewards.append(mean_reward_est)
ci = tf.reshape(
tf.linalg.tensor_diag_part(
tf.matmul(encoded_observation_for_arm, a_inv_x)
),
[-1, 1],
)
p_values.append(
tf.reshape(mean_reward_est, [-1, 1]) + self._alpha * tf.sqrt(ci)
)
stacked_p_values = tf.squeeze(tf.stack(p_values, axis=-1), axis=[1])
if mask is None:
chosen_actions = tf.argmax(
stacked_p_values, axis=-1, output_type=tf.int32
)
else:
chosen_actions = policy_utilities.masked_argmax(
stacked_p_values, mask, output_type=tf.int32
)
est_mean_reward = tf.cast(tf.stack(est_rewards, axis=-1), tf.float32)
return (
chosen_actions,
est_mean_reward,
tf.cast(stacked_p_values, tf.float32),
)
def _distribution(self, time_step, policy_state):
raise NotImplementedError(
'This policy outputs an action and not a distribution.'
)
def _action(self, time_step, policy_state, seed): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
observation = time_step.observation
if self.observation_and_action_constraint_splitter is not None:
observation, _ = self.observation_and_action_constraint_splitter(
observation
)
mask = constraints.construct_mask_from_multiple_sources(
time_step.observation,
self._observation_and_action_constraint_splitter,
(),
self._num_actions,
)
# Pass the observations through the encoding network.
encoded_observation, _ = self._encoding_network(observation)
encoded_observation = tf.cast(encoded_observation, dtype=self._dtype)
if tf.distribute.has_strategy():
if self._distributed_use_reward_layer:
chosen_actions, est_mean_rewards, est_rewards_optimistic = (
self._get_actions_from_reward_layer(encoded_observation, mask)
)
else:
chosen_actions, est_mean_rewards, est_rewards_optimistic = (
self._get_actions_from_linucb(encoded_observation, mask)
)
else:
chosen_actions, est_mean_rewards, est_rewards_optimistic = tf.cond(
self._actions_from_reward_layer,
# pylint: disable=g-long-lambda
lambda: self._get_actions_from_reward_layer(
encoded_observation, mask
),
lambda: self._get_actions_from_linucb(encoded_observation, mask),
)
arm_observations = ()
if self._accepts_per_arm_features:
arm_observations = observation[bandit_spec_utils.PER_ARM_FEATURE_KEY]
policy_info = policy_utilities.populate_policy_info(
arm_observations,
chosen_actions,
est_rewards_optimistic,
est_mean_rewards,
self._emit_policy_info,
self._accepts_per_arm_features,
)
return policy_step.PolicyStep(chosen_actions, policy_state, policy_info)
def _get_encoded_observation_for_arm(
self, encoded_observation: types.Float, arm_index: int
) -> types.Float:
if self._accepts_per_arm_features:
return encoded_observation[:, arm_index, :]
else:
return encoded_observation