-
Notifications
You must be signed in to change notification settings - Fork 22.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[jit] Make nn.Transformer
TorchScript compatible
#28561
Conversation
This makes MultiheadedAttention TorchScript compatible. It changes the BC-compatible code in `forward` to use `__setstate__` instead so that `torch.load` still works correctly for old models
…nto driazati/transformer/2
Add BC breaking tag for release note. @gchanan |
nn.Transformer
TorchScript compatiblenn.Transformer
TorchScript compatible
Remember to add label "topic: bc-breaking" to such changes as well :). |
@@ -171,11 +173,10 @@ def forward(self, src, mask=None, src_key_padding_mask=None): | |||
""" | |||
output = src | |||
|
|||
for i in range(self.num_layers): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So range
is not jitable as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ModuleList
s can't be indexed, but can be used with for-in
loops. This should be the same
|
||
if self.norm: | ||
if self.norm is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change due to jitable
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right the check has to be that the value is not None so this branch doesn't get compiled if there is no code to call (i.e. if norm
is None
)
torch/nn/modules/transformer.py
Outdated
|
||
def __init__(self, decoder_layer, num_layers, norm=None): | ||
super(TransformerDecoder, self).__init__() | ||
self.layers = _get_clones(decoder_layer, num_layers) | ||
self.num_layers = num_layers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a BC breaking
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like redundant info since someone can just do len(self.layers)
to get the same value. Since this change is BC-breaking anyways I think this cleanup is warranted
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since I want to cover BC for transformer, please add those two attributes back and I will approve it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.num_layers
is added to Encoder but not Decoder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need tests to cover the scriptable transformer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are thinking about the _load_from_state_dict
func to avoid BC breaking. See the PR I added. #29001
By the time 1.4 is released there will have been a whole minor version for people to update their models, which I think is enough time to make a BC-breaking change here on loading old models. Either way it'd be good to settle on something, the first version of these PRs (this and #28555) had the proper BC-maintaining changes which I reverted based on our discussions. |
Yes. Based on our discussions, we agreed to remove Since I have the PR (#28555) to add |
Please consider to overwrite |
@@ -355,10 +366,7 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, | |||
key_padding_mask=memory_key_padding_mask)[0] | |||
tgt = tgt + self.dropout2(tgt2) | |||
tgt = self.norm2(tgt) | |||
if hasattr(self, "activation"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this gets patched up in __setstate__
, the other case is not necessary
nn.Transformer
TorchScript compatiblenn.Transformer
TorchScript compatible
torch/nn/modules/transformer.py
Outdated
|
||
def __init__(self, encoder_layer, num_layers, norm=None): | ||
super(TransformerEncoder, self).__init__() | ||
self.layers = _get_clones(encoder_layer, num_layers) | ||
self.num_layers = num_layers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this will be a BC breaking. Prefer not to remove it since we won't claim a BC breaking for this module. Will this block a jitable transformer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I re-added it in the latest commit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more comment. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.num_layers is added to Encoder but not Decoder. Please add it to Decoder and the PR is ready to merge. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: This makes `nn.Transformer` usable from TorchScript. It preserves backwards compatibility via `__setstate__` on the encoder/decoder. Fixes pytorch#24173 Pull Request resolved: pytorch#28561 Differential Revision: D18124753 Pulled By: driazati fbshipit-source-id: 7314843e5aa9c9bf974c4672e4edb24ed8ef4a6f
This makes
nn.Transformer
usable from TorchScript. It preserves backwards compatibility via__setstate__
on the encoder/decoder.Fixes #24173
Differential Revision: D18124753