Skip to content

Commit

Permalink
Rename sql attributes to match sac
Browse files Browse the repository at this point in the history
  • Loading branch information
hartikainen committed Feb 1, 2019
1 parent df5db9f commit b1ba430
Showing 1 changed file with 39 additions and 28 deletions.
67 changes: 39 additions & 28 deletions softlearning/algorithms/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ class SQL(RLAlgorithm):
def __init__(
self,
env,
pool,
Qs,
policy,
Qs,
pool,
plotter=None,

policy_lr=1E-3,
Q_lr=1E-3,
value_n_particles=16,
Expand All @@ -55,11 +56,11 @@ def __init__(
"""
Args:
env (`SoftlearningEnv`): Environment object used for training.
pool (`PoolBase`): Replay pool to add gathered samples to.
policy: A policy function approximator.
Qs: Q-function approximators. The min of these
approximators will be used. Usage of at least two Q-functions
improves performance by reducing overestimation bias.
policy: A policy function approximator.
pool (`PoolBase`): Replay pool to add gathered samples to.
plotter (`QFPolicyPlotter`): Plotter instance to be used for
visualizing Q-function during training.
Q_lr (`float`): Learning rate used for the Q-function approximator.
Expand All @@ -86,12 +87,15 @@ def __init__(
"""
super(SQL, self).__init__(**kwargs)

self.env = env
self.pool = pool
self._env = env
self._policy = policy

self._Qs = Qs
self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs)
self.policy = policy
self.plotter = plotter

self._pool = pool
self._plotter = plotter
self._env = env

self._Q_lr = Q_lr
self._policy_lr = policy_lr
Expand All @@ -110,8 +114,13 @@ def __init__(
self._train_Q = train_Q
self._train_policy = train_policy

self._observation_shape = list(self.env.observation_space.shape)
self._action_shape = list(self.env.action_space.shape)
observation_shape = env.active_observation_shape
action_shape = env.action_space.shape

assert len(observation_shape) == 1, observation_shape
self._observation_shape = observation_shape
assert len(action_shape) == 1, action_shape
self._action_shape = action_shape

self._create_placeholders()

Expand All @@ -131,7 +140,7 @@ def __init__(
for Q, Q_weights in zip(self._Qs, saved_Q_weights):
Q.set_weights(Q_weights)
if use_saved_policy:
self.policy.set_weights(saved_policy_weights)
self._policy.set_weights(saved_policy_weights)

def _create_placeholders(self):
"""Create all necessary placeholders."""
Expand Down Expand Up @@ -248,7 +257,7 @@ def _create_td_update(self):
def _create_svgd_update(self):
"""Create a minimization operation for policy update (SVGD)."""

actions = self.policy.actions([
actions = self._policy.actions([
tf.tile(
self._observations_ph,
(self._kernel_n_particles,
Expand Down Expand Up @@ -312,12 +321,12 @@ def _create_svgd_update(self):
# Propagate the gradient through the policy network (Equation 14).
gradients = tf.gradients(
updated_actions,
self.policy.trainable_variables,
self._policy.trainable_variables,
grad_ys=action_gradients)

surrogate_loss = tf.reduce_sum([
tf.reduce_sum(w * tf.stop_gradient(g))
for w, g in zip(self.policy.trainable_variables, gradients)
for w, g in zip(self._policy.trainable_variables, gradients)
])

self._policy_optimizer = tf.train.AdamOptimizer(
Expand All @@ -328,17 +337,19 @@ def _create_svgd_update(self):
if self._train_policy:
svgd_training_op = self._policy_optimizer.minimize(
loss=-surrogate_loss,
var_list=self.policy.trainable_variables)
var_list=self._policy.trainable_variables)
self._training_ops.append(svgd_training_op)

# TODO: do not pass, policy, and pool to `__init__` directly.
def train(self):
initial_exploration_policy = None
def train(self, *args, **kwargs):
"""Initiate training of the SAC instance."""

return self._train(
self.env,
self.policy,
self.pool,
initial_exploration_policy=initial_exploration_policy)
self._env,
self._policy,
self._pool,
initial_exploration_policy=self._initial_exploration_policy,
*args,
**kwargs)

def _init_training(self):
self._update_target(tau=1.0)
Expand Down Expand Up @@ -400,14 +411,14 @@ def get_diagnostics(self,
'Q_loss': np.mean(Q_losses),
})

policy_diagnostics = self.policy.get_diagnostics(batch['observations'])
policy_diagnostics = self._policy.get_diagnostics(batch['observations'])
diagnostics.update({
f'policy/{key}': value
for key, value in policy_diagnostics.items()
})

if self.plotter:
self.plotter.draw()
if self._plotter:
self._plotter.draw()

return diagnostics

Expand All @@ -421,13 +432,13 @@ def get_snapshot(self, epoch):

state = {
'epoch': epoch,
'policy': self.policy,
'policy': self._policy,
'Q': self._Q,
'env': self.env,
'env': self._env,
}

if self._save_full_state:
state.update({'replay_pool': self.pool})
state.update({'replay_pool': self._pool})

return state

Expand Down

0 comments on commit b1ba430

Please sign in to comment.