-
Notifications
You must be signed in to change notification settings - Fork 309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Pytorch TRPO #1018
Add Pytorch TRPO #1018
Conversation
src/garage/torch/algos/trpo.py
Outdated
rewards) | ||
self._optimizer.step(closure) | ||
|
||
def _build_closure(self, itr, paths, valids, obs, actions, rewards): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a reason you did this instead of just using lambdas?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'd prefer you just # noqa: E731
them than add 50 lines of boilerplate. also, the rule forbids assigning them, not passing them as function arguments. seems to me you could:
def _optimizer(...):
self._optimizer.step(
compute_loss=lambda: self._compute_loss(itr, paths, valids, obs, actions, rewards)
compute_kl=lambda: self._compute_kl_constraint(obs)
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
up to you. i think either is a better option than what you had to do here.
src/garage/torch/algos/vpg.py
Outdated
@@ -188,7 +198,7 @@ def _compute_loss(self, itr, paths, valids, obs, actions, rewards): | |||
objective += self._policy_ent_coeff * policy_entropies | |||
|
|||
valid_objectives = loss_function_utils.filter_valids(objective, valids) | |||
return torch.cat(valid_objectives).mean() | |||
return -1 * torch.cat(valid_objectives).mean() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not just -torch.cat(valid_objectives).mean()
?
"""Take an optimization step. | ||
|
||
Args: | ||
closure (tuple[function]): Functions to compute loss and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not separate arguments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is overriding torch.optim.Optimizer
's step function, which only takes one parameter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah i see.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems clear to me that torch intends closure
to be a single function. If you want to pseudo-implement their API, I think it would be more analogous to change the amend the signature to accept a second closure (which computes the constraint).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/garage/torch/algos/trpo.py
Outdated
env_spec, | ||
policy, | ||
baseline, | ||
max_path_length=500, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is really high for most environments, which often are <100
Code is looking pretty great, how are the benchmarks? |
@@ -15,13 +15,16 @@ def set_seed(seed): | |||
|
|||
""" | |||
seed %= 4294967294 | |||
global seed_ | |||
global seed_ # pylint: disable=global-statement |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can disable this function-wide here. the reason is fairly obvious.
from garage.torch.utils import update_tensor_list_from_flat_tensor | ||
|
||
|
||
def build_hessian_vector_product(func, params, reg_coeff=1e-5): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this need to be public?
computation 6.1 (1994): 147-160.` | ||
|
||
Args: | ||
func (function): A function that returns a torch.Tensor. Hessian of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Callable is the type for this.
Hmm, I'm also not sure what happened with It seems like the variance of the PyTorch plots is much lower. Did you ensure that the initial standard deviation of the policies are the same in both benchmarks? What about other hyperparameters? There's a reason for every metric which the TF TRPO implementation plots, so make sure your version is plotting every one of those too. How do the Policy/MeanKL plots compare? Policy/Entropy? Baseline/ExplainedVariance? Check all these before resorting to more drastic debugging. |
I'd go ahead and take a look at the plots for mean kl, mean mu, mean std. Those will probably be the most helpful in debugging the decreased performance on cheetah @utkarshjp7 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you using the params from the original TRPO paper to get the results that you posted?
@ryanjulian The hyperparameters are all same for both versions. I added logging to PyTorch TRPO and started benchmarks with 5 trials. I will post the results once its finished. |
Great! Note that you can upload your results to https://tensorboard.dev for easy reviewing. |
Please prioritize this PR, since it is blocking your MAML implementation. |
The latest benchmark data can be found here -> It is same as the plots I posted before, but with the additional logging. The main difference I observed is that lower bound of mean policy KL in tensorflow is 6.5e-3 while in PyTorch its 0.0. |
In future runs I recommend you make the name of the PyTorch policy also "GaussianMLPPolicy" so that these plots overlap. They're hard to analyze apart. Also, LinearFeatureBaseline/ExplainedVariance is a really essential debugging stat for on-policy RL so make sure to add that to the torch version. Perhaps we should resurrect the @krzentner WDYT? |
Some disorganized thoughts on these plots: looking at
Notice that KL, LossAfter, and dLoss are all 0 for itrs 486 and 487. This suggests to me that your optimizer is silently failing to create an update. Perhaps it can't calculate an update if the the loss is > 0? |
e11a24d
to
8dd2d23
Compare
Codecov Report
@@ Coverage Diff @@
## master #1018 +/- ##
==========================================
+ Coverage 85.07% 85.18% +0.11%
==========================================
Files 157 160 +3
Lines 7478 7616 +138
Branches 930 955 +25
==========================================
+ Hits 6362 6488 +126
- Misses 932 935 +3
- Partials 184 193 +9
Continue to review full report at Codecov.
|
@@ -61,6 +64,7 @@ def __init__( | |||
center_adv=True, | |||
positive_adv=False, | |||
optimizer=None, | |||
optimizer_args=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm really not a fan of this pattern of passing dicts of args for constructors into other constructors. can we somehow construct-and-pass the optimizer instead? or just flatten these args into the parent constructor?
@krzentner your thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, it is less ambiguous in the code if we construct and pass the optimizer.
src/garage/torch/algos/vpg.py
Outdated
@@ -308,5 +361,12 @@ def _log(self, itr, paths): | |||
tabular.record('StdReturn', np.std(undiscounted_returns)) | |||
tabular.record('MaxReturn', np.max(undiscounted_returns)) | |||
tabular.record('MinReturn', np.min(undiscounted_returns)) | |||
tabular.record('{0}/LossBefore'.format(self.policy.name), loss_before) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can use tabular.prefix
for this: https://github.com/rlworkgroup/dowel/blob/master/src/dowel/tabular_input.py
you might also be interested in tabular.record_misc_stat
src/garage/torch/policies/base.py
Outdated
@@ -3,37 +3,88 @@ | |||
|
|||
|
|||
class Policy(abc.ABC): | |||
""" | |||
Policy base class without Parameterzied. | |||
"""Policy base class without Parameterzied. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think "without Parameterized" is pretty outdated here.
src/garage/torch/policies/base.py
Outdated
* torch.Tensor: Predicted action. | ||
* dict: | ||
* list[float]: Mean of the distribution | ||
* list[float]: Standard deviation of logarithmic values of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is actually the log of the stddev, not the stddev of the log
src/garage/torch/policies/base.py
Outdated
* torch.Tensor: Predicted actions. | ||
* dict: | ||
* list[float]: Mean of the distribution | ||
* list[float]: Standard deviation of logarithmic values of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
log(std) not std(log)
* torch.Tensor: Predicted action. | ||
* dict: | ||
* list[float]: Mean of the distribution | ||
* list[float]: Standard deviation of logarithmic values of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
log(std)
* torch.Tensor: Predicted actions. | ||
* dict: | ||
* list[float]: Mean of the distribution | ||
* list[float]: Standard deviation of logarithmic values of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
log(std)
|
||
""" | ||
|
||
def __init__(self, env_spec, **kwargs): | ||
def __init__(self, env_spec, name='GaussianMLPPolicy', **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please don't swallow **kwargs
in the constructor. (I know you didn't write this, but let's remove it)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we remove it then how can one change the non-linearility of layer or initialization of weights?, or the point hear is that no one should be changing that through GaussianMLPPolicy
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see what you mean. We should explicitly mention each keyword argument that this class supports?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. always.
# pylint: disable=not-callable #https://github.com/pytorch/pytorch/issues/24807 # noqa: E501 | ||
|
||
|
||
class TestConjugateGradientOptimizer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that pytest is perfectly happy running test functions outside of classes, as long as they are named test_*
6a8fb49
to
81ecd2b
Compare
I'm still worried that PyTorch my be systematically just a little bit worse. What do these look like with 10 trials? And what's the average performance gap for each at itr 999? |
Okay, this is as close to equal as algos get. I think the performance is ready. Did you have a chance to look at the runtime? |
I ran python profiler, and it seems the overhead in PyTorch is computing gradients (the
|
81ecd2b
to
296da04
Compare
I think that this is a general issue in PyTorch which we can investigate later. You can try some of the suggestions here: pytorch/pytorch#975 Is your CPU usage near 100% (on all cores) during the optimization phase? Anyway, please don't block commit on this one since it looks like it's framework-wide. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -61,6 +64,7 @@ def __init__( | |||
center_adv=True, | |||
positive_adv=False, | |||
optimizer=None, | |||
optimizer_args=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, it is less ambiguous in the code if we construct and pass the optimizer.
296da04
to
6b05023
Compare
Implemented Trust Region Policy Optimization in PyTorch.
Benchmarks are currently running and should be finished by tomorrow. I opened this PR to get some feedback since Initial results and tests looked good.