Skip to content
This repository has been archived by the owner on Jul 1, 2023. It is now read-only.

Fixed a couple RNN bugs. #522

Merged
merged 9 commits into from Nov 14, 2019
Merged

Fixed a couple RNN bugs. #522

merged 9 commits into from Nov 14, 2019

Conversation

eaplatanios
Copy link
Contributor

@rxwei @Shashi456 this tackles #518 and also includes the fix of #519. I haven't added a test yet, but will try to add one tomorrow. The fixed issues are:

  • Backpropagated gradients for RNN cells were computed wrongly, resulting in being unable to train RNNs (especially in cases where only the last time step output is being used -- e.g., sequence classification)
  • The default zero state for RNNs always had a batch size of 1 that could not be broadcast and caused failures (e.g., when concatenating the input and hidden state in LSTM cells). Now zeroState is a function that takes an example input as argument so that it can account for the batch size. In principle we can also switch it to just take the batch size as input, but in either case, until we can discuss the design of RNNs more extensively this solution suffices and fixes the current failures.

Copy link
Contributor

@rxwei rxwei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you separate the code movement and refactoring to a separate patch, and make this patch be just a bug fix?

Sources/TensorFlow/Layers/Recurrent.swift Outdated Show resolved Hide resolved
@eaplatanios
Copy link
Contributor Author

I reverted the refactoring changes as suggested.

@eaplatanios
Copy link
Contributor Author

There is only a couple variable renames I kept to make all names consistent across the callAsFunction and its corresponding VJP.

@rxwei
Copy link
Contributor

rxwei commented Oct 2, 2019

Let's wait on #519 and also add a unit test.

@eaplatanios
Copy link
Contributor Author

I haven't really found time to add a test for this PR yet, but I believe it should be merged anyway because it fixes a bug in the RNN gradients that will almost definitely cause undesired consequences for users, and add a TODO for a test. What do you think?

@marcrasi marcrasi mentioned this pull request Nov 14, 2019
@marcrasi
Copy link
Contributor

That sounds good. I will merge and create an issue for adding a test.

@marcrasi marcrasi merged commit 25c7cfe into tensorflow:master Nov 14, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants