Skip to content

Commit

Permalink
support custom mean / std network in continuous policy
Browse files Browse the repository at this point in the history
  • Loading branch information
dementrock committed May 24, 2016
1 parent d7aa694 commit bc1b506
Showing 1 changed file with 31 additions and 23 deletions.
54 changes: 31 additions & 23 deletions rllab/policies/gaussian_mlp_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def __init__(
std_hidden_nonlinearity=NL.tanh,
hidden_nonlinearity=NL.tanh,
output_nonlinearity=None,
mean_network=None,
std_network=None,
):
"""
:param env_spec:
Expand All @@ -44,6 +46,8 @@ def __init__(
:param std_hidden_nonlinearity:
:param hidden_nonlinearity: nonlinearity used for each hidden layer
:param output_nonlinearity: nonlinearity for the output layer
:param mean_network: custom network for the output mean
:param std_network: custom network for the output log std
:return:
"""
Serializable.quick_init(self, locals())
Expand All @@ -53,36 +57,40 @@ def __init__(
action_dim = env_spec.action_space.flat_dim

# create network
mean_network = MLP(
input_shape=(obs_dim,),
output_dim=action_dim,
hidden_sizes=hidden_sizes,
hidden_nonlinearity=hidden_nonlinearity,
output_nonlinearity=output_nonlinearity,
)
if mean_network is None:
mean_network = MLP(
input_shape=(obs_dim,),
output_dim=action_dim,
hidden_sizes=hidden_sizes,
hidden_nonlinearity=hidden_nonlinearity,
output_nonlinearity=output_nonlinearity,
)
self._mean_network = mean_network

l_mean = mean_network.output_layer
obs_var = mean_network.input_var

if adaptive_std:
std_network = MLP(
input_shape=(obs_dim,),
input_layer=mean_network.input_layer,
output_dim=action_dim,
hidden_sizes=std_hidden_sizes,
hidden_nonlinearity=std_hidden_nonlinearity,
output_nonlinearity=None,
)
if std_network is not None:
l_log_std = std_network.output_layer
else:
l_log_std = ParamLayer(
mean_network.input_layer,
num_units=action_dim,
param=lasagne.init.Constant(np.log(init_std)),
name="output_log_std",
trainable=learn_std,
)
if adaptive_std:
std_network = MLP(
input_shape=(obs_dim,),
input_layer=mean_network.input_layer,
output_dim=action_dim,
hidden_sizes=std_hidden_sizes,
hidden_nonlinearity=std_hidden_nonlinearity,
output_nonlinearity=None,
)
l_log_std = std_network.output_layer
else:
l_log_std = ParamLayer(
mean_network.input_layer,
num_units=action_dim,
param=lasagne.init.Constant(np.log(init_std)),
name="output_log_std",
trainable=learn_std,
)

self.min_std = min_std

Expand Down

0 comments on commit bc1b506

Please sign in to comment.