Skip to content
This repository has been archived by the owner on Oct 17, 2021. It is now read-only.

Add support for stateful RNN #292

Merged
merged 12 commits into from
Aug 22, 2018
Merged

Add support for stateful RNN #292

merged 12 commits into from
Aug 22, 2018

Conversation

caisq
Copy link
Contributor

@caisq caisq commented Aug 19, 2018

FEATURE

Fixes tensorflow/tfjs#23

This change is Reviewable

Copy link
Contributor

@bileschi bileschi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 2 of 3 files at r1.
Reviewable status: 0 of 1 approvals obtained (waiting on @caisq, @davidsoergel, @ericdnielsen, and @bileschi)


src/engine/container.ts, line 1427 at r1 (raw file):

Exapmles

nit: sp


src/engine/container.ts, line 1434 at r1 (raw file):

(layer as any).resetStates

base Layer has a stateful property. Does it make more sense to add resetStates there (as a no-op fn by default) than to check for the existence of the key here? Genuinely asking, which is more idiomatic.


src/layers/recurrent.ts, line 261 at r1 (raw file):

difference

I think you mean 'different' here?

Perhaps it would be clearer following the keras.io documentation to say something like "If the layer is stateful, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch."


src/layers/recurrent.ts, line 270 at r1 (raw file):

Quoted 11 lines of code…
   * To enable "statefulness":
   *   - specify `stateful: true` in the layer constructor.
   *   - specify a fixed batch size for your model, by passing
   *     - if sequential model:
   *       `batchInputShape: [...]` to the first layer in your model.
   *     - else for functional model with 1 or more Input layers:
   *       `batchShape: [...]` to all the first layers in your model.
   *     This is the expected shape of your inputs
   *     *including the batch size*.
   *     It should be a tuple of integers, e.g., `[32, 10, 100]`.
   *   - specify `shuffle: false` when calling `Model.fit()`.

To what extent is any of this protectable via configuration validation? Can we keep users from shooting themselves in the foot here?


src/layers/recurrent.ts, line 275 at r1 (raw file):

resetState

nit: resetStates


src/layers/recurrent.ts, line 389 at r1 (raw file):

BPTT

What is BPTT? Back Propagation Through Time?


src/layers/recurrent_test.ts, line 789 at r1 (raw file):

5,

Nit: Small thing, I know, but the style of this python code and that from the test above doesn't match, making searching for the differences more difficult than it needs to be. For instance, in the above code, units is separated out, here it's in line, above, the third dimension name is called 'sequence_length', here it's 'n_dims' The arguments are also in a different order.


src/layers/recurrent_test.ts, line 1211 at r1 (raw file):

 });

Great testing, btw! Thanks for taking the time to do this well and right!

Copy link
Contributor Author

@caisq caisq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @bileschi, @caisq, @davidsoergel, and @ericdnielsen)


src/engine/container.ts, line 1427 at r1 (raw file):

Previously, bileschi (Stanley Bileschi) wrote…
Exapmles

nit: sp

Done.


src/engine/container.ts, line 1434 at r1 (raw file):

Previously, bileschi (Stanley Bileschi) wrote…
(layer as any).resetStates

base Layer has a stateful property. Does it make more sense to add resetStates there (as a no-op fn by default) than to check for the existence of the key here? Genuinely asking, which is more idiomatic.

Thanks for the suggestion. I think that makes sense. Done!


src/layers/recurrent.ts, line 261 at r1 (raw file):

Previously, bileschi (Stanley Bileschi) wrote…
difference

I think you mean 'different' here?

Perhaps it would be clearer following the keras.io documentation to say something like "If the layer is stateful, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch."

Yep. I think the current doc string says the same thing.


src/layers/recurrent.ts, line 270 at r1 (raw file):

Previously, bileschi (Stanley Bileschi) wrote…
   * To enable "statefulness":
   *   - specify `stateful: true` in the layer constructor.
   *   - specify a fixed batch size for your model, by passing
   *     - if sequential model:
   *       `batchInputShape: [...]` to the first layer in your model.
   *     - else for functional model with 1 or more Input layers:
   *       `batchShape: [...]` to all the first layers in your model.
   *     This is the expected shape of your inputs
   *     *including the batch size*.
   *     It should be a tuple of integers, e.g., `[32, 10, 100]`.
   *   - specify `shuffle: false` when calling `Model.fit()`.

To what extent is any of this protectable via configuration validation? Can we keep users from shooting themselves in the foot here?

One way users can shoot themselves in the foot is to forget setting shuffle: false during Model.fit() calls. This is the case even for Python Keras. Other than that, there is sufficient protection for missing batchInputShape and so forth. I added a TODO item here to explore how we can warn the user about missing shuffle: false when training a model consisting of stateful Layers.


src/layers/recurrent.ts, line 275 at r1 (raw file):

Previously, bileschi (Stanley Bileschi) wrote…
resetState

nit: resetStates

Done, here and elsewhere.


src/layers/recurrent.ts, line 389 at r1 (raw file):

Previously, bileschi (Stanley Bileschi) wrote…
BPTT

What is BPTT? Back Propagation Through Time?

Spelling it out here. Done.


src/layers/recurrent_test.ts, line 789 at r1 (raw file):

Previously, bileschi (Stanley Bileschi) wrote…
5,

Nit: Small thing, I know, but the style of this python code and that from the test above doesn't match, making searching for the differences more difficult than it needs to be. For instance, in the above code, units is separated out, here it's in line, above, the third dimension name is called 'sequence_length', here it's 'n_dims' The arguments are also in a different order.

Thanks for catching that. I fixed these inconsistencies.


src/layers/recurrent_test.ts, line 1211 at r1 (raw file):

Previously, bileschi (Stanley Bileschi) wrote…
 });

Great testing, btw! Thanks for taking the time to do this well and right!

Ack.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
2 participants