Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Implement TorchPolicy.export_model. #13989

Merged
merged 3 commits into from
Feb 22, 2021

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented Feb 8, 2021

Implement TorchPolicy.export_model.

This method is currently missing and users get an NotImplementedError when trying to call it on any TorchPolicy.
This PR provides a solution storing the Policy's model as a TorchScript.
The only remaining problem is for non-RNNs: A fake state in tensor (any tensor is fine) must be provided when calling the TorchScript model. This is due to torch.jit not handling the otherwise empty internal states list that is returned by RLlib's ModelV2s.

Why are these changes needed?

Related issue number

Checks

  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

dummy_inputs = self._lazy_tensor_dict(self._dummy_batch.data)
# Provide dummy state inputs if not an RNN (torch cannot jit with
# returned empty internal states list).
if "state_in_0" not in dummy_inputs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does torch jit require state_in even if the model is not an RNN?

Copy link
Contributor Author

@sven1977 sven1977 Feb 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, this is a total hack. However, torch requires the output to be some tensor (or nested struct of tensors), but NOT an empty list :/ That's why we need to fake it here. One more reason to keep thinking about a possible new ModelV3 API :)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch jit doesn't require state_in, it's the self.model that requires it.

(dummy_inputs, state_ins, seq_lens))
if not os.path.exists(export_dir):
os.makedirs(export_dir)
file_name = os.path.join(export_dir, "model.pt")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also enable user to customize model name^

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, was thinking about this, too. The problem is: I didn't want to change the method's signature, which is:
export_model(self, export_dir). For TF, this is ok b/c a TF model export will produce many files (inside the export_dir).
For Torch, it's just a single file.
But yes, we could add an optional arg (filename=None) to the torch method (and then make the base's signature: export_dir, **kwargs).

@michaelzhiluo
Copy link
Contributor

Once conflict is done, we can merge it ^^

@sven1977 sven1977 merged commit 95ef04b into ray-project:master Feb 22, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants