Skip to content

Commit

Permalink
Move squash_scorrection under NNPolicy
Browse files Browse the repository at this point in the history
  • Loading branch information
hartikainen committed Jul 21, 2018
1 parent 328458e commit 36155f7
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 37 deletions.
13 changes: 0 additions & 13 deletions softlearning/policies/gaussian_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,19 +158,6 @@ def get_actions(self, observations, with_log_pis=False, with_raw_actions=False):
return super(GaussianPolicy, self).get_actions(
observations, with_log_pis, with_raw_actions)

def _squash_correction(self, actions):
if not self._squash:
return 0

# Numerically stable squash correction without bias from EPS,
# return tf.reduce_sum(tf.log(1 - tf.tanh(actions) **2 + EPS), axis=1)
return tf.reduce_sum(
2.0 * (
tf.log(2.0)
- actions
- tf.nn.softplus(-2. * actions)
), axis=1)

@contextmanager
def deterministic(self, set_deterministic=True):
"""Context manager for changing the determinism of the policy.
Expand Down
18 changes: 7 additions & 11 deletions softlearning/policies/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@
from softlearning.policies import NNPolicy
from softlearning.misc import tf_utils


EPS = 1e-6


class GMMPolicy(NNPolicy, Serializable):
"""
Gaussian Mixture Model policy
TODO: change interfaces to match other policies to support returning as log_pis for given actions.
TODO: change interfaces to match other policies to support returning as
log_pis for given actions.
"""
def __init__(self, env_spec, K=2, hidden_layer_sizes=(100, 100), reg=1e-3,
squash=True, reparameterize=False, qf=None, name='gmm_policy'):
Expand Down Expand Up @@ -152,7 +155,7 @@ def get_actions(self, observations, with_log_pis=False, with_raw_actions=False):
qs = self._qf.eval(observations, squashed_mus)

if self._fixed_h is not None:
h = self._fixed_h # TODO.code_consolidation: this needs to be tiled
h = self._fixed_h # TODO.code_consolidation: needs to be tiled
else:
h = np.argmax(qs) # TODO.code_consolidation: check the axis

Expand All @@ -164,15 +167,8 @@ def get_actions(self, observations, with_log_pis=False, with_raw_actions=False):

return actions

return super(GMMPolicy, self).get_actions(observations, with_log_pis, with_raw_actions)

def _squash_correction(self, actions):
if not self._squash:
return 0
# return tf.reduce_sum(tf.log(1 - tf.tanh(actions) **2 + EPS), axis=1)

# numerically stable squash correction without bias from EPS
return tf.reduce_sum(2. * (tf.log(2.) - actions - tf.nn.softplus(-2. * actions)), axis=1)
return super(GMMPolicy, self).get_actions(
observations, with_log_pis, with_raw_actions)

@contextmanager
def deterministic(self, set_deterministic=True, latent=None):
Expand Down
13 changes: 0 additions & 13 deletions softlearning/policies/latent_space_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,19 +225,6 @@ def get_actions(self,
return super(LatentSpacePolicy, self).get_actions(
observations, with_log_pis, with_raw_actions)

def _squash_correction(self, actions):
if not self._squash:
return 0
# return tf.reduce_sum(tf.log(1 - tf.tanh(actions) **2 + EPS), axis=1)

# numerically stable squash correction without bias from EPS
return tf.reduce_sum(
2.0 * (
tf.log(2.)
- actions
- tf.nn.softplus(-2.0 * actions)
), axis=1)

@contextmanager
def deterministic(self, set_deterministic=True, h=None):
"""Context manager for changing the determinism of the policy.
Expand Down
13 changes: 13 additions & 0 deletions softlearning/policies/nn_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ def __init__(self, name, env_spec, observation_ph, actions):

super(NNPolicy, self).__init__(env_spec)

def _squash_correction(self, actions):
if not self._squash:
return 0

# Numerically stable squash correction without bias from EPS,
# return tf.reduce_sum(tf.log(1 - tf.tanh(actions) **2 + EPS), axis=1)
return tf.reduce_sum(
2.0 * (
tf.log(2.0)
- actions
- tf.nn.softplus(-2. * actions)
), axis=1)

@overrides
def get_action(self,
observation,
Expand Down

0 comments on commit 36155f7

Please sign in to comment.