Skip to content

Commit

Permalink
Implemented TRPO in PyTorch
Browse files Browse the repository at this point in the history
  • Loading branch information
utkarshjp7 committed Nov 14, 2019
1 parent eaf39da commit 6a8fb49
Show file tree
Hide file tree
Showing 14 changed files with 1,069 additions and 165 deletions.
51 changes: 51 additions & 0 deletions examples/torch/trpo_pendulum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env python3
"""This is an example to train a task with TRPO algorithm (PyTorch).
Here it runs InvertedDoublePendulum-v2 environment with 100 iterations.
"""
import torch

from garage.experiment import LocalRunner, run_experiment
from garage.np.baselines import LinearFeatureBaseline
from garage.tf.envs import TfEnv
from garage.torch.algos import TRPO
from garage.torch.policies import GaussianMLPPolicy


def run_task(snapshot_config, *_):
"""Set up environment and algorithm and run the task.
Args:
snapshot_config (garage.experiment.SnapshotConfig): The snapshot
configuration used by LocalRunner to create the snapshotter.
If None, it will create one with default settings.
_ : Unused parameters
"""
env = TfEnv(env_name='InvertedDoublePendulum-v2')

runner = LocalRunner(snapshot_config)

policy = GaussianMLPPolicy(env.spec,
hidden_sizes=[32, 32],
hidden_nonlinearity=torch.tanh,
output_nonlinearity=None)

baseline = LinearFeatureBaseline(env_spec=env.spec)

algo = TRPO(env_spec=env.spec,
policy=policy,
baseline=baseline,
max_path_length=100,
discount=0.99,
center_adv=False)

runner.setup(algo, env)
runner.train(n_epochs=100, batch_size=1024)


run_experiment(
run_task,
snapshot_mode='last',
seed=1,
)
3 changes: 2 additions & 1 deletion src/garage/torch/algos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from garage.torch.algos.ddpg import DDPG
from garage.torch.algos.vpg import VPG
from garage.torch.algos.ppo import PPO # noqa: I100
from garage.torch.algos.trpo import TRPO

__all__ = ['DDPG', 'VPG', 'PPO']
__all__ = ['DDPG', 'VPG', 'PPO', 'TRPO']
15 changes: 5 additions & 10 deletions src/garage/torch/algos/ppo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""Proximal Policy Optimization (PPO)."""
import copy

import torch

from garage.torch.algos import VPG
Expand Down Expand Up @@ -29,6 +27,7 @@ class PPO(VPG):
standardized before shifting.
optimizer (object): The optimizer of the algorithm. Should be the
optimizers in torch.optim.
optimizer_args (dict): Arguments required to initialize the optimizer.
policy_ent_coeff (float): The coefficient of the policy entropy.
Setting it to zero would mean no entropy regularization.
use_softplus_entropy (bool): Whether to estimate the softmax
Expand Down Expand Up @@ -56,18 +55,18 @@ def __init__(self,
center_adv=True,
positive_adv=False,
optimizer=None,
optimizer_args=None,
policy_ent_coeff=0.0,
use_softplus_entropy=False,
stop_entropy_gradient=False,
entropy_method='no_entropy'):
super().__init__(env_spec, policy, baseline, max_path_length,
policy_lr, n_samples, discount, gae_lambda,
center_adv, positive_adv, optimizer, policy_ent_coeff,
use_softplus_entropy, stop_entropy_gradient,
entropy_method)
center_adv, positive_adv, optimizer, optimizer_args,
policy_ent_coeff, use_softplus_entropy,
stop_entropy_gradient, entropy_method)

self._lr_clip_range = lr_clip_range
self._old_policy = copy.deepcopy(self.policy)

def _compute_objective(self, advantages, valids, obs, actions, rewards):
"""Compute objective using surrogate value and clipped surrogate value.
Expand All @@ -83,17 +82,13 @@ def _compute_objective(self, advantages, valids, obs, actions, rewards):
torch.Tensor: Calculated objective values
"""
# pylint: disable=unused-argument
# Compute constraint
with torch.no_grad():
old_ll = self._old_policy.log_likelihood(obs, actions)
new_ll = self.policy.log_likelihood(obs, actions)

likelihood_ratio = (new_ll - old_ll).exp()

# Memorize the policy state_dict
self._old_policy.load_state_dict(self.policy.state_dict())

# Calculate surrogate
surrogate = likelihood_ratio * advantages

Expand Down
105 changes: 105 additions & 0 deletions src/garage/torch/algos/trpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Trust Region Policy Optimization."""
import torch

from garage.torch.algos import VPG
from garage.torch.optimizers import ConjugateGradientOptimizer


class TRPO(VPG):
"""Trust Region Policy Optimization (TRPO).
Args:
env_spec (garage.envs.EnvSpec): Environment specification.
policy (garage.torch.policies.base.Policy): Policy.
baseline (garage.np.baselines.Baseline): The baseline.
max_path_length (int): Maximum length of a single rollout.
policy_lr (float): Learning rate for training policy network.
num_train_per_epoch (int): Number of train_once calls per epoch.
discount (float): Discount.
gae_lambda (float): Lambda used for generalized advantage
estimation.
max_kl (float): The maximum KL divergence between old and new
policies.
center_adv (bool): Whether to rescale the advantages
so that they have mean 0 and standard deviation 1.
positive_adv (bool): Whether to shift the advantages
so that they are always positive. When used in
conjunction with center_adv the advantages will be
standardized before shifting.
optimizer (object): The optimizer of the algorithm. Should be the
optimizers in torch.optim.
optimizer_args (dict): Arguments required to initialize the optimizer.
policy_ent_coeff (float): The coefficient of the policy entropy.
Setting it to zero would mean no entropy regularization.
use_softplus_entropy (bool): Whether to estimate the softmax
distribution of the entropy to prevent the entropy from being
negative.
stop_entropy_gradient (bool): Whether to stop the entropy gradient.
entropy_method (str): A string from: 'max', 'regularized',
'no_entropy'. The type of entropy method to use. 'max' adds the
dense entropy to the reward for each time step. 'regularized' adds
the mean entropy to the surrogate objective. See
https://arxiv.org/abs/1805.00909 for more details.
"""

def __init__(self,
env_spec,
policy,
baseline,
max_path_length=100,
policy_lr=3e-4,
num_train_per_epoch=1,
discount=0.99,
gae_lambda=0.98,
max_kl=0.01,
center_adv=True,
positive_adv=False,
optimizer=None,
optimizer_args=None,
policy_ent_coeff=0.0,
use_softplus_entropy=False,
stop_entropy_gradient=False,
entropy_method='no_entropy'):
if optimizer is None:
optimizer = ConjugateGradientOptimizer
optimizer_args = {'max_constraint_value': max_kl}

super().__init__(env_spec, policy, baseline, max_path_length,
policy_lr, num_train_per_epoch, discount, gae_lambda,
center_adv, positive_adv, optimizer, optimizer_args,
policy_ent_coeff, use_softplus_entropy,
stop_entropy_gradient, entropy_method)

self._kl = None

def _compute_objective(self, advantages, valids, obs, actions, rewards):
"""Compute the surrogate objective.
Args:
advantages (torch.Tensor): Expected rewards over the actions
valids (list[int]): length of the valid values for each path
obs (torch.Tensor): Observation from the environment.
actions (torch.Tensor): Predicted action.
rewards (torch.Tensor): Feedback from the environment.
Returns:
torch.Tensor: Calculated objective values
"""
with torch.no_grad():
old_ll = self._old_policy.log_likelihood(obs, actions)

new_ll = self.policy.log_likelihood(obs, actions)
likelihood_ratio = (new_ll - old_ll).exp()

# Calculate surrogate
surrogate = likelihood_ratio * advantages

return surrogate

def _optimize(self, itr, paths, valids, obs, actions, rewards):
self._optimizer.step(
f_loss=lambda: self._compute_loss(itr, paths, valids, obs, actions,
rewards),
f_constraint=lambda: self._compute_kl_constraint(obs))
80 changes: 70 additions & 10 deletions src/garage/torch/algos/vpg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Vanilla Policy Gradient (REINFORCE)."""
import collections
import copy

from dowel import tabular
import numpy as np
Expand All @@ -9,6 +10,7 @@
from garage.misc import tensor_utils
from garage.np.algos import BatchPolopt
from garage.torch.algos import loss_function_utils
from garage.torch.utils import flatten_batch


class VPG(BatchPolopt):
Expand All @@ -34,6 +36,7 @@ class VPG(BatchPolopt):
standardized before shifting.
optimizer (object): The optimizer of the algorithm. Should be the
optimizers in torch.optim.
optimizer_args (dict): Arguments required to initialize the optimizer.
policy_ent_coeff (float): The coefficient of the policy entropy.
Setting it to zero would mean no entropy regularization.
use_softplus_entropy (bool): Whether to estimate the softmax
Expand Down Expand Up @@ -61,6 +64,7 @@ def __init__(
center_adv=True,
positive_adv=False,
optimizer=None,
optimizer_args=None,
policy_ent_coeff=0.0,
use_softplus_entropy=False,
stop_entropy_gradient=False,
Expand All @@ -83,16 +87,20 @@ def __init__(
stop_entropy_gradient,
policy_ent_coeff)
self._episode_reward_mean = collections.deque(maxlen=100)
self._optimizer = optimizer(policy.parameters(),
lr=policy_lr,
eps=1e-5)

if optimizer_args is None:
optimizer_args = {'lr': policy_lr, 'eps': 1e-5}

self._optimizer = optimizer(policy.parameters(), **optimizer_args)

super().__init__(policy=policy,
baseline=baseline,
discount=discount,
max_path_length=max_path_length,
n_samples=n_samples)

self._old_policy = copy.deepcopy(self.policy)

@staticmethod
def _check_entropy_configuration(entropy_method, center_adv,
stop_entropy_gradient, policy_ent_coeff):
Expand Down Expand Up @@ -124,15 +132,27 @@ def train_once(self, itr, paths):
"""
valids, obs, actions, rewards = self.process_samples(itr, paths)
average_return = self._log(itr, paths)

loss = self._compute_loss(itr, paths, valids, obs, actions, rewards)

# Memorize the policy state_dict
self._old_policy.load_state_dict(self.policy.state_dict())

self._optimizer.zero_grad()
# using a negative because optimizers use gradient descent,
# whilst we want gradient ascent.
(-loss).backward()
self._optimizer.step()
loss.backward()

kl_before = self._compute_kl_constraint(obs).detach()
self._optimize(itr, paths, valids, obs, actions, rewards)

with torch.no_grad():
loss_after = self._compute_loss(itr, paths, valids, obs, actions,
rewards)
kl = self._compute_kl_constraint(obs)
policy_entropy = self._compute_policy_entropy(obs)
average_return = self._log(itr, paths, loss.item(),
loss_after.item(), kl_before.item(),
kl.item(),
policy_entropy.mean().item())

self.baseline.fit(paths)
return average_return
Expand Down Expand Up @@ -188,7 +208,31 @@ def _compute_loss(self, itr, paths, valids, obs, actions, rewards):
objective += self._policy_ent_coeff * policy_entropies

valid_objectives = loss_function_utils.filter_valids(objective, valids)
return torch.cat(valid_objectives).mean()
return -torch.cat(valid_objectives).mean()

def _compute_kl_constraint(self, obs):
"""Compute KL divergence.
Compute the KL divergence between the old policy distribution and
current policy distribution.
Args:
obs (torch.Tensor): Observation from the environment.
Returns:
torch.Tensor: Calculated mean KL divergence.
"""
flat_obs = flatten_batch(obs)
with torch.no_grad():
old_dist = self._old_policy.forward(flat_obs)

new_dist = self.policy.forward(flat_obs)

kl_constraint = torch.distributions.kl.kl_divergence(
old_dist, new_dist)

return kl_constraint.mean()

def _compute_policy_entropy(self, obs):
"""Compute entropy value of probability distribution.
Expand Down Expand Up @@ -247,6 +291,9 @@ def _get_baselines(self, path):
return torch.Tensor(self.baseline.predict_n(path))
return torch.Tensor(self.baseline.predict(path))

def _optimize(self, itr, paths, valids, obs, actions, rewards): # pylint: disable=unused-argument # noqa: E501
self._optimizer.step()

def process_samples(self, itr, paths):
"""Process sample data based on the collected paths.
Expand Down Expand Up @@ -282,12 +329,18 @@ def process_samples(self, itr, paths):

return valids, obs, actions, rewards

def _log(self, itr, paths):
def _log(self, itr, paths, loss_before, loss_after, kl_before, kl,
policy_entropy):
"""Log information per iteration based on the collected paths.
Args:
itr (int): Iteration number.
paths (list[dict]): A list of collected paths
loss_before (float): Loss before optimization step.
loss_after (float): Loss after optimization step.
kl_before (float): KL divergence before optimization step.
kl (float): KL divergence after optimization step.
policy_entropy (float): Policy entropy.
Returns:
float: The average return in last epoch cycle.
Expand All @@ -308,5 +361,12 @@ def _log(self, itr, paths):
tabular.record('StdReturn', np.std(undiscounted_returns))
tabular.record('MaxReturn', np.max(undiscounted_returns))
tabular.record('MinReturn', np.min(undiscounted_returns))
with tabular.prefix(self.policy.name):
tabular.record('LossBefore', loss_before)
tabular.record('LossAfter', loss_after)
tabular.record('dLoss', loss_before - loss_after)
tabular.record('KLBefore', kl_before)
tabular.record('KL', kl)
tabular.record('Entropy', policy_entropy)

return average_return
5 changes: 5 additions & 0 deletions src/garage/torch/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""PyTorch optimizers."""
from garage.torch.optimizers.conjugate_gradient_optimizer import (
ConjugateGradientOptimizer)

__all__ = ['ConjugateGradientOptimizer']
Loading

0 comments on commit 6a8fb49

Please sign in to comment.