Skip to content

Commit

Permalink
Remove state value function from SAC
Browse files Browse the repository at this point in the history
  • Loading branch information
hartikainen committed Nov 22, 2018
1 parent f56457e commit d33c114
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 77 deletions.
6 changes: 1 addition & 5 deletions examples/development/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
from softlearning.preprocessors.utils import get_preprocessor_from_variant
from softlearning.replay_pools.utils import get_replay_pool_from_variant
from softlearning.samplers.utils import get_sampler_from_variant
from softlearning.value_functions.utils import (
get_Q_function_from_variant,
get_V_function_from_variant)
from softlearning.value_functions.utils import get_Q_function_from_variant

from softlearning.misc.utils import set_seed

Expand All @@ -33,7 +31,6 @@ def _setup(self, variant):
sampler = get_sampler_from_variant(variant)
preprocessor = get_preprocessor_from_variant(variant, env)
Qs = get_Q_function_from_variant(variant, env)
V = get_V_function_from_variant(variant, env)
policy = get_policy_from_variant(variant, env, Qs, preprocessor)
initial_exploration_policy = get_policy('UniformPolicy', env)

Expand All @@ -43,7 +40,6 @@ def _setup(self, variant):
policy=policy,
initial_exploration_policy=initial_exploration_policy,
Qs=Qs,
V=V,
pool=replay_pool,
sampler=sampler,
)
Expand Down
6 changes: 0 additions & 6 deletions examples/development/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,6 @@ def get_variant_spec(universe, domain, task, policy):
POLICY_PARAMS_BASE[policy],
POLICY_PARAMS_FOR_DOMAIN[policy].get(domain, {})
),
'V_params': {
'type': 'feedforward_V_function',
'kwargs': {
'hidden_layer_sizes': (M, M),
}
},
'Q_params': {
'type': 'double_feedforward_Q_function',
'kwargs': {
Expand Down
12 changes: 1 addition & 11 deletions examples/multigoal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
from softlearning.samplers import SimpleSampler
from softlearning.policies.utils import get_policy_from_variant
from softlearning.replay_pools import SimpleReplayPool
from softlearning.value_functions.utils import (
get_Q_function_from_variant,
get_V_function_from_variant)
from softlearning.value_functions.utils import get_Q_function_from_variant
from examples.utils import get_parser, launch_experiments_ray


Expand All @@ -31,7 +29,6 @@ def run_experiment(variant, reporter):
max_path_length=30, min_pool_size=100, batch_size=64)

Qs = get_Q_function_from_variant(variant, env)
V = get_V_function_from_variant(variant, env)

policy = get_policy_from_variant(variant, env, Qs, preprocessor=None)
plotter = QFPolicyPlotter(
Expand Down Expand Up @@ -60,7 +57,6 @@ def run_experiment(variant, reporter):
initial_exploration_policy=None,
pool=pool,
Qs=Qs,
V=V,
plotter=plotter,

lr=3e-4,
Expand Down Expand Up @@ -100,12 +96,6 @@ def main():
'regularization_coeff': 1e-3,
},
},
'V_params': {
'type': 'feedforward_V_function',
'kwargs': {
'hidden_layer_sizes': (layer_size, layer_size),
}
},
'Q_params': {
'type': 'double_feedforward_Q_function',
'kwargs': {
Expand Down
78 changes: 23 additions & 55 deletions softlearning/algorithms/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def __init__(
policy,
initial_exploration_policy,
Qs,
V,
pool,
plotter=None,
tf_summaries=False,
Expand All @@ -56,7 +55,6 @@ def __init__(
Qs: Q-function approximators. The min of these
approximators will be used. Usage of at least two Q-functions
improves performance by reducing overestimation bias.
V: Soft value function approximator.
pool (`PoolBase`): Replay pool to add gathered samples to.
plotter (`QFPolicyPlotter`): Plotter instance to be used for
visualizing Q-function during training.
Expand All @@ -81,16 +79,14 @@ def __init__(
self._initial_exploration_policy = initial_exploration_policy

self._Qs = Qs
self._V = V
self._V_target = tf.keras.models.clone_model(self._V)
self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs)

self._pool = pool
self._plotter = plotter
self._tf_summaries = tf_summaries

self._policy_lr = lr
self._Q_lr = lr
self._V_lr = lr

self._reward_scale = reward_scale
self._target_entropy = (
Expand Down Expand Up @@ -131,7 +127,7 @@ def _build(self):

def _initialize_tf_variables(self):
# Initialize all uninitialized variables. This prevents initializing
# pre-trained policy and Q and V variables. tf.metrics (used at
# pre-trained policy and Q variables. tf.metrics (used at
# least in the LFP-policy) uses local variables.
uninit_vars = []
for var in tf.global_variables() + tf.local_variables():
Expand Down Expand Up @@ -221,13 +217,21 @@ def _init_placeholders(self):
)

def _get_Q_target(self):
V_next_target = self._V_target([self._next_observations_ph])
next_actions = self._policy.actions([self._next_observations_ph])
next_log_pis = self._policy.log_pis(
[self._next_observations_ph], next_actions)

next_Qs_values = tuple(
Q([self._next_observations_ph, next_actions])
for Q in self._Qs)

min_next_Q = tf.reduce_min(next_Qs_values, axis=0)
next_value = min_next_Q - self._alpha * next_log_pis

Q_target = td_target(
reward=self._reward_scale * self._rewards_ph,
discount=self._discount,
next_value=(1 - self._terminals_ph) * V_next_target
) # N
next_value=(1 - self._terminals_ph) * next_value)

return Q_target

Expand Down Expand Up @@ -318,8 +322,6 @@ def _init_actor_update(self):

self._alpha = alpha

V_value = self._V_value = self._V([self._observations_ph]) # N

if self._action_prior == 'normal':
policy_prior = tf.contrib.distributions.MultivariateNormalDiag(
loc=tf.zeros(self._action_shape),
Expand All @@ -339,13 +341,7 @@ def _init_actor_update(self):
- min_Q_log_target
- policy_prior_log_probs)
else:
raise NotImplementedError(
"TODO(hartikainen): Make sure to stop policy gradients"
" correctly. See old GaussianPolicy implementation.")
policy_kl_losses = (
log_pis * tf.stop_gradient(
alpha * log_pis - min_Q_log_target + V_value
- policy_prior_log_probs))
raise NotImplementedError

assert policy_kl_losses.shape.as_list() == [None, 1]

Expand All @@ -354,16 +350,6 @@ def _init_actor_update(self):

policy_loss = policy_kl_loss + policy_regularization_loss

# We update the V towards the min of two Q-functions in order to
# reduce overestimation bias from function approximation error.
V_target = tf.stop_gradient(
min_Q_log_target
- alpha * log_pis
+ policy_prior_log_probs)

V_loss = self._V_loss = tf.losses.mean_squared_error(
labels=V_target, predictions=V_value, weights=0.5)

policy_train_op = tf.contrib.layers.optimize_loss(
policy_loss,
self.global_step,
Expand All @@ -376,34 +362,19 @@ def _init_actor_update(self):
"loss", "gradients", "gradient_norm", "global_gradient_norm"
) if self._tf_summaries else ())

V_train_op = tf.contrib.layers.optimize_loss(
V_loss,
self.global_step,
learning_rate=self._V_lr,
optimizer=tf.train.AdamOptimizer,
variables=self._V.trainable_variables,
increment_global_step=False,
name="V_optimizer",
summaries=(
"loss", "gradients", "gradient_norm", "global_gradient_norm"
) if self._tf_summaries else ())

self._training_ops.update({
'policy': policy_train_op,
'V': V_train_op,
})
self._training_ops.update({'policy_train_op': policy_train_op})

def _init_training(self):
self._update_target()

def _update_target(self):
source_params = self._V.get_weights()
target_params = self._V_target.get_weights()

self._V_target.set_weights([
(1 - self._tau) * target + self._tau * source
for target, source in zip(target_params, source_params)
])
for Q, Q_target in zip(self._Qs, self._Q_targets):
source_params = Q.get_weights()
target_params = Q_target.get_weights()
Q_target.set_weights([
(1 - self._tau) * target + self._tau * source
for target, source in zip(target_params, source_params)
])

def _do_training(self, iteration, batch):
"""Runs the operations for updating training and target ops."""
Expand Down Expand Up @@ -448,9 +419,8 @@ def get_diagnostics(self, iteration, batch, paths):

feed_dict = self._get_feed_dict(iteration, batch)

(Q_values, V_value, Q_losses, alpha, global_step) = self._sess.run(
(Q_values, Q_losses, alpha, global_step) = self._sess.run(
(self._Q_values,
self._V_value,
self._Q_losses,
self._alpha,
self.global_step),
Expand All @@ -459,8 +429,6 @@ def get_diagnostics(self, iteration, batch, paths):
diagnostics = OrderedDict({
'Q-avg': np.mean(Q_values),
'Q-std': np.std(Q_values),
'V-avg': np.mean(V_value),
'V-std': np.std(V_value),
'Q_loss': np.mean(Q_losses),
'alpha': alpha,
})
Expand Down

0 comments on commit d33c114

Please sign in to comment.