Permalink
Branch: master
Find file Copy path
450f0f7 Apr 24, 2016
1 contributor

Users who have contributed to this file

140 lines (124 sloc) 4.67 KB
import theano.tensor as TT
import theano
from rllab.misc import logger
from rllab.misc.overrides import overrides
from rllab.misc import ext
from rllab.algos.batch_polopt import BatchPolopt
from rllab.optimizers.first_order_optimizer import FirstOrderOptimizer
from rllab.core.serializable import Serializable
class VPG(BatchPolopt, Serializable):
"""
Vanilla Policy Gradient.
"""
def __init__(
self,
env,
policy,
baseline,
optimizer=None,
optimizer_args=None,
**kwargs):
Serializable.quick_init(self, locals())
if optimizer is None:
default_args = dict(
batch_size=None,
max_epochs=1,
)
if optimizer_args is None:
optimizer_args = default_args
else:
optimizer_args = dict(default_args, **optimizer_args)
optimizer = FirstOrderOptimizer(**optimizer_args)
self.optimizer = optimizer
self.opt_info = None
super(VPG, self).__init__(env=env, policy=policy, baseline=baseline, **kwargs)
@overrides
def init_opt(self):
is_recurrent = int(self.policy.recurrent)
obs_var = self.env.observation_space.new_tensor_variable(
'obs',
extra_dims=1 + is_recurrent,
)
action_var = self.env.action_space.new_tensor_variable(
'action',
extra_dims=1 + is_recurrent,
)
advantage_var = ext.new_tensor(
'advantage',
ndim=1 + is_recurrent,
dtype=theano.config.floatX
)
dist = self.policy.distribution
old_dist_info_vars = {
k: ext.new_tensor(
'old_%s' % k,
ndim=2 + is_recurrent,
dtype=theano.config.floatX
) for k in dist.dist_info_keys
}
old_dist_info_vars_list = [old_dist_info_vars[k] for k in dist.dist_info_keys]
if is_recurrent:
valid_var = TT.matrix('valid')
else:
valid_var = None
state_info_vars = {
k: ext.new_tensor(
k,
ndim=2 + is_recurrent,
dtype=theano.config.floatX
) for k in self.policy.state_info_keys
}
state_info_vars_list = [state_info_vars[k] for k in self.policy.state_info_keys]
dist_info_vars = self.policy.dist_info_sym(obs_var, state_info_vars)
logli = dist.log_likelihood_sym(action_var, dist_info_vars)
kl = dist.kl_sym(old_dist_info_vars, dist_info_vars)
# formulate as a minimization problem
# The gradient of the surrogate objective is the policy gradient
if is_recurrent:
surr_obj = - TT.sum(logli * advantage_var * valid_var) / TT.sum(valid_var)
mean_kl = TT.sum(kl * valid_var) / TT.sum(valid_var)
max_kl = TT.max(kl * valid_var)
else:
surr_obj = - TT.mean(logli * advantage_var)
mean_kl = TT.mean(kl)
max_kl = TT.max(kl)
input_list = [obs_var, action_var, advantage_var] + state_info_vars_list
if is_recurrent:
input_list.append(valid_var)
self.optimizer.update_opt(surr_obj, target=self.policy, inputs=input_list)
f_kl = ext.compile_function(
inputs=input_list + old_dist_info_vars_list,
outputs=[mean_kl, max_kl],
)
self.opt_info = dict(
f_kl=f_kl,
)
@overrides
def optimize_policy(self, itr, samples_data):
logger.log("optimizing policy")
inputs = ext.extract(
samples_data,
"observations", "actions", "advantages"
)
agent_infos = samples_data["agent_infos"]
state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
inputs += tuple(state_info_list)
if self.policy.recurrent:
inputs += (samples_data["valids"],)
dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]
loss_before = self.optimizer.loss(inputs)
self.optimizer.optimize(inputs)
loss_after = self.optimizer.loss(inputs)
logger.record_tabular("LossBefore", loss_before)
logger.record_tabular("LossAfter", loss_after)
mean_kl, max_kl = self.opt_info['f_kl'](*(list(inputs) + dist_info_list))
logger.record_tabular('MeanKL', mean_kl)
logger.record_tabular('MaxKL', max_kl)
@overrides
def get_itr_snapshot(self, itr, samples_data):
return dict(
itr=itr,
policy=self.policy,
baseline=self.baseline,
env=self.env,
)