Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jit] Make nn.Transformer TorchScript compatible #28561

Closed
wants to merge 21 commits into from

Conversation

driazati
Copy link
Contributor

@driazati driazati commented Oct 23, 2019

This makes nn.Transformer usable from TorchScript. It preserves backwards compatibility via __setstate__ on the encoder/decoder.

Fixes #24173

Differential Revision: D18124753

Your Name added 2 commits October 23, 2019 15:00
This makes MultiheadedAttention TorchScript compatible. It changes the
BC-compatible code in `forward` to use `__setstate__` instead so that
`torch.load` still works correctly for old models
@zhangguanheng66
Copy link
Contributor

Add BC breaking tag for release note. @gchanan

@zhangguanheng66 zhangguanheng66 changed the title [jit] Make nn.Transformer TorchScript compatible [jit][BC-breaking] Make nn.Transformer TorchScript compatible Oct 25, 2019
@gchanan gchanan added the module: bc-breaking Related to a BC-breaking change label Oct 27, 2019
@gchanan
Copy link
Contributor

gchanan commented Oct 27, 2019

Remember to add label "topic: bc-breaking" to such changes as well :).

@@ -171,11 +173,10 @@ def forward(self, src, mask=None, src_key_padding_mask=None):
"""
output = src

for i in range(self.num_layers):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So range is not jitable as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ModuleLists can't be indexed, but can be used with for-in loops. This should be the same


if self.norm:
if self.norm is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change due to jitable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right the check has to be that the value is not None so this branch doesn't get compiled if there is no code to call (i.e. if norm is None)


def __init__(self, decoder_layer, num_layers, norm=None):
super(TransformerDecoder, self).__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a BC breaking

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like redundant info since someone can just do len(self.layers) to get the same value. Since this change is BC-breaking anyways I think this cleanup is warranted

Copy link
Contributor

@zhangguanheng66 zhangguanheng66 Nov 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I want to cover BC for transformer, please add those two attributes back and I will approve it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.num_layers is added to Encoder but not Decoder.

@driazati driazati changed the base branch from driazati/transformer/1 to master October 30, 2019 00:55
@driazati driazati changed the base branch from master to driazati/transformer/1 October 30, 2019 00:55
Copy link
Contributor

@zhangguanheng66 zhangguanheng66 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need tests to cover the scriptable transformer?

Copy link
Contributor

@zhangguanheng66 zhangguanheng66 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are thinking about the _load_from_state_dict func to avoid BC breaking. See the PR I added. #29001

@driazati
Copy link
Contributor Author

driazati commented Nov 4, 2019

By the time 1.4 is released there will have been a whole minor version for people to update their models, which I think is enough time to make a BC-breaking change here on loading old models.

Either way it'd be good to settle on something, the first version of these PRs (this and #28555) had the proper BC-maintaining changes which I reverted based on our discussions.

@zhangguanheng66
Copy link
Contributor

By the time 1.4 is released there will have been a whole minor version for people to update their models, which I think is enough time to make a BC-breaking change here on loading old models.

Either way it'd be good to settle on something, the first version of these PRs (this and #28555) had the proper BC-maintaining changes which I reverted based on our discussions.

Yes. Based on our discussions, we agreed to remove hasattr func, which introduces BC breaking. Adding _load_from_state_dict could avoid such breaking while your current version still works.

Since I have the PR (#28555) to add _load_from_state_dict func for MultiheadAttention, I will add _load_from_state_dict func to nn.Transformer there, as well.

@zhangguanheng66
Copy link
Contributor

Please consider to overwrite __setstate__ func to avoid BC breaking. #29001.
Once it's done, we should remove the BC-breaking flag.

@@ -355,10 +366,7 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
if hasattr(self, "activation"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this gets patched up in __setstate__, the other case is not necessary

@driazati driazati changed the title [jit][BC-breaking] Make nn.Transformer TorchScript compatible [jit] Make nn.Transformer TorchScript compatible Nov 27, 2019

def __init__(self, encoder_layer, num_layers, norm=None):
super(TransformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this will be a BC breaking. Prefer not to remove it since we won't claim a BC breaking for this module. Will this block a jitable transformer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I re-added it in the latest commit

Copy link
Contributor

@zhangguanheng66 zhangguanheng66 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more comment. Thanks.

torch/nn/modules/transformer.py Outdated Show resolved Hide resolved
@driazati driazati removed the module: bc-breaking Related to a BC-breaking change label Nov 28, 2019
Copy link
Contributor

@zhangguanheng66 zhangguanheng66 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.num_layers is added to Encoder but not Decoder. Please add it to Decoder and the PR is ready to merge. Thanks.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@driazati merged this pull request in 1f87e82.

driazati pushed a commit that referenced this pull request Dec 17, 2019
wuhuikx pushed a commit to wuhuikx/pytorch that referenced this pull request Jan 30, 2020
Summary:
This makes `nn.Transformer` usable from TorchScript. It preserves backwards compatibility via `__setstate__` on the encoder/decoder.

Fixes pytorch#24173
Pull Request resolved: pytorch#28561

Differential Revision: D18124753

Pulled By: driazati

fbshipit-source-id: 7314843e5aa9c9bf974c4672e4edb24ed8ef4a6f
@facebook-github-bot facebook-github-bot deleted the driazati/transformer/2 branch July 13, 2020 17:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Transformer model seems not supported in TorchScript?
6 participants