-
Notifications
You must be signed in to change notification settings - Fork 309
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
eaf39da
commit 6a8fb49
Showing
14 changed files
with
1,069 additions
and
165 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
#!/usr/bin/env python3 | ||
"""This is an example to train a task with TRPO algorithm (PyTorch). | ||
Here it runs InvertedDoublePendulum-v2 environment with 100 iterations. | ||
""" | ||
import torch | ||
|
||
from garage.experiment import LocalRunner, run_experiment | ||
from garage.np.baselines import LinearFeatureBaseline | ||
from garage.tf.envs import TfEnv | ||
from garage.torch.algos import TRPO | ||
from garage.torch.policies import GaussianMLPPolicy | ||
|
||
|
||
def run_task(snapshot_config, *_): | ||
"""Set up environment and algorithm and run the task. | ||
Args: | ||
snapshot_config (garage.experiment.SnapshotConfig): The snapshot | ||
configuration used by LocalRunner to create the snapshotter. | ||
If None, it will create one with default settings. | ||
_ : Unused parameters | ||
""" | ||
env = TfEnv(env_name='InvertedDoublePendulum-v2') | ||
|
||
runner = LocalRunner(snapshot_config) | ||
|
||
policy = GaussianMLPPolicy(env.spec, | ||
hidden_sizes=[32, 32], | ||
hidden_nonlinearity=torch.tanh, | ||
output_nonlinearity=None) | ||
|
||
baseline = LinearFeatureBaseline(env_spec=env.spec) | ||
|
||
algo = TRPO(env_spec=env.spec, | ||
policy=policy, | ||
baseline=baseline, | ||
max_path_length=100, | ||
discount=0.99, | ||
center_adv=False) | ||
|
||
runner.setup(algo, env) | ||
runner.train(n_epochs=100, batch_size=1024) | ||
|
||
|
||
run_experiment( | ||
run_task, | ||
snapshot_mode='last', | ||
seed=1, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
"""Trust Region Policy Optimization.""" | ||
import torch | ||
|
||
from garage.torch.algos import VPG | ||
from garage.torch.optimizers import ConjugateGradientOptimizer | ||
|
||
|
||
class TRPO(VPG): | ||
"""Trust Region Policy Optimization (TRPO). | ||
Args: | ||
env_spec (garage.envs.EnvSpec): Environment specification. | ||
policy (garage.torch.policies.base.Policy): Policy. | ||
baseline (garage.np.baselines.Baseline): The baseline. | ||
max_path_length (int): Maximum length of a single rollout. | ||
policy_lr (float): Learning rate for training policy network. | ||
num_train_per_epoch (int): Number of train_once calls per epoch. | ||
discount (float): Discount. | ||
gae_lambda (float): Lambda used for generalized advantage | ||
estimation. | ||
max_kl (float): The maximum KL divergence between old and new | ||
policies. | ||
center_adv (bool): Whether to rescale the advantages | ||
so that they have mean 0 and standard deviation 1. | ||
positive_adv (bool): Whether to shift the advantages | ||
so that they are always positive. When used in | ||
conjunction with center_adv the advantages will be | ||
standardized before shifting. | ||
optimizer (object): The optimizer of the algorithm. Should be the | ||
optimizers in torch.optim. | ||
optimizer_args (dict): Arguments required to initialize the optimizer. | ||
policy_ent_coeff (float): The coefficient of the policy entropy. | ||
Setting it to zero would mean no entropy regularization. | ||
use_softplus_entropy (bool): Whether to estimate the softmax | ||
distribution of the entropy to prevent the entropy from being | ||
negative. | ||
stop_entropy_gradient (bool): Whether to stop the entropy gradient. | ||
entropy_method (str): A string from: 'max', 'regularized', | ||
'no_entropy'. The type of entropy method to use. 'max' adds the | ||
dense entropy to the reward for each time step. 'regularized' adds | ||
the mean entropy to the surrogate objective. See | ||
https://arxiv.org/abs/1805.00909 for more details. | ||
""" | ||
|
||
def __init__(self, | ||
env_spec, | ||
policy, | ||
baseline, | ||
max_path_length=100, | ||
policy_lr=3e-4, | ||
num_train_per_epoch=1, | ||
discount=0.99, | ||
gae_lambda=0.98, | ||
max_kl=0.01, | ||
center_adv=True, | ||
positive_adv=False, | ||
optimizer=None, | ||
optimizer_args=None, | ||
policy_ent_coeff=0.0, | ||
use_softplus_entropy=False, | ||
stop_entropy_gradient=False, | ||
entropy_method='no_entropy'): | ||
if optimizer is None: | ||
optimizer = ConjugateGradientOptimizer | ||
optimizer_args = {'max_constraint_value': max_kl} | ||
|
||
super().__init__(env_spec, policy, baseline, max_path_length, | ||
policy_lr, num_train_per_epoch, discount, gae_lambda, | ||
center_adv, positive_adv, optimizer, optimizer_args, | ||
policy_ent_coeff, use_softplus_entropy, | ||
stop_entropy_gradient, entropy_method) | ||
|
||
self._kl = None | ||
|
||
def _compute_objective(self, advantages, valids, obs, actions, rewards): | ||
"""Compute the surrogate objective. | ||
Args: | ||
advantages (torch.Tensor): Expected rewards over the actions | ||
valids (list[int]): length of the valid values for each path | ||
obs (torch.Tensor): Observation from the environment. | ||
actions (torch.Tensor): Predicted action. | ||
rewards (torch.Tensor): Feedback from the environment. | ||
Returns: | ||
torch.Tensor: Calculated objective values | ||
""" | ||
with torch.no_grad(): | ||
old_ll = self._old_policy.log_likelihood(obs, actions) | ||
|
||
new_ll = self.policy.log_likelihood(obs, actions) | ||
likelihood_ratio = (new_ll - old_ll).exp() | ||
|
||
# Calculate surrogate | ||
surrogate = likelihood_ratio * advantages | ||
|
||
return surrogate | ||
|
||
def _optimize(self, itr, paths, valids, obs, actions, rewards): | ||
self._optimizer.step( | ||
f_loss=lambda: self._compute_loss(itr, paths, valids, obs, actions, | ||
rewards), | ||
f_constraint=lambda: self._compute_kl_constraint(obs)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""PyTorch optimizers.""" | ||
from garage.torch.optimizers.conjugate_gradient_optimizer import ( | ||
ConjugateGradientOptimizer) | ||
|
||
__all__ = ['ConjugateGradientOptimizer'] |
Oops, something went wrong.