Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rllib] Refactor pytorch custom model support #3634

Merged
merged 23 commits into from
Jan 3, 2019

Conversation

ericl
Copy link
Contributor

@ericl ericl commented Dec 25, 2018

What do these changes do?

  • Clean up the pytorch model API to support RNNs, Dict / Tuple spaces
  • Unify QMIX RNN model with model catalog

I expect we'll have to make more changes (and add more tests) as we implement PyTorch support more fully; this is just an initial cleanup to better unify QMIX with pytorch A3C.

#3365

@@ -292,8 +301,8 @@ def to_batches(arr):
@override(PolicyGraph)
def get_initial_state(self):
return [
self.model.init_hidden().numpy().squeeze()
for _ in range(self.n_agents)
s.expand([self.n_agents, -1]).numpy()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main change here is that we return a list of [num_agents, h_size_i] for i in n_state_tensors, rather than [h_size] for i in n_agents previously (which was kind of unnatural and didn't generalize well to multiple state tensors).

Relatedly, I'm regretting allowing multiple hidden state tensors, we probably should have just required them to be fused into one element.

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/10386/
Test FAILed.

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/10387/
Test FAILed.

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/10390/
Test FAILed.

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/10388/
Test FAILed.

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/10391/
Test FAILed.

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/10415/
Test FAILed.

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/10453/
Test FAILed.

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/10468/
Test PASSed.

@ericl ericl added the tests-ok The tagger certifies test failures are unrelated and assumes personal liability. label Dec 28, 2018
@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/10475/
Test FAILed.

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/10474/
Test PASSed.

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/10472/
Test FAILed.


def __init__(self, obs_space, num_outputs, options):
TorchModel.__init__(self, obs_space, num_outputs, options)
hiddens = options.get("fcnet_hiddens")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this have a default?

Copy link
Contributor

@richardliaw richardliaw Jan 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right now, if not provided, this fails at line 26 (which then you may as well make this options["fcnet_hiddens"])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's filled in by agent, so it should be ok either way.

obs: observations and features"""
res = self._convs(obs)
def _hidden_layers(self, obs):
res = self._convs(obs.permute(0, 3, 1, 2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be documented somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


return (TupleActions(list(actions.transpose([1, 0]))),
hiddens.transpose([1, 0, 2]), {})
return TupleActions(list(actions.transpose([1, 0]))), hiddens, {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the second [1, 0, 2] not needed anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah (see comment above), hiddens are handled differently now.

@ericl ericl merged commit 47d36d7 into ray-project:master Jan 3, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tests-ok The tagger certifies test failures are unrelated and assumes personal liability.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants