Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pytorch TRPO #1018

Merged
merged 1 commit into from
Nov 16, 2019
Merged

Add Pytorch TRPO #1018

merged 1 commit into from
Nov 16, 2019

Conversation

utkarshjp7
Copy link
Member

Implemented Trust Region Policy Optimization in PyTorch.

Benchmarks are currently running and should be finished by tomorrow. I opened this PR to get some feedback since Initial results and tests looked good.

@utkarshjp7 utkarshjp7 requested a review from a team as a code owner November 12, 2019 00:17
rewards)
self._optimizer.step(closure)

def _build_closure(self, itr, paths, valids, obs, actions, rewards):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason you did this instead of just using lambdas?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd prefer you just # noqa: E731 them than add 50 lines of boilerplate. also, the rule forbids assigning them, not passing them as function arguments. seems to me you could:

def _optimizer(...):
    self._optimizer.step(
        compute_loss=lambda: self._compute_loss(itr, paths, valids, obs, actions, rewards)
        compute_kl=lambda: self._compute_kl_constraint(obs)
    )

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

up to you. i think either is a better option than what you had to do here.

@@ -188,7 +198,7 @@ def _compute_loss(self, itr, paths, valids, obs, actions, rewards):
objective += self._policy_ent_coeff * policy_entropies

valid_objectives = loss_function_utils.filter_valids(objective, valids)
return torch.cat(valid_objectives).mean()
return -1 * torch.cat(valid_objectives).mean()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just -torch.cat(valid_objectives).mean()?

"""Take an optimization step.

Args:
closure (tuple[function]): Functions to compute loss and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not separate arguments?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is overriding torch.optim.Optimizer 's step function, which only takes one parameter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah i see.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems clear to me that torch intends closure to be a single function. If you want to pseudo-implement their API, I think it would be more analogous to change the amend the signature to accept a second closure (which computes the constraint).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ryanjulian ryanjulian requested review from krzentner and a team November 12, 2019 01:43
@ghost ghost requested review from nish21 and removed request for a team November 12, 2019 01:43
env_spec,
policy,
baseline,
max_path_length=500,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is really high for most environments, which often are <100

@ryanjulian
Copy link
Member

Code is looking pretty great, how are the benchmarks?

@@ -15,13 +15,16 @@ def set_seed(seed):

"""
seed %= 4294967294
global seed_
global seed_ # pylint: disable=global-statement
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can disable this function-wide here. the reason is fairly obvious.

from garage.torch.utils import update_tensor_list_from_flat_tensor


def build_hessian_vector_product(func, params, reg_coeff=1e-5):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be public?

computation 6.1 (1994): 147-160.`

Args:
func (function): A function that returns a torch.Tensor. Hessian of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Callable is the type for this.

@utkarshjp7
Copy link
Member Author

utkarshjp7 commented Nov 12, 2019

I am not sure what happened with HalfCheetah. Any thoughts?
These are averaged over 3 trials.

Hopper-v2_benchmark
Reacher-v2_benchmark
Swimmer-v2_benchmark
Walker2d-v2_benchmark
HalfCheetah-v2_benchmark

@ryanjulian
Copy link
Member

Hmm, I'm also not sure what happened with HalfCheetah -- but it's actually consistent with the rest of your results, which are slightly lower than TF in all cases. HalfCheetah is just harder than the rest, so the difference looks bigger.

It seems like the variance of the PyTorch plots is much lower. Did you ensure that the initial standard deviation of the policies are the same in both benchmarks? What about other hyperparameters?

There's a reason for every metric which the TF TRPO implementation plots, so make sure your version is plotting every one of those too. How do the Policy/MeanKL plots compare? Policy/Entropy? Baseline/ExplainedVariance? Check all these before resorting to more drastic debugging.

@avnishn
Copy link
Member

avnishn commented Nov 12, 2019

I'd go ahead and take a look at the plots for mean kl, mean mu, mean std. Those will probably be the most helpful in debugging the decreased performance on cheetah @utkarshjp7

Copy link
Member

@avnishn avnishn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you using the params from the original TRPO paper to get the results that you posted?

src/garage/torch/algos/ppo.py Show resolved Hide resolved
@utkarshjp7
Copy link
Member Author

@ryanjulian The hyperparameters are all same for both versions. I added logging to PyTorch TRPO and started benchmarks with 5 trials. I will post the results once its finished.

@ryanjulian
Copy link
Member

Great! Note that you can upload your results to https://tensorboard.dev for easy reviewing.

@ryanjulian
Copy link
Member

Please prioritize this PR, since it is blocking your MAML implementation.

@utkarshjp7
Copy link
Member Author

The latest benchmark data can be found here ->
https://tensorboard.dev/experiment/OP0fqpjNQzawLEuLDslKjA

It is same as the plots I posted before, but with the additional logging. The main difference I observed is that lower bound of mean policy KL in tensorflow is 6.5e-3 while in PyTorch its 0.0.

@ryanjulian
Copy link
Member

In future runs I recommend you make the name of the PyTorch policy also "GaussianMLPPolicy" so that these plots overlap. They're hard to analyze apart.

Also, LinearFeatureBaseline/ExplainedVariance is a really essential debugging stat for on-policy RL so make sure to add that to the torch version. Perhaps we should resurrect the log_diagnostics interface or similar to make these things consistent.

@krzentner WDYT?

@ryanjulian
Copy link
Member

Some disorganized thoughts on these plots:

looking at Reacher-v2/trial_1_seed_18/garage vs Reacher-v2/trial_1_seed_18/garage_pytorch:

  • LossAfter for PyTorch is never positive which is a bit strange. Loss magnitudes and signs don't have much meaning in policy gradients, but it's weird to see it clamped at 0. This suggests to me that either the optimizer or loss function is truncating, scaling, or normalizing something unexpectedly.
  • MeanKL magnitudes are as-expected, but the PyTorch version has many steps with 0 MeanKL -- does this mean the line search is failing and it aborts the optimization for those steps? Do you produce some sort of warning when this happens?
  • dLoss in PyTorch is normal but again has many itrs where it is 0. This adds more weight to my suspicion above that many optimization epochs are failing.
  • All of the above seems to apply to HalfCheetah experiments

Here's some plots to chew on:
Screen Shot 2019-11-13 at 11 45 36 AM

Notice that KL, LossAfter, and dLoss are all 0 for itrs 486 and 487. This suggests to me that your optimizer is silently failing to create an update. Perhaps it can't calculate an update if the the loss is > 0?

@codecov
Copy link

codecov bot commented Nov 14, 2019

Codecov Report

Merging #1018 into master will increase coverage by 0.11%.
The diff coverage is 91.61%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1018      +/-   ##
==========================================
+ Coverage   85.07%   85.18%   +0.11%     
==========================================
  Files         157      160       +3     
  Lines        7478     7616     +138     
  Branches      930      955      +25     
==========================================
+ Hits         6362     6488     +126     
- Misses        932      935       +3     
- Partials      184      193       +9
Impacted Files Coverage Δ
src/garage/torch/algos/ppo.py 100% <ø> (ø) ⬆️
src/garage/torch/policies/gaussian_mlp_policy.py 100% <100%> (ø) ⬆️
src/garage/torch/policies/base.py 80% <100%> (+15.71%) ⬆️
src/garage/torch/utils.py 100% <100%> (ø) ⬆️
src/garage/torch/algos/vpg.py 98.42% <100%> (+0.42%) ⬆️
.../garage/torch/policies/deterministic_mlp_policy.py 100% <100%> (ø) ⬆️
src/garage/torch/optimizers/__init__.py 100% <100%> (ø)
...e/torch/optimizers/conjugate_gradient_optimizer.py 86.66% <86.66%> (ø)
src/garage/torch/algos/trpo.py 94.73% <94.73%> (ø)
.../exploration_strategies/epsilon_greedy_strategy.py 96.29% <0%> (-3.71%) ⬇️
... and 3 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update ca2d1cf...6b05023. Read the comment docs.

@@ -61,6 +64,7 @@ def __init__(
center_adv=True,
positive_adv=False,
optimizer=None,
optimizer_args=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm really not a fan of this pattern of passing dicts of args for constructors into other constructors. can we somehow construct-and-pass the optimizer instead? or just flatten these args into the parent constructor?

@krzentner your thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, it is less ambiguous in the code if we construct and pass the optimizer.

@@ -308,5 +361,12 @@ def _log(self, itr, paths):
tabular.record('StdReturn', np.std(undiscounted_returns))
tabular.record('MaxReturn', np.max(undiscounted_returns))
tabular.record('MinReturn', np.min(undiscounted_returns))
tabular.record('{0}/LossBefore'.format(self.policy.name), loss_before)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use tabular.prefix for this: https://github.com/rlworkgroup/dowel/blob/master/src/dowel/tabular_input.py

you might also be interested in tabular.record_misc_stat

@@ -3,37 +3,88 @@


class Policy(abc.ABC):
"""
Policy base class without Parameterzied.
"""Policy base class without Parameterzied.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think "without Parameterized" is pretty outdated here.

* torch.Tensor: Predicted action.
* dict:
* list[float]: Mean of the distribution
* list[float]: Standard deviation of logarithmic values of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is actually the log of the stddev, not the stddev of the log

* torch.Tensor: Predicted actions.
* dict:
* list[float]: Mean of the distribution
* list[float]: Standard deviation of logarithmic values of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log(std) not std(log)

* torch.Tensor: Predicted action.
* dict:
* list[float]: Mean of the distribution
* list[float]: Standard deviation of logarithmic values of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log(std)

* torch.Tensor: Predicted actions.
* dict:
* list[float]: Mean of the distribution
* list[float]: Standard deviation of logarithmic values of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log(std)


"""

def __init__(self, env_spec, **kwargs):
def __init__(self, env_spec, name='GaussianMLPPolicy', **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please don't swallow **kwargs in the constructor. (I know you didn't write this, but let's remove it)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove it then how can one change the non-linearility of layer or initialization of weights?, or the point hear is that no one should be changing that through GaussianMLPPolicy?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean. We should explicitly mention each keyword argument that this class supports?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. always.

# pylint: disable=not-callable #https://github.com/pytorch/pytorch/issues/24807 # noqa: E501


class TestConjugateGradientOptimizer:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that pytest is perfectly happy running test functions outside of classes, as long as they are named test_*

@utkarshjp7 utkarshjp7 force-pushed the pytorch_trpo branch 4 times, most recently from 6a8fb49 to 81ecd2b Compare November 14, 2019 09:36
@utkarshjp7
Copy link
Member Author

Benchmark results look much better after fixing a bug in optimizer. These are averaged over 5 trials.

HalfCheetah-v2_benchmark
Hopper-v2_benchmark
Reacher-v2_benchmark
Swimmer-v2_benchmark
Walker2d-v2_benchmark

@ryanjulian
Copy link
Member

Looks like the torch version is 2x slower than the TF version?
Screen Shot 2019-11-14 at 8 32 31 AM

Any insight into why that is? Perhaps we haven't enabled CPU parallelism in torch? We don't need to fix it for this PR but it would be nice to know what's going on.

@ryanjulian
Copy link
Member

I'm still worried that PyTorch my be systematically just a little bit worse.

What do these look like with 10 trials? And what's the average performance gap for each at itr 999?

@utkarshjp7
Copy link
Member Author

Difference in means at itr 999 averaged over 10 trials. (tensorflow - pytorch)

HalfCheetah-v2: -17.11101
Hopper-v2: 13.40361
Reacher-v2: 0.71291
Swimmer-v2: 1.15873
Walker2d-v2: 1.80943

HalfCheetah-v2_benchmark
Hopper-v2_benchmark
Reacher-v2_benchmark
Swimmer-v2_benchmark
Walker2d-v2_benchmark

@ryanjulian
Copy link
Member

Okay, this is as close to equal as algos get. I think the performance is ready.

Did you have a chance to look at the runtime?

@utkarshjp7
Copy link
Member Author

I ran python profiler, and it seems the overhead in PyTorch is computing gradients (the backward function call). These stats are for running PyTorch TRPO for 10 epochs.

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      230    8.429    0.037    8.429    0.037 {method 'run_backward' of 'torch._C._EngineBase' objects}
     3208    3.015    0.001    3.015    0.001 {built-in method cholesky}
     6416    0.832    0.000    0.832    0.000 {built-in method tanh}
    52225    0.661    0.000    0.661    0.000 {method 'step' of 'mujoco_py.cymj.MjSim' objects}
     8808    0.500    0.000    0.500    0.000 {built-in method addmm}
      292    0.485    0.002    1.697    0.006 inspect.py:714(getmodule)
      452    0.453    0.001    0.453    0.001 {built-in method triangular_solve}
  198/179    0.389    0.002    0.462    0.003 {built-in method _imp.create_dynamic}
     3866    0.314    0.000    0.314    0.000 {built-in method marshal.loads}

@ryanjulian
Copy link
Member

I think that this is a general issue in PyTorch which we can investigate later. You can try some of the suggestions here: pytorch/pytorch#975

Is your CPU usage near 100% (on all cores) during the optimization phase? Anyway, please don't block commit on this one since it looks like it's framework-wide.

Copy link
Member

@avnishn avnishn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -61,6 +64,7 @@ def __init__(
center_adv=True,
positive_adv=False,
optimizer=None,
optimizer_args=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, it is less ambiguous in the code if we construct and pass the optimizer.

@mergify mergify bot merged commit 02d6ef7 into master Nov 16, 2019
@mergify mergify bot deleted the pytorch_trpo branch November 16, 2019 00:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants