Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rllib] Support torch device and distributions. #4553

Merged
merged 8 commits into from
Apr 12, 2019
Merged

Conversation

cffan
Copy link
Contributor

@cffan cffan commented Apr 3, 2019

What do these changes do?

  • Move torch model and inputs to GPU if specified.
  • Wrapper class for torch native distributions.
  • Support "grad_clip" in A3C.
  • Return grad_info in A3C.
  • Fix A3C and PG log_prob calculated not using actions.

Related issue number

Closes #4333

Linter

  • I've run scripts/format.sh to lint the changes in this PR.

@AmplabJenkins
Copy link

Can one of the admins verify this patch?

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/13475/
Test FAILed.

@cffan
Copy link
Contributor Author

cffan commented Apr 3, 2019

The test environment is using pytorch-cpu. Should I change it to pytorch?

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/13476/
Test FAILed.

@ericl ericl self-assigned this Apr 3, 2019
@ericl ericl added this to Needs triage in RLlib via automation Apr 3, 2019
@ericl ericl changed the title Support torch device and distributions. [rllib] Support torch device and distributions. Apr 3, 2019
@ericl
Copy link
Contributor

ericl commented Apr 3, 2019

Hm, does that make a gpu available? AFAIK, none of our tests are currently run with GPUs. What is the limitation of pytorch-cpu?

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/13484/
Test FAILed.

@cffan
Copy link
Contributor Author

cffan commented Apr 4, 2019

Never mind. CPU case should be handled correctly now.

logits, _, values, _ = policy_model(
{SampleBatch.CUR_OBS: observations}, [])
logits = logits
values = values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two lines seem redundant?

log_probs = log_probs.sum(-1)
self.entropy = dist.entropy().mean().cpu()
self.pi_err = -advantages.dot(log_probs.reshape(-1)).cpu()
self.value_err = F.mse_loss(values.reshape(-1), value_targets).cpu()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to move the loss to cpu?

action_distribution_cls=dist_class)

@override(PolicyGraph)
def compute_gradients(self, postprocessed_batch):
Copy link
Contributor

@ericl ericl Apr 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we keep this method impl in TorchPolicyGraph and have options to clip grads / return extra stats as generic functionality?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can make an abstract method for getting extra grad info.
For grad clipping, I can either make config a property of TorchPolicyGraph so compute_gradients() in TorchPolicyGraph would know whether to clip grad or make an abstract method extra_grad_processing(self, grad) in TorchPolicyGraph and let subclass process the grad. What's your preference?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TFPolicyGraph offers the extra grad processing method, so it's probably better to do that for consistency.

dist = self.dist_class(logits)
log_probs = dist.logp(actions)
if len(log_probs.shape) > 1:
log_probs = log_probs.sum(-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In which cases does log_probs have a nontrivial second dimension? Wondering if the reshape() is sufficient?

Same question for A3CLoss.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't tried others but Normal distribution's log_prob returns vector of shape (n,) where n is the number of gaussians. I can absorb this into TorchDiagGaussian.

"""
self.observation_space = observation_space
self.action_space = action_space
self.lock = Lock()
self._model = model
cuda_devices = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could simply check if bool(os.environ.get("CUDA_VISIBLE_DEVICES"))

@@ -285,3 +286,40 @@ def kl(self, other):
@override(ActionDistribution)
def _build_sample_op(self):
return self.dist.sample()


class TorchDistributionWrapper(ActionDistribution):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the torch classes for action dist be placed in a separate file?

@@ -5,6 +5,7 @@
from collections import namedtuple
import distutils.version
import tensorflow as tf
import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure to not import torch unless we hit a torch=true code path, to avoid acquiring a hard dependency on torch.

@@ -120,7 +121,8 @@ def get_action_dist(action_space, config, dist_type=None):
elif dist_type == "deterministic":
return Deterministic, action_space.shape[0]
elif isinstance(action_space, gym.spaces.Discrete):
return Categorical, action_space.n
dist = TorchCategorical if torch else Categorical
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we add if torch: raise NotImplementedError for the other dist types?

@ericl
Copy link
Contributor

ericl commented Apr 4, 2019

Thanks for opening this! Overall looks solid; have some comments.

model_out = self._model({"obs": ob}, state_batches)
logits, _, vf, state = model_out
actions = F.softmax(logits, dim=1).multinomial(1).squeeze(0)
return (actions.numpy(), [h.numpy() for h in state],
action_dist = self._action_dist_cls(logits)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that A2C/PG presumably work with continuous action spaces, you can add two entries to run_rllib_tests.sh to check they work on Pendulum-v0:

Similar to the CartPole-v0 entries:
https://github.com/ray-project/ray/blob/master/ci/jenkins_tests/run_rllib_tests.sh#L407

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/13720/
Test FAILed.

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/13721/
Test FAILed.

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/13723/
Test FAILed.

Copy link
Contributor

@ericl ericl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. One thing I'm wondering is if it's possible to test that GPU mode works properly, without a real GPU. It seems easy to forget a cpu().

@cffan
Copy link
Contributor Author

cffan commented Apr 11, 2019

Is it possible to spin up a gpu instance every night and run a nightly test build on it to catch some errors?

@ericl
Copy link
Contributor

ericl commented Apr 11, 2019

Hm, potentially. I'm not sure if travis supports GPU instances though.

@FlyClover tests look good, but you have a couple lint changes: https://travis-ci.com/ray-project/ray/jobs/192067752

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/13748/
Test FAILed.

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/13755/
Test FAILed.

@ericl ericl merged commit bb207a2 into ray-project:master Apr 12, 2019
RLlib automation moved this from Needs triage to Done Apr 12, 2019
@ericl
Copy link
Contributor

ericl commented Apr 12, 2019

Merged, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
No open projects
RLlib
  
Done
Development

Successfully merging this pull request may close these issues.

[rllib] PyTorch A2C is not GPU accelerated
3 participants