Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for masked input in TrainingSampler #546

Merged

Conversation

kazemnejad
Copy link
Contributor

@kazemnejad kazemnejad commented Sep 27, 2019

By applying this patch, TrainingSampler will be able to support both mask and sequnce_length arguments. But sequnce_length will have higher priority. So if both are provided, the mask argument will be ignored. I had to add this priority because Keras implicitly passes the mask information to layers upon the invocation (even in model sub-classing). Without this priority, the user will face an error by using the following snippet (due to the automatically passed mask argument):

some_masked_input = ...
my_decoder(some_masked_input, initial_state=..., sequence_length=sequence_length)

To overcome this issue, the user had to manually delete the _keras_mask attribute from the some_masked_input variable.( e.g. delattr(some_masked_input, '_keras_mask')) to be able to apply the sequence_length parameter only.

It Fixes (#534) by adding support for masked inputs. Also, this will bring better Keras integration.

Fixes #534

@kazemnejad kazemnejad changed the title Training helper masked input Add support for masked input in TrainingSampler Sep 27, 2019
@kazemnejad
Copy link
Contributor Author

@seanpmorgan Could you please force run the CI tests on this PR?

guillaumekln
guillaumekln previously approved these changes Sep 30, 2019
Copy link
Contributor

@guillaumekln guillaumekln left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

@qlzh727 should also validate before merging.

Copy link
Member

@qlzh727 qlzh727 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. The major change I would like is to raise error if both sequence length and masking are provided.

tensorflow_addons/seq2seq/sampler.py Outdated Show resolved Hide resolved
tensorflow_addons/seq2seq/basic_decoder_test.py Outdated Show resolved Hide resolved
tensorflow_addons/seq2seq/basic_decoder_test.py Outdated Show resolved Hide resolved
tensorflow_addons/seq2seq/sampler.py Outdated Show resolved Hide resolved
It will throw an error if both sequence_length and mask are provided.
@kazemnejad
Copy link
Contributor Author

kazemnejad commented Oct 1, 2019

@qlzh727 Did you have a chance to take a look at the new commits?

tensorflow_addons/seq2seq/sampler.py Outdated Show resolved Hide resolved
tensorflow_addons/seq2seq/sampler.py Outdated Show resolved Hide resolved
qlzh727
qlzh727 previously approved these changes Oct 7, 2019
@seanpmorgan seanpmorgan merged commit 83d28c9 into tensorflow:master Oct 7, 2019
@kazemnejad
Copy link
Contributor Author

@qlzh727 @guillaumekln Thank you very much for your feedbacks.

@guillaumekln
Copy link
Contributor

Thank you for these improvements. Looking forward to the seq2seq tutorial!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BasicDecoder fails with masked input tensor
6 participants