Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RNNCellDropoutWrapper applies dropout on the LSTM c state #33690

Closed
georgesterpu opened this issue Oct 24, 2019 · 6 comments
Closed

RNNCellDropoutWrapper applies dropout on the LSTM c state #33690

georgesterpu opened this issue Oct 24, 2019 · 6 comments
Assignees
Labels
comp:apis Highlevel API related issues TF 2.0 Issues relating to TensorFlow 2.0 type:feature Feature requests

Comments

@georgesterpu
Copy link
Contributor

georgesterpu commented Oct 24, 2019

System information

  • Have I written custom code: No
  • TensorFlow installed from: source (pip)
  • TensorFlow version: v2.0.0-rc2-26-g64c3d38 2.0.0
  • Python version: 3.7.4

Describe the current behavior
The _call_wrapped_cell method of the DropoutWrapperBase class applies dropout on both c and h states of an LSTM cell. Its default method to determine if a state should take dropout, _default_dropout_state_filter_visitor, only works correctly with LSTM states packed as a LSTMStateTuple namedtuple. This was fine in TensorFlow 1.x, where the LSTM state is passed around as a LSTMStateTuple. However, in TensorFlow 2.0 the state is a Python tuple, and the method returns True for both substates.

Describe the expected behavior
Exclude the LSTM c state from the list of dropout candidates.

@gadagashwini-zz gadagashwini-zz self-assigned this Oct 25, 2019
@gadagashwini-zz gadagashwini-zz added comp:apis Highlevel API related issues TF 2.0 Issues relating to TensorFlow 2.0 type:feature Feature requests labels Oct 25, 2019
@gowthamkpr gowthamkpr assigned qlzh727 and unassigned gowthamkpr Oct 25, 2019
@gowthamkpr gowthamkpr added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Oct 25, 2019
@qlzh727
Copy link
Member

qlzh727 commented Oct 28, 2019

For keras LSTM cell, it by default support dropout. You can use param dropout or recurrent dropout for that. The DropoutWrapper is not expected to be used with keras cells.

I am going to add an error message when keras lstm cell is used with dropout wrapper.

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Oct 29, 2019
tensorflow-copybara pushed a commit that referenced this issue Oct 29, 2019
See #33690 for details.

PiperOrigin-RevId: 277171614
Change-Id: I9925801a2b0c055c595ff20ecf063d6dd876e18e
@qlzh727 qlzh727 closed this as completed Oct 29, 2019
@georgesterpu
Copy link
Contributor Author

Is there any reason why RNNCellDropoutWrapper cannot be updated to work with tf.keras.layers.LSTMCell ? It seems that the only issue is the LSTMStateTuple code which needs to be ported.

The dropout and recurrent_dropout parameters of LSTMCell are not supporting all the features of DropoutWrapper. For example, the dropout mask is reused across timesteps in LSTMCell, which is not the default behaviour of DropoutWrapper. This aspect introduces new independent variables in all experiments running code ported from TF 1.x.

Furthermore, dropout in LSTMCell is not compatible with tfa.seq2seq.BasicDecoder (when using variable batch sizes) and ``tfa.seq2seq.BeamSearchDecoder` (because of the beam width multiplication changing the input shape).

@qlzh727
Copy link
Member

qlzh727 commented Nov 4, 2019

Sorry for the very late reply, I was at TF world last few days.

The dropout_wrapper was ported from v1 tf.nn.rnn API, and it has duplicated functionality wrt to existing keras API. We port all the wrappers since they provide some values to user, but we would prefer user to rely on the keras API since they are better integrated.

For this particular issue, the fix isn't very straight forward, since the keras lstm cell only returns list as [h, c]. We could check the cell type when it is passed in, but it possible to have the cell being wrapped by other wrapper, etc.

If you have other proposal, feel free to send a PR and I would be happy to review it. Thanks.

@georgesterpu
Copy link
Contributor Author

georgesterpu commented Nov 4, 2019

Thanks for the reply, @qlzh727.
I agree that we should focus our effort on a single API.
Would you mind looking at #33991, which describes an error when using dropout in LSTMCell ?

Btw, above I was saying that not all functionalities of DropoutWrapper have been ported in TF2.0/LSTMCell. It seems that variational_recurrent=False and output dropout (output_keep_prob) are no longer supported. Would it be possible to have an option in LSTM/LSTMCell to apply new dropout masks at every call ?

@qlzh727
Copy link
Member

qlzh727 commented Nov 7, 2019

Sorry for the late reply.

I will take a look for the issue you referred.

For the variational dropout, the default value is False in DropoutWrapper. However, I would expect most of the user to use True value. The reason to have False by default is being defensive and not changing user's code when adding new flags. You can check the paper for that, and see the performance/accuracy comparison there.

On keras side, we are using variational dropout by default, since we believe this is better than the non-variational version. We didn't expose the knob to user to control it based on the consideration for API complexity and ease of use.

Also for the output dropout, it can be easily achieved by adding a dropout layer to the output tensor.

@georgesterpu
Copy link
Contributor Author

Thanks again, @qlzh727

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:apis Highlevel API related issues TF 2.0 Issues relating to TensorFlow 2.0 type:feature Feature requests
Projects
None yet
Development

No branches or pull requests

5 participants