Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformer #19274

Closed
wants to merge 524 commits into from
Closed

Transformer #19274

wants to merge 524 commits into from

Conversation

zhangguanheng66
Copy link
Contributor

Create a PR for comments. The model is still WIP but I want to have some feedbacks before moving too far. The transformer model depends on several modules, like MultiheadAttention (landed), PositionalEncoding, FeedForward... If they are available already, I will switch to the existing ones. Otherwise, those modules could be landed by separate PRs.

Transformer is implemented based on the paper (https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf). Users have the flexibility to build a transformer with self-defined and/or built-in components (i.e encoder, decoder, encoder_layer, decoder_layer). Calling buildTransformerModel() function will generate a generic transformer model. Then, Users could modify sub-layers as needed.

A test case was created to train a simple transformer model (see TestNN.test_transformer_number_match).

@zhangguanheng66
Copy link
Contributor Author

@cpuhrsch @soumith @myleott. Any feedbacks are appreciated.

self.d_model = d_model

def forward(self, x):
return self.lut(x) * math.sqrt(self.d_model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine if you move this explicitly into the callsites instead of creating another layer, since self.d_model is only a scalar.

self.linear2 = nn.Linear(d_ff, d_model)

def forward(self, x):
return self.linear2(self.dropout(F.relu(self.linear1(x))))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very common pattern in general. For now I think it's ok if you do this explicitly instead of using a separate Module within Encoder and DecoderLayer and comment that it could be fused.

@cpuhrsch
Copy link
Contributor

Overall I think the temporary layers should be made private or merged into the respective models (and then marked as fuse-able). That'll minimize the number of potential additions to our nn library to only the Transformer components. After we the tests pass and we have more docs we should pull in more people to see if we can get them to use it within their libraries. If this diff becomes rather contentious, we should consider merging it into torchtext first.

@cpuhrsch
Copy link
Contributor

cc @soumith

@cpuhrsch
Copy link
Contributor

cc @gchanan

test/test_nn.py Outdated
d_ff = 64
batch_size = 400
milliseconds = int(round(time.time() * 1000))
torch.manual_seed(milliseconds)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is just a unit test for the transformer module.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you not use any of the existing test helpers in this file?

test/test_nn.py Outdated
torch.manual_seed(milliseconds)

model = Transformer(V, V, d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers, d_ff=d_ff)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is "d_ff"? The name is a bit nondescript.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will make some changes in the next commit.

test/test_nn.py Outdated
model.eval()

for _ in range(3):
src = Variable(torch.randint(1, V, (10, 1)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought "Variable" was deprecated

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a warning for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's deprecated -- we don't warn on some ubiquitous calls because basically all pytorch code uses them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, just src = torch.randint(1, V, (10, 1)) is sufficient

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but again, see my comments above. a full convergence test is no place for our unit test suite, it does not catch implementation bugs.

test/test_nn.py Outdated
src = Variable(torch.randint(1, V, (10, 1)))
src[0][0] = 1
tgt = generate_test(model, src, max_len=10, start_symbol=1)
assert np.allclose(src, tgt, atol=1e-5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a torch "allclose" equivalent. You can't always assume np is available. If you have to rely on numpy, decorate the test to be skipped if numpy isn't available (see other tests for code examples).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will search and make some changes.

test/test_nn.py Outdated
@@ -8572,5 +8607,109 @@ def _buildEquivalentAffineTransforms3d(device, input_size, output_size, angle_ra
# end TestNN.test_affine_* helpers


# The following are some helpers for TestNN.test_transformer_number_match
class DataBatch:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these are specific to your test, I'd add them to the local scope of that test, i.e. move them into the function itself. This way you avoid making these available to other tests and clearly specify the scope.

test/test_nn.py Outdated
self.trg_y = trg[:, 1:].transpose(0, 1).contiguous()
self.ntokens = (self.trg_y != pad).data.sum().item()

class NoamOpt:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be an optimizer that already does this and if not, if this is common, we should add it explicitly. What features do our current optimizers lack, that this one provides?

test/test_nn.py Outdated
return self.factor * (self.model_size ** (-0.5) *
min(step ** (-0.5), step * self.warmup ** (-1.5)))

class LabelSmoothingLoss(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is a common loss for Transformers, we should also add it to our losses.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove before finalizing the PR (see my comments above)

test/test_nn.py Outdated
self.true_dist = true_dist
return self.criterion(x, Variable(true_dist, requires_grad=False))

def train_epoch(data_iter, model, criterion, opt):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a common name, it's important to be local in scope to the test.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove before finalizing the PR (see my comments above)

test/test_nn.py Outdated
@@ -7699,6 +7701,39 @@ def test_adaptive_log_softmax(self):
out = asfm.predict(x)
self.assertEqual(out, asfm.log_prob(x).argmax(dim=1))

def test_transformer_number_match(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This types of integration tests are definitely useful. Are there smaller unit tests we could also look at to test for edge cases (empty input etc. etc.)?

test/test_nn.py Outdated
print("Training the model...")
for epoch in range(10):
model.train()
train_epoch(data_gen(V, batch_size, 21), model, criterion, model_opt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd move the Embedding outside of the Transformer. That is, the result of data_gen will be passed through an instance of nn.Embedding outside of this forward call. The reasoning for that is, that someone might want to use a different kind of Embedding and will then need to reimplement their own Transformer. We also don't do this for LSTM.

Copy link

@MaksymDel MaksymDel May 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is crucial for libs like allennlp or pytext.

def __init__(self, src_vocab, tgt_vocab, d_model=512, nhead=8, num_encoder_layers=6,
num_decoder_layers=6, d_ff=2048, dropout=0.1):
encoder_layer = TransformerEncoderLayer(d_model, nhead, d_ff, dropout)
src_embed = Embeddings(d_model, src_vocab)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd exclude this from the Transformer.

num_decoder_layers=6, d_ff=2048, dropout=0.1):
encoder_layer = TransformerEncoderLayer(d_model, nhead, d_ff, dropout)
src_embed = Embeddings(d_model, src_vocab)
pos_encoder = PositionalEncoding(d_model, dropout)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with the positional encoding.



# Temporarily leave LayerNorm module here. Will be moved somewhere else.
class LayerNorm(Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be replaceable by a normalization layer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def __init__(self, d_model, nhead, d_ff=2048, dropout=0.1):
super(TransformerEncoderLayer, self).__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
self.ff = FeedForward(d_model, d_ff=d_ff, dropout=dropout)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest we inline the implementation of FeedForward until we're implementing a fused layer or such.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reasoning is that FeedForward is quite small and we've historically decided to not introduce an abstraction for this type of model, but rather require the user to implement it herself.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you inline it, do mark it as such, so that we can possibly pull it back out later on

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea I agree.



# Temporarily leave FeedForward module here. Will be moved somewhere else.
class FeedForward(Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest we remove this entirely.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i concur



# Temporarily leave Generator module here. Will be moved somewhere else.
class Generator(Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also is very small and should be inlined.



# Temporarily leave Embeddings module here. Will be moved somewhere else.
class Embeddings(Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be left to the user.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed



# Temporarily leave PositionalEncoding module here. Will be moved somewhere else.
class PositionalEncoding(Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be left to the user. It's something you could define locally for your integration test.


return output

def encode(self, src, src_mask=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this method necessary?

output = self.src_embed(output)

if self.pos_encoder:
output = self.pos_encoder(output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can assume that the user is passing the result of something akin to "self.src_embed(self.pos_encoder(src))" directly as the input (src). This is not something I'd add statically. We usually require the user to do the embedding herself and then pass in the result vector to our models.

User is able to modified the attributes as needed.

Args:
src_vocab: the number of vocabularies in the source sequence (required).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you remove the Embedding specialization you can call this input_size

nishantpdce and others added 14 commits May 6, 2019 09:07
Summary:
Pull Request resolved: pytorch#19402

This pass propagate the qparams calculated after calibration to the
quant nodes which will be used later for quantization

Differential Revision: D14995230

fbshipit-source-id: 5709153ea1c039c4ab4470ddb689a303b0bcc6fd
Summary:
Pull Request resolved: pytorch#19680

This was broken for quite some time because of an operator schema
check that went into effect at some point in time.

Reviewed By: manojkris

Differential Revision: D15055082

fbshipit-source-id: 7f730f9b810bdaffd69bab7ac4d02c5b2e40645b
Summary: Pull Request resolved: pytorch#19966

Reviewed By: yinghai

Differential Revision: D15096086

fbshipit-source-id: 8e6a26c46898f99d411dd5841f086946884b2457
…ytorch#19910)

Summary:
Pull Request resolved: pytorch#19910

This change modifies the quant-dequant node pattern from
qparam->q->dq to qparam->q->int_repr->qparam->dq. The motivation for
this change is to make the qparams required for op substition one level
up at dequant node instead of multiple levels up.

Differential Revision: D15120146

fbshipit-source-id: 74b0fd5cb50a338f562740a9cc727a7791c718c3
Summary:
Pull Request resolved: pytorch#20001

att

Reviewed By: zrphercule

Differential Revision: D15164116

fbshipit-source-id: dab19fb84fa0ab648103317af5509703db918682
Summary: Pull Request resolved: pytorch#19981

Differential Revision: D15174219

Pulled By: pjh5

fbshipit-source-id: 205952aa90ed93f193f40d4293f5a8d82fa33ed6
Summary:
Stack from [ghstack](https://github.com/ezyang/ghstack):
* **pytorch#19686 [jit] Remove try/catch in constant propagation**

The try-catch here gets tripped pretty often when constant prop is run which screws up `catch throw` in gdb.](https://our.intern.facebook.com/intern/diff/15170134/)
Pull Request resolved: pytorch#19686

Pulled By: driazati

Differential Revision: D15170134

fbshipit-source-id: 93688561126f3ab582c8358e8f2787f7fce9aa73
Summary:
Stack from [ghstack](https://github.com/ezyang/ghstack):
* **pytorch#20026 Remove warnings on new_* constructors**

Revert of pytorch#16770, fixes pytorch#19995
Pull Request resolved: pytorch#20026

Pulled By: driazati

Differential Revision: D15171691

fbshipit-source-id: 057c3b4a9fd6086ca240007e5404a286080f04b6
Summary:
Pull Request resolved: pytorch#19999
ghimport-source-id: 81157c6

Differential Revision: D15169025

Pulled By: ljk53

fbshipit-source-id: 8e6f8df6dec6d21d6c7e743e974f4fcfff7cdeb5
Summary:
Pull Request resolved: pytorch#19760
ghimport-source-id: d0aabee

Differential Revision: D15087655

Pulled By: ljk53

fbshipit-source-id: ac133dfc2301c2a86c41b4b8f1483d7d23824e1e
Summary:
Fixes pytorch#19314
Pull Request resolved: pytorch#19380

Differential Revision: D15167858

Pulled By: Krovatkin

fbshipit-source-id: e87261bbf3e6f8df0601df80280eb3dba42798cd
…3599ef (pytorch#20012)

Summary:
Pull Request resolved: pytorch#20012

Previous import was f1311e74ec8a91cbf86094cd6f10157cbf00c536

Included changes:
- **[7d7bc83d](onnx/onnx@7d7bc83d)**: fix shape inference (pytorch#1984) <Ashwini Khade>
- **[68630bbd](onnx/onnx@68630bbd)**: fixing some of Mod test cases (pytorch#1962) <Jeff Saremi>

Reviewed By: zrphercule

Differential Revision: D15160934

fbshipit-source-id: c53aff401f56b2febeb6c4ee302670eb12b9b495
Avoid loading modules from torch.nn in transformer.py.
clean some lint errors in transformer.py and test_nn.py.
A few lint errors.
Fix continuation line over-indented.
Change d_ff to dim_feedforward.
more help functions to test_transformer_number_match.
Remove Embeddings and PositionalEncoder out of transformer.py.
Remove LayerNorm and FeedForward in transformer.py. Use the corresponding modules in torch.nn.
Remove test_transformer_number_match from test_nn.py.
Remove Embeddings and PositionalEncoder out of transformer.py.
Update transformer docs.
Minor updates in test_nn.py.
Fix lint errors.
@zhangguanheng66
Copy link
Contributor Author

I accidentally rebase and the commit log is too messy. Will create a new PR for merge request.

@zhangguanheng66 zhangguanheng66 mentioned this pull request May 6, 2019
@zhangguanheng66 zhangguanheng66 deleted the transformer branch May 16, 2019 20:31
facebook-github-bot pushed a commit that referenced this pull request Jun 12, 2019
Summary:
Accidentally rebased the old PR and make it too messy. Find it here (#19274)

Create a PR for comments. The model is still WIP but I want to have some feedbacks before moving too far. The transformer model depends on several modules, like MultiheadAttention (landed).

Transformer is implemented based on the paper (https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf). Users have the flexibility to build a transformer with self-defined and/or built-in components (i.e encoder, decoder, encoder_layer, decoder_layer). Users could use Transformer class to build a standard transformer model and modify sub-layers as needed.

Add a few unit tests for the transformer module, as follow:
TestNN.test_Transformer_cell
TestNN.test_transformerencoderlayer
TestNN.test_transformerdecoderlayer
TestNN.test_transformer_args_check
TestScript.test_scriptmodule_transformer_cuda

There is another demonstration example for applying transformer module on the word language problem. pytorch/examples#555
Pull Request resolved: #20170

Differential Revision: D15417983

Pulled By: zhangguanheng66

fbshipit-source-id: 7ce771a7e27715acd9a23d60bf44917a90d1d572
@zhangguanheng66 zhangguanheng66 changed the title [WIP] Transformer Transformer Jun 12, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet