Skip to content

RNNCellDropoutWrapper applies dropout on the LSTM c state #33690

Closed
@georgesterpu

Description

@georgesterpu

System information

  • Have I written custom code: No
  • TensorFlow installed from: source (pip)
  • TensorFlow version: v2.0.0-rc2-26-g64c3d38 2.0.0
  • Python version: 3.7.4

Describe the current behavior
The _call_wrapped_cell method of the DropoutWrapperBase class applies dropout on both c and h states of an LSTM cell. Its default method to determine if a state should take dropout, _default_dropout_state_filter_visitor, only works correctly with LSTM states packed as a LSTMStateTuple namedtuple. This was fine in TensorFlow 1.x, where the LSTM state is passed around as a LSTMStateTuple. However, in TensorFlow 2.0 the state is a Python tuple, and the method returns True for both substates.

Describe the expected behavior
Exclude the LSTM c state from the list of dropout candidates.

Metadata

Metadata

Assignees

Labels

TF 2.0Issues relating to TensorFlow 2.0comp:apisHighlevel API related issuestype:featureFeature requests

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions