-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Implement TorchPolicy.export_model
.
#13989
Conversation
dummy_inputs = self._lazy_tensor_dict(self._dummy_batch.data) | ||
# Provide dummy state inputs if not an RNN (torch cannot jit with | ||
# returned empty internal states list). | ||
if "state_in_0" not in dummy_inputs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does torch jit require state_in even if the model is not an RNN?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, this is a total hack. However, torch requires the output to be some tensor (or nested struct of tensors), but NOT an empty list :/ That's why we need to fake it here. One more reason to keep thinking about a possible new ModelV3 API :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch jit doesn't require state_in, it's the self.model that requires it.
(dummy_inputs, state_ins, seq_lens)) | ||
if not os.path.exists(export_dir): | ||
os.makedirs(export_dir) | ||
file_name = os.path.join(export_dir, "model.pt") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also enable user to customize model name^
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, was thinking about this, too. The problem is: I didn't want to change the method's signature, which is:
export_model(self, export_dir)
. For TF, this is ok b/c a TF model export will produce many files (inside the export_dir).
For Torch, it's just a single file.
But yes, we could add an optional arg (filename=None) to the torch method (and then make the base's signature: export_dir, **kwargs
).
Once conflict is done, we can merge it ^^ |
Implement
TorchPolicy.export_model
.This method is currently missing and users get an NotImplementedError when trying to call it on any TorchPolicy.
This PR provides a solution storing the Policy's model as a TorchScript.
The only remaining problem is for non-RNNs: A fake state in tensor (any tensor is fine) must be provided when calling the TorchScript model. This is due to torch.jit not handling the otherwise empty internal states list that is returned by RLlib's ModelV2s.
Why are these changes needed?
Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.