-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transformer #19274
Transformer #19274
Conversation
torch/nn/modules/transformer.py
Outdated
self.d_model = d_model | ||
|
||
def forward(self, x): | ||
return self.lut(x) * math.sqrt(self.d_model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine if you move this explicitly into the callsites instead of creating another layer, since self.d_model is only a scalar.
torch/nn/modules/transformer.py
Outdated
self.linear2 = nn.Linear(d_ff, d_model) | ||
|
||
def forward(self, x): | ||
return self.linear2(self.dropout(F.relu(self.linear1(x)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a very common pattern in general. For now I think it's ok if you do this explicitly instead of using a separate Module within Encoder and DecoderLayer and comment that it could be fused.
Overall I think the temporary layers should be made private or merged into the respective models (and then marked as fuse-able). That'll minimize the number of potential additions to our nn library to only the Transformer components. After we the tests pass and we have more docs we should pull in more people to see if we can get them to use it within their libraries. If this diff becomes rather contentious, we should consider merging it into torchtext first. |
cc @soumith |
cc @gchanan |
46f965b
to
cccab3c
Compare
test/test_nn.py
Outdated
d_ff = 64 | ||
batch_size = 400 | ||
milliseconds = int(round(time.time() * 1000)) | ||
torch.manual_seed(milliseconds) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is just a unit test for the transformer module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you not use any of the existing test helpers in this file?
test/test_nn.py
Outdated
torch.manual_seed(milliseconds) | ||
|
||
model = Transformer(V, V, d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers, | ||
num_decoder_layers=num_decoder_layers, d_ff=d_ff) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is "d_ff"? The name is a bit nondescript.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will make some changes in the next commit.
test/test_nn.py
Outdated
model.eval() | ||
|
||
for _ in range(3): | ||
src = Variable(torch.randint(1, V, (10, 1))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought "Variable" was deprecated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see a warning for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's deprecated -- we don't warn on some ubiquitous calls because basically all pytorch code uses them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, just src = torch.randint(1, V, (10, 1))
is sufficient
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but again, see my comments above. a full convergence test is no place for our unit test suite, it does not catch implementation bugs.
test/test_nn.py
Outdated
src = Variable(torch.randint(1, V, (10, 1))) | ||
src[0][0] = 1 | ||
tgt = generate_test(model, src, max_len=10, start_symbol=1) | ||
assert np.allclose(src, tgt, atol=1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a torch "allclose" equivalent. You can't always assume np is available. If you have to rely on numpy, decorate the test to be skipped if numpy isn't available (see other tests for code examples).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will search and make some changes.
test/test_nn.py
Outdated
@@ -8572,5 +8607,109 @@ def _buildEquivalentAffineTransforms3d(device, input_size, output_size, angle_ra | |||
# end TestNN.test_affine_* helpers | |||
|
|||
|
|||
# The following are some helpers for TestNN.test_transformer_number_match | |||
class DataBatch: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If these are specific to your test, I'd add them to the local scope of that test, i.e. move them into the function itself. This way you avoid making these available to other tests and clearly specify the scope.
test/test_nn.py
Outdated
self.trg_y = trg[:, 1:].transpose(0, 1).contiguous() | ||
self.ntokens = (self.trg_y != pad).data.sum().item() | ||
|
||
class NoamOpt: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be an optimizer that already does this and if not, if this is common, we should add it explicitly. What features do our current optimizers lack, that this one provides?
test/test_nn.py
Outdated
return self.factor * (self.model_size ** (-0.5) * | ||
min(step ** (-0.5), step * self.warmup ** (-1.5))) | ||
|
||
class LabelSmoothingLoss(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is a common loss for Transformers, we should also add it to our losses.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove before finalizing the PR (see my comments above)
test/test_nn.py
Outdated
self.true_dist = true_dist | ||
return self.criterion(x, Variable(true_dist, requires_grad=False)) | ||
|
||
def train_epoch(data_iter, model, criterion, opt): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is a common name, it's important to be local in scope to the test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove before finalizing the PR (see my comments above)
test/test_nn.py
Outdated
@@ -7699,6 +7701,39 @@ def test_adaptive_log_softmax(self): | |||
out = asfm.predict(x) | |||
self.assertEqual(out, asfm.log_prob(x).argmax(dim=1)) | |||
|
|||
def test_transformer_number_match(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This types of integration tests are definitely useful. Are there smaller unit tests we could also look at to test for edge cases (empty input etc. etc.)?
test/test_nn.py
Outdated
print("Training the model...") | ||
for epoch in range(10): | ||
model.train() | ||
train_epoch(data_gen(V, batch_size, 21), model, criterion, model_opt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd move the Embedding outside of the Transformer. That is, the result of data_gen will be passed through an instance of nn.Embedding outside of this forward call. The reasoning for that is, that someone might want to use a different kind of Embedding and will then need to reimplement their own Transformer. We also don't do this for LSTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is crucial for libs like allennlp or pytext.
torch/nn/modules/transformer.py
Outdated
def __init__(self, src_vocab, tgt_vocab, d_model=512, nhead=8, num_encoder_layers=6, | ||
num_decoder_layers=6, d_ff=2048, dropout=0.1): | ||
encoder_layer = TransformerEncoderLayer(d_model, nhead, d_ff, dropout) | ||
src_embed = Embeddings(d_model, src_vocab) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd exclude this from the Transformer.
torch/nn/modules/transformer.py
Outdated
num_decoder_layers=6, d_ff=2048, dropout=0.1): | ||
encoder_layer = TransformerEncoderLayer(d_model, nhead, d_ff, dropout) | ||
src_embed = Embeddings(d_model, src_vocab) | ||
pos_encoder = PositionalEncoding(d_model, dropout) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same with the positional encoding.
torch/nn/modules/transformer.py
Outdated
|
||
|
||
# Temporarily leave LayerNorm module here. Will be moved somewhere else. | ||
class LayerNorm(Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be replaceable by a normalization layer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have a native LayerNorm as well... https://pytorch.org/docs/stable/nn.html?highlight=layernorm#torch.nn.LayerNorm
torch/nn/modules/transformer.py
Outdated
def __init__(self, d_model, nhead, d_ff=2048, dropout=0.1): | ||
super(TransformerEncoderLayer, self).__init__() | ||
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) | ||
self.ff = FeedForward(d_model, d_ff=d_ff, dropout=dropout) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest we inline the implementation of FeedForward until we're implementing a fused layer or such.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reasoning is that FeedForward is quite small and we've historically decided to not introduce an abstraction for this type of model, but rather require the user to implement it herself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you inline it, do mark it as such, so that we can possibly pull it back out later on
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea I agree.
torch/nn/modules/transformer.py
Outdated
|
||
|
||
# Temporarily leave FeedForward module here. Will be moved somewhere else. | ||
class FeedForward(Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest we remove this entirely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i concur
torch/nn/modules/transformer.py
Outdated
|
||
|
||
# Temporarily leave Generator module here. Will be moved somewhere else. | ||
class Generator(Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also is very small and should be inlined.
torch/nn/modules/transformer.py
Outdated
|
||
|
||
# Temporarily leave Embeddings module here. Will be moved somewhere else. | ||
class Embeddings(Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be left to the user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed
torch/nn/modules/transformer.py
Outdated
|
||
|
||
# Temporarily leave PositionalEncoding module here. Will be moved somewhere else. | ||
class PositionalEncoding(Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be left to the user. It's something you could define locally for your integration test.
torch/nn/modules/transformer.py
Outdated
|
||
return output | ||
|
||
def encode(self, src, src_mask=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this method necessary?
torch/nn/modules/transformer.py
Outdated
output = self.src_embed(output) | ||
|
||
if self.pos_encoder: | ||
output = self.pos_encoder(output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can assume that the user is passing the result of something akin to "self.src_embed(self.pos_encoder(src))" directly as the input (src). This is not something I'd add statically. We usually require the user to do the embedding herself and then pass in the result vector to our models.
torch/nn/modules/transformer.py
Outdated
User is able to modified the attributes as needed. | ||
|
||
Args: | ||
src_vocab: the number of vocabularies in the source sequence (required). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you remove the Embedding specialization you can call this input_size
Summary: Pull Request resolved: pytorch#19402 This pass propagate the qparams calculated after calibration to the quant nodes which will be used later for quantization Differential Revision: D14995230 fbshipit-source-id: 5709153ea1c039c4ab4470ddb689a303b0bcc6fd
Summary: Pull Request resolved: pytorch#19680 This was broken for quite some time because of an operator schema check that went into effect at some point in time. Reviewed By: manojkris Differential Revision: D15055082 fbshipit-source-id: 7f730f9b810bdaffd69bab7ac4d02c5b2e40645b
Summary: Pull Request resolved: pytorch#19966 Reviewed By: yinghai Differential Revision: D15096086 fbshipit-source-id: 8e6a26c46898f99d411dd5841f086946884b2457
…ytorch#19910) Summary: Pull Request resolved: pytorch#19910 This change modifies the quant-dequant node pattern from qparam->q->dq to qparam->q->int_repr->qparam->dq. The motivation for this change is to make the qparams required for op substition one level up at dequant node instead of multiple levels up. Differential Revision: D15120146 fbshipit-source-id: 74b0fd5cb50a338f562740a9cc727a7791c718c3
Summary: Pull Request resolved: pytorch#20001 att Reviewed By: zrphercule Differential Revision: D15164116 fbshipit-source-id: dab19fb84fa0ab648103317af5509703db918682
Summary: Pull Request resolved: pytorch#19981 Differential Revision: D15174219 Pulled By: pjh5 fbshipit-source-id: 205952aa90ed93f193f40d4293f5a8d82fa33ed6
Summary: Stack from [ghstack](https://github.com/ezyang/ghstack): * **pytorch#19686 [jit] Remove try/catch in constant propagation** The try-catch here gets tripped pretty often when constant prop is run which screws up `catch throw` in gdb.](https://our.intern.facebook.com/intern/diff/15170134/) Pull Request resolved: pytorch#19686 Pulled By: driazati Differential Revision: D15170134 fbshipit-source-id: 93688561126f3ab582c8358e8f2787f7fce9aa73
Summary: Stack from [ghstack](https://github.com/ezyang/ghstack): * **pytorch#20026 Remove warnings on new_* constructors** Revert of pytorch#16770, fixes pytorch#19995 Pull Request resolved: pytorch#20026 Pulled By: driazati Differential Revision: D15171691 fbshipit-source-id: 057c3b4a9fd6086ca240007e5404a286080f04b6
Summary: Pull Request resolved: pytorch#19999 ghimport-source-id: 81157c6 Differential Revision: D15169025 Pulled By: ljk53 fbshipit-source-id: 8e6f8df6dec6d21d6c7e743e974f4fcfff7cdeb5
Summary: Pull Request resolved: pytorch#19760 ghimport-source-id: d0aabee Differential Revision: D15087655 Pulled By: ljk53 fbshipit-source-id: ac133dfc2301c2a86c41b4b8f1483d7d23824e1e
Summary: Fixes pytorch#19314 Pull Request resolved: pytorch#19380 Differential Revision: D15167858 Pulled By: Krovatkin fbshipit-source-id: e87261bbf3e6f8df0601df80280eb3dba42798cd
…3599ef (pytorch#20012) Summary: Pull Request resolved: pytorch#20012 Previous import was f1311e74ec8a91cbf86094cd6f10157cbf00c536 Included changes: - **[7d7bc83d](onnx/onnx@7d7bc83d)**: fix shape inference (pytorch#1984) <Ashwini Khade> - **[68630bbd](onnx/onnx@68630bbd)**: fixing some of Mod test cases (pytorch#1962) <Jeff Saremi> Reviewed By: zrphercule Differential Revision: D15160934 fbshipit-source-id: c53aff401f56b2febeb6c4ee302670eb12b9b495
Avoid loading modules from torch.nn in transformer.py. clean some lint errors in transformer.py and test_nn.py. A few lint errors. Fix continuation line over-indented.
Change d_ff to dim_feedforward. more help functions to test_transformer_number_match. Remove Embeddings and PositionalEncoder out of transformer.py. Remove LayerNorm and FeedForward in transformer.py. Use the corresponding modules in torch.nn. Remove test_transformer_number_match from test_nn.py. Remove Embeddings and PositionalEncoder out of transformer.py. Update transformer docs. Minor updates in test_nn.py. Fix lint errors.
c23dd64
to
a108297
Compare
I accidentally rebase and the commit log is too messy. Will create a new PR for merge request. |
Summary: Accidentally rebased the old PR and make it too messy. Find it here (#19274) Create a PR for comments. The model is still WIP but I want to have some feedbacks before moving too far. The transformer model depends on several modules, like MultiheadAttention (landed). Transformer is implemented based on the paper (https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf). Users have the flexibility to build a transformer with self-defined and/or built-in components (i.e encoder, decoder, encoder_layer, decoder_layer). Users could use Transformer class to build a standard transformer model and modify sub-layers as needed. Add a few unit tests for the transformer module, as follow: TestNN.test_Transformer_cell TestNN.test_transformerencoderlayer TestNN.test_transformerdecoderlayer TestNN.test_transformer_args_check TestScript.test_scriptmodule_transformer_cuda There is another demonstration example for applying transformer module on the word language problem. pytorch/examples#555 Pull Request resolved: #20170 Differential Revision: D15417983 Pulled By: zhangguanheng66 fbshipit-source-id: 7ce771a7e27715acd9a23d60bf44917a90d1d572
Create a PR for comments. The model is still WIP but I want to have some feedbacks before moving too far. The transformer model depends on several modules, like MultiheadAttention (landed), PositionalEncoding, FeedForward... If they are available already, I will switch to the existing ones. Otherwise, those modules could be landed by separate PRs.
Transformer is implemented based on the paper (https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf). Users have the flexibility to build a transformer with self-defined and/or built-in components (i.e encoder, decoder, encoder_layer, decoder_layer). Calling buildTransformerModel() function will generate a generic transformer model. Then, Users could modify sub-layers as needed.
A test case was created to train a simple transformer model (see TestNN.test_transformer_number_match).