-
Notifications
You must be signed in to change notification settings - Fork 74.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added implementation for GridRNN #1665
Conversation
Can one of the admins verify this patch? |
Will take a look tomorrow. |
from __future__ import print_function | ||
|
||
# pylint: disable=unused-import,wildcard-import, line-too-long | ||
from tensorflow.contrib.grid_rnn.python.ops.grid_rnn_cell import * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add newline
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Thanks for the hard work! Some comments. |
…input and output of the cell, updated tests
@ebrevdo Thanks for the comments. I updated the code.
|
Without a unit test that calls tf.nn.rnn or tf.nn.dynamic_rnn, it's not clear that these cells interact correctly with those methods. Can you add such a test for your classes? |
ok will do. |
I added the tests for |
Will look tomorrow-thanks!
|
Jenkins, test this please? |
Can one of the admins verify this patch? |
Sorry for not reviewing earlier. Let's retest this and we can get it merged. |
@phvu can you rebase on HEAD and run the tests? |
@ebrevdo: why does he have to rebase to HEAD? Do you think there will be a conflict in doing so? Also, only we can trigger tests. |
@tensorflow-jenkins: test this please |
Cool, tests passed. So I don't need to rebase master for now? |
Excellent, we're good to go. Thanks for the contribution! If you have any example code using this we can consider adding it elsewhere in the repo (e.g. under models). |
Cool, thanks for all the help. I think at some point I can try reproducing some experiments in the paper. Will submit PRs by then. |
Hi @phvu, Thanks very much for your work. I'm a little confused as to how this can be applied generally. For example, in the paper a 3-LSTM network was applied to image patches of the MNIST dataset where each patch is c * m units long (which is equal to the depth dimension). So we have a 3D grid of c * m vectors. But in your implementation of |
Hi @jstaker7, Second, in the MNIST case, (as far as I understand) I suppose you could use Grid3LSTMCell (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py#L242). This cells receives input and gives output in the first dimension (index 0), and uses LSTM in the other 2 dimensions. So your assumption that input_dims should be |
Hi @phvu, Thank you, this helps quite a lot. Since each dimension will always have its own LSTM, does that mean we will never use an |
To replicate the experiments in the paper, yes. But I am not sure about that generally. When we set dimension i to be input_dims, it simply means for each So I guess we can be creative and feed inputs into more than one dimension. I used dimension 0 as input and output dimensions just as a convention. You can as well construct your own GridRNNCell with any configuration. |
Perfect, thanks so much for the clarification. Super helpful and makes a lot of sense. This might also be interesting to look into in the future. Grid LSTMs were mentioned and it looks like there was a recent merge: #2560 |
As discussed here: #1453
The implementation is generic, users can specify the number of dimensions and various configurations for those dimensions (input/output/priority/non-recurrent). The type of the cells along dimensions can also be selected among LSTM, GRU, vanilla RNN.
Come with unittests for basic types: 2LSTM (tied weights, non-recurrent), 2BasicLSTM and 2RNN.
I made a simple test of Grid2LSTM for character-level language modeling: https://github.com/phvu/grid-lstm-tensorflow/tree/master/char-rnn.