Skip to content

Commit

Permalink
Implement double Q for SQL
Browse files Browse the repository at this point in the history
  • Loading branch information
hartikainen committed Feb 1, 2019
1 parent df9c535 commit d35509e
Showing 1 changed file with 77 additions and 48 deletions.
125 changes: 77 additions & 48 deletions softlearning/algorithms/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
self,
env,
pool,
Q,
Qs,
policy,
plotter=None,
policy_lr=1E-3,
Expand All @@ -56,7 +56,9 @@ def __init__(
Args:
env (`SoftlearningEnv`): Environment object used for training.
pool (`PoolBase`): Replay pool to add gathered samples to.
Q: Q-function approximator.
Qs: Q-function approximators. The min of these
approximators will be used. Usage of at least two Q-functions
improves performance by reducing overestimation bias.
policy: A policy function approximator.
plotter (`QFPolicyPlotter`): Plotter instance to be used for
visualizing Q-function during training.
Expand Down Expand Up @@ -86,8 +88,8 @@ def __init__(

self.env = env
self.pool = pool
self._Q = Q
self._Q_target = tf.keras.models.clone_model(self._Q)
self._Qs = Qs
self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs)
self.policy = policy
self.plotter = plotter

Expand Down Expand Up @@ -119,16 +121,17 @@ def __init__(
self._create_svgd_update()

if use_saved_Q:
saved_Q_params = Q.get_param_values()
saved_Q_weights = tuple(Q.get_param_values() for Q in self._Qs)
if use_saved_policy:
saved_policy_params = policy.get_param_values()
saved_policy_weights = policy.get_weights()

self._session.run(tf.global_variables_initializer())

if use_saved_Q:
self._Q.set_param_values(saved_Q_params)
for Q, Q_weights in zip(self._Qs, saved_Q_weights):
Q.set_weights(Q_weights)
if use_saved_policy:
self.policy.set_param_values(saved_policy_params)
self.policy.set_weights(saved_policy_weights)

def _create_placeholders(self):
"""Create all necessary placeholders."""
Expand Down Expand Up @@ -178,47 +181,69 @@ def _create_td_update(self):
target_actions, (tf.shape(self._next_observations_ph)[0], 1, 1))
target_actions = tf.reshape(target_actions, (-1, *self._action_shape))

Q_next_target = self._Q_target([next_observations, target_actions])
Q_next_target = tf.reshape(
Q_next_target, (-1, self._value_n_particles))
Q_next_targets = tuple(
Q([next_observations, target_actions])
for Q in self._Q_targets)

assert_shape(Q_next_target, (None, self._value_n_particles))
min_Q_next_targets = tf.reduce_min(Q_next_targets, axis=0)

self._Q_values = self._Q([self._observations_ph, self._actions_ph])
assert_shape(self._Q_values, [None, 1])
assert_shape(min_Q_next_targets, (None, 1))

min_Q_next_target = tf.reshape(
min_Q_next_targets, (-1, self._value_n_particles))

assert_shape(min_Q_next_target, (None, self._value_n_particles))

# Equation 10:
next_value = tf.reduce_logsumexp(Q_next_target, keepdims=True, axis=1)
next_value = tf.reduce_logsumexp(
min_Q_next_target, keepdims=True, axis=1)
assert_shape(next_value, [None, 1])

# Importance weights add just a constant to the value.
next_value -= tf.log(tf.cast(self._value_n_particles, tf.float32))
next_value += np.prod(self._action_shape) * np.log(2)

# \hat Q in Equation 11:
ys = tf.stop_gradient(
Q_target = tf.stop_gradient(
self._reward_scale
* self._rewards_ph
+ (1 - self._terminals_ph)
* self._discount
* next_value)
assert_shape(ys, [None, 1])
assert_shape(Q_target, [None, 1])

# Equation 11:
bellman_residual = tf.losses.mean_squared_error(
labels=ys, predictions=self._Q_values, weights=0.5)
Q_values = self._Q_values = tuple(
Q([self._observations_ph, self._actions_ph])
for Q in self._Qs)

self._Q_optimizer = tf.train.AdamOptimizer(
learning_rate=self._Q_lr,
name='Q_optimizer'
)
for Q_value in self._Q_values:
assert_shape(Q_value, [None, 1])

if self._train_Q:
td_train_op = self._Q_optimizer.minimize(
loss=bellman_residual, var_list=self._Q.trainable_variables)
self._training_ops.append(td_train_op)
# Equation 11:
Q_losses = self._Q_losses = tuple(
tf.losses.mean_squared_error(
labels=Q_target, predictions=Q_value, weights=0.5)
for Q_value in Q_values)

self._bellman_residual = bellman_residual
if self._train_Q:
self._Q_optimizers = tuple(
tf.train.AdamOptimizer(
learning_rate=self._Q_lr,
name='{}_{}_optimizer'.format(Q._name, i)
) for i, Q in enumerate(self._Qs))
Q_training_ops = tuple(
tf.contrib.layers.optimize_loss(
Q_loss,
None,
learning_rate=self._Q_lr,
optimizer=Q_optimizer,
variables=Q.trainable_variables,
increment_global_step=False,
summaries=())
for i, (Q, Q_loss, Q_optimizer)
in enumerate(zip(self._Qs, Q_losses, self._Q_optimizers)))

self._training_ops.append(tf.group(Q_training_ops))

def _create_svgd_update(self):
"""Create a minimization operation for policy update (SVGD)."""
Expand Down Expand Up @@ -252,12 +277,13 @@ def _create_svgd_update(self):
assert_shape(updated_actions,
[None, n_updated_actions, *self._action_shape])

svgd_target_values = self._Q_target([
tf.tile(self._observations_ph, (n_fixed_actions, 1)),
tf.reshape(fixed_actions, (-1, *self._action_shape))
])
Q_log_targets = tuple(
Q([tf.tile(self._observations_ph, (n_fixed_actions, 1)),
tf.reshape(fixed_actions, (-1, *self._action_shape))])
for Q in self._Qs)
min_Q_log_target = tf.reduce_min(Q_log_targets, axis=0)
svgd_target_values = tf.reshape(
svgd_target_values,
min_Q_log_target,
(-1, n_fixed_actions, 1))

# Target log-density. Q_soft in Equation 13:
Expand Down Expand Up @@ -320,12 +346,13 @@ def _init_training(self):
def _update_target(self, tau=None):
tau = tau or self._tau

source_params = self._Q.get_weights()
target_params = self._Q_target.get_weights()
self._Q_target.set_weights([
tau * source + (1.0 - tau) * target
for source, target in zip(source_params, target_params)
])
for Q, Q_target in zip(self._Qs, self._Q_targets):
source_params = Q.get_weights()
target_params = Q_target.get_weights()
Q_target.set_weights([
tau * source + (1.0 - tau) * target
for source, target in zip(source_params, target_params)
])

def _do_training(self, iteration, batch):
"""Run the operations for updating training and target ops."""
Expand Down Expand Up @@ -364,13 +391,13 @@ def get_diagnostics(self,
"""

feeds = self._get_feed_dict(batch)
Q_np, bellman_residual = self._session.run(
[self._Q_values, self._bellman_residual], feeds)
Q_values, Q_losses = self._session.run(
[self._Q_values, self._Q_losses], feeds)

diagnostics = OrderedDict({
'Q-avg': np.mean(Q_np),
'Q-std': np.std(Q_np),
'mean-sq-bellman-error': bellman_residual,
'Q-avg': np.mean(Q_values),
'Q-std': np.std(Q_values),
'Q_loss': np.mean(Q_losses),
})

policy_diagnostics = self.policy.get_diagnostics(batch['observations'])
Expand Down Expand Up @@ -407,7 +434,9 @@ def get_snapshot(self, epoch):
@property
def tf_saveables(self):
return {
'_Q_target': self._Q_target,
'_Q_optimizer': self._Q_optimizer,
'_policy_optimizer': self._policy_optimizer
'_policy_optimizer': self._policy_optimizer,
**{
f'Q_optimizer_{i}': optimizer
for i, optimizer in enumerate(self._Q_optimizers)
},
}

0 comments on commit d35509e

Please sign in to comment.