Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add value function for torch #1317

Merged
merged 1 commit into from
May 1, 2020
Merged

Add value function for torch #1317

merged 1 commit into from
May 1, 2020

Conversation

yonghyuc
Copy link
Contributor

Implement GaussianMLPValueFunction for PyTorch algorithms.

A value function computes loss and returns the value to an algorithm to update the weights.
An algorithm has two optimizers, for policy and value function.

@yonghyuc yonghyuc requested a review from a team as a code owner April 18, 2020 05:16
@codecov
Copy link

codecov bot commented Apr 18, 2020

Codecov Report

Merging #1317 into master will increase coverage by 0.00%.
The diff coverage is 98.44%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master    #1317   +/-   ##
=======================================
  Coverage   91.50%   91.50%           
=======================================
  Files         218      220    +2     
  Lines       10970    10975    +5     
  Branches     1322     1324    +2     
=======================================
+ Hits        10038    10043    +5     
- Misses        675      676    +1     
+ Partials      257      256    -1     
Impacted Files Coverage Δ
src/garage/torch/algos/trpo.py 91.66% <81.81%> (-8.34%) ⬇️
src/garage/torch/algos/maml.py 97.04% <100.00%> (+0.22%) ⬆️
src/garage/torch/algos/maml_ppo.py 100.00% <100.00%> (ø)
src/garage/torch/algos/maml_trpo.py 100.00% <100.00%> (ø)
src/garage/torch/algos/maml_vpg.py 100.00% <100.00%> (ø)
src/garage/torch/algos/ppo.py 100.00% <100.00%> (ø)
src/garage/torch/algos/vpg.py 100.00% <100.00%> (ø)
src/garage/torch/optimizers/__init__.py 100.00% <100.00%> (ø)
src/garage/torch/optimizers/optimizer_wrapper.py 100.00% <100.00%> (ø)
src/garage/torch/value_functions/__init__.py 100.00% <100.00%> (ø)
... and 12 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 652db03...2cfc032. Read the comment docs.

Copy link
Contributor

@krzentner krzentner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is definitely a solid change in the right direction. Having said that, I think the ValueFunction interface should be much smaller. Remember: it's easier to add new methods than to delete old ones.
Aside from that and some minor details, I'm quite happy with this change.

@@ -63,7 +61,8 @@ def __init__(self,
self._meta_evaluator = meta_evaluator
self._policy = policy
self._env = env
self._value_function = value_function
self._value_function = copy.deepcopy(inner_algo._value_function)
self._lame_vf_state = self._value_function.state_dict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you call it old instead of lame?

Copy link
Contributor Author

@yonghyuc yonghyuc Apr 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@naeioi Can you check this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conceptually it's the same, it's just that "lame" sounds a little unprofessional. Either way is okay, I guess.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's just confusing. use old or previous

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I am late to the party. I made this name. This is not an old or previous value function, but a value function without any training. Now I think initial_vf_state is the most appropriate name.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. If it gets used in multiple places, I suppose I'm fine with it being fixed in a later PR. _initial_vf_state is definitely a better name.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change this to _initial_vf_state ?

src/garage/torch/algos/maml.py Show resolved Hide resolved
src/garage/torch/algos/vpg.py Show resolved Hide resolved
src/garage/torch/value_functions/base.py Outdated Show resolved Hide resolved
src/garage/torch/value_functions/base.py Outdated Show resolved Hide resolved
src/garage/torch/value_functions/base.py Show resolved Hide resolved
src/garage/torch/value_functions/base.py Outdated Show resolved Hide resolved
src/garage/torch/algos/vpg.py Show resolved Hide resolved
advs_flat)
logger.log('Policy loss: {}'.format(policy_loss))

batch_dataset = BatchDataset((obs_flat, returns_flat),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why was this necessary for minibatching for the value function, but not the policy?

Copy link
Contributor Author

@yonghyuc yonghyuc Apr 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought for TRPO, there is the cg_iters field for ConjugateGradientOptimizer and this works very similar with max_optimization_epochs.
Is it okay to use minibatch for ConjugateGradientOptimizer?
If so, does this need to use same minibatch dataset but with different optimizing iterations? (cg_iters and max_optimization_epochs)?
Or it needs a single field for cg_iters and max_optimization_epochs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can't use minibatching with ConjugateGradientOptimizer, or rather it would be kind of useless. Minibatching only works with SGD-based optimization algorithms, but a CG optimizer doesn't use SGD to choose parameters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, so I used mini-batch only for value_function because default optimizer for value function is adam. When I run benchmark, I got a better result with minibatch and optimization_epochs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if i use the non-default ConjugateGradientOptimizer? will it still minibatch? wouldn't that be incorrect?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the tensorflow branch, we have our own optimizer class partly for this reason -- so that users can pass either a minibatching optimizer (which does batching inside of our own optimizer class), or the ConjugateGradientOptimizer. The real solution for us is probably to do the same for the pytorch branch, but that seems like significant additional complexity for this change. I think the simplest change which wouldn't be wrong would be to not use minibatching with Adam for the time being (even though that makes performance worse), and open an issue to create a minibatching adam optimizer for pytoch in another PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i saw in the CHANGELOG that Torch 1.5 started making it easier to add custom optimizers, but didn't look in detail.

Copy link
Member

@ryanjulian ryanjulian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great PR. See my comments and KR's (we are in agreement).

Overall this is like 80% there, but the intention of this PR is actually to totally remove the ValueFunction as a "special" primitive and make it act just like any other network in garage, so "predict" should not be necessary (for instance).

this will allow us to put the ValueFunction in the computation graph of the algorithm, which allows us to make a lot of things cleaner, and also enables us to do some interesting things.

@ryanjulian
Copy link
Member

What do the MuJoCo3M benchmarks look like?

@yonghyuc yonghyuc force-pushed the add_torch_gmvf branch 2 times, most recently from a0d74c2 to 7fbc0a4 Compare April 20, 2020 07:57
@yonghyuc
Copy link
Contributor Author

This is PPO benchmark result for Mujoco1M (5.5e+5)

HalfCheetah-v2_benchmark
Hopper-v2_benchmark
InvertedDoublePendulum-v2_benchmark
InvertedPendulum-v2_benchmark
Reacher-v2_benchmark
Swimmer-v2_benchmark
Walker2d-v2_benchmark

@yonghyuc
Copy link
Contributor Author

This is TRPO benchmark result for Mujoco1M (5.5e+5)

HalfCheetah-v2_benchmark
Hopper-v2_benchmark
InvertedDoublePendulum-v2_benchmark
InvertedPendulum-v2_benchmark
Reacher-v2_benchmark
Swimmer-v2_benchmark
Walker2d-v2_benchmark

Copy link
Contributor

@krzentner krzentner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, looks good to me. I still think we should avoid using the term "lame," since it's unprofessional.

src/garage/torch/value_functions/base.py Show resolved Hide resolved
@@ -432,7 +483,7 @@ def process_samples(self, itr, paths):
for path in paths:
if 'returns' not in path:
path['returns'] = tu.discount_cumsum(path['rewards'],
self.discount)
self.discount).copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why copy this?

Copy link
Contributor Author

@yonghyuc yonghyuc Apr 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not, I got the error
E ValueError: some of the strides of a given numpy array are negative. This is currently not supported, but will be added in future releases.
This is becaue we revert the order of output [::-1] in the discount_cumsum function.
So I need to copy the array

You can also check here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did it work before? What changed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use some other numpy function to get a reversed view, other than negative indexing?

When you call copy() you break the gradient path. If your returns have a differentiable components, e.g. a differentiable reward augmentation, your augmented rewards will no longer be differentiable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error is from here.
Previously, there is only LinearFeatureBaseline which doesn't have to convert numpy.array to torch.tensor. But now we need to convert numpy.array to torch.tensor for GaussianMLPValueFunction.

Copy link
Contributor Author

@yonghyuc yonghyuc Apr 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to this post, problem is in PyTorch. PyTorch doesn't support numpy.array with negative stride.
This is one solution not using .copy()

torch.Tensor( tu.discount_cumsum(path['rewards'], self.discount) [::-1] ).flip(-1)

For the gradient path, the returns values from the tu.discount_cumsum is numpy.array not a torch.tensor, so I think there is no gradient path for the values yet.

Also I'm just curious that we use rewards from environment to compute the return values and there is module or functions that compute gradient for the rewards?

@yonghyuc
Copy link
Contributor Author

yonghyuc commented Apr 27, 2020

I added new OptimizerWrapper class.
The OptimizerWrapper gets optimizer type(torch.optim.optimizer) and other parameters for mini batch. So, it is similar with torch.optim.optimizer (zero_grad, step) but it has get_minibatch function to provide batch data to algorithm.

Overall process is

  • algo gets mini batch dataset from OptimizerWrapper
  • OptimizerWrapper.zero_grad() -> inner torch.optim.optimizer.zero_grad()
  • algo computes loss and backward it
  • OptimizerWrapper.step() -> inner torch.optim.optimizer.step()

This PR contains many changes, so I will put changes only for OptimizerWrapper

  1. OptimizerWrapper
  2. Usage in BenchmarkPPO
  3. Changes in VPG

@yonghyuc yonghyuc force-pushed the add_torch_gmvf branch 2 times, most recently from c1a5687 to 4326ba8 Compare April 28, 2020 22:34
@yonghyuc
Copy link
Contributor Author

@naeioi Could you check the changes and how they affect MAML algorithm?
I changed the MAML following the changes but I am worried about missing part.
You can check main changes here
Thanks!

@krzentner Could you check this PR for me? Thanks!

@ryanjulian ryanjulian requested review from a team and zequnyu and removed request for a team April 29, 2020 22:10
@@ -264,49 +247,3 @@ def run_garage_tf(env, seed, log_dir):
dowel_logger.remove_all()

return tabular_log_file


def run_baselines(env, seed, log_dir):
Copy link
Member

@zequnyu zequnyu Apr 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to delete this? This should be the only place we could locate the baselines benchmarking code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the openai.baseline version is not support Tensorflow 2.0.

@@ -218,59 +215,3 @@ def run_garage(env, seed, log_dir):
dowel_logger.remove_all()

return tabular_log_file


def run_baselines(env, seed, log_dir):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the openai.baseline version is not support Tensorflow 2.0.

Copy link
Member

@naeioi naeioi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. This is an important change! Please fix my minor comment

@@ -63,7 +61,8 @@ def __init__(self,
self._meta_evaluator = meta_evaluator
self._policy = policy
self._env = env
self._value_function = value_function
self._value_function = copy.deepcopy(inner_algo._value_function)
self._lame_vf_state = self._value_function.state_dict()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change this to _initial_vf_state ?

@ryanjulian
Copy link
Member

@Mergifyio rebase

@mergify
Copy link
Contributor

mergify bot commented Apr 30, 2020

Command rebase: success

Branch has been successfully rebased

@ryanjulian
Copy link
Member

@Mergifyio rebase

@mergify
Copy link
Contributor

mergify bot commented May 1, 2020

Command rebase: success

Branch has been successfully rebased

@mergify mergify bot merged commit e09e6dc into master May 1, 2020
@mergify mergify bot deleted the add_torch_gmvf branch May 1, 2020 19:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants