Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added implementation for GridRNN #1665

Merged
merged 15 commits into from
Apr 14, 2016
Merged

Conversation

phvu
Copy link
Contributor

@phvu phvu commented Mar 27, 2016

As discussed here: #1453

The implementation is generic, users can specify the number of dimensions and various configurations for those dimensions (input/output/priority/non-recurrent). The type of the cells along dimensions can also be selected among LSTM, GRU, vanilla RNN.

Come with unittests for basic types: 2LSTM (tied weights, non-recurrent), 2BasicLSTM and 2RNN.

I made a simple test of Grid2LSTM for character-level language modeling: https://github.com/phvu/grid-lstm-tensorflow/tree/master/char-rnn.

@tensorflow-jenkins
Copy link
Collaborator

Can one of the admins verify this patch?

@ebrevdo
Copy link
Contributor

ebrevdo commented Mar 29, 2016

Will take a look tomorrow.

from __future__ import print_function

# pylint: disable=unused-import,wildcard-import, line-too-long
from tensorflow.contrib.grid_rnn.python.ops.grid_rnn_cell import *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add newline

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@ebrevdo
Copy link
Contributor

ebrevdo commented Mar 30, 2016

Thanks for the hard work! Some comments.

@phvu
Copy link
Contributor Author

phvu commented Mar 31, 2016

@ebrevdo Thanks for the comments. I updated the code.

  • Good to know about input_size and output_size, I updated the code to not depend on those.
    However as the LSTMCell is still accepting input_size, I still keep input_size as a parameter in the cell_fn callback. Should it be also removed?
  • The tests are real. I computed the values using this. Since the initializer is fixed in the tests (weights are initialized to 0.5 or 0.2, depending on the tests; bias are initialized to zero), the output are deterministic. But of course I can switch to lightweight asserts if that is preferred.

@ebrevdo
Copy link
Contributor

ebrevdo commented Apr 4, 2016

Without a unit test that calls tf.nn.rnn or tf.nn.dynamic_rnn, it's not clear that these cells interact correctly with those methods. Can you add such a test for your classes?

@phvu
Copy link
Contributor Author

phvu commented Apr 4, 2016

ok will do.

@phvu
Copy link
Contributor Author

phvu commented Apr 6, 2016

I added the tests for Grid1LSTM, Grid2LSTM, Grid3LSTM (with ReLU), trained with tf.nn.rnn. Can you take a look?

@ebrevdo
Copy link
Contributor

ebrevdo commented Apr 8, 2016

Will look tomorrow-thanks!
On Apr 6, 2016 11:36 AM, "Vu Pham" notifications@github.com wrote:

I added the tests for Grid1LSTM, Grid2LSTM, Grid3LSTM (with ReLU),
trained with tf.nn.rnn. Can you take a look?


You are receiving this because you were mentioned.
Reply to this email directly or view it on GitHub
#1665 (comment)

@ebrevdo
Copy link
Contributor

ebrevdo commented Apr 13, 2016

Jenkins, test this please?

@gunan
Copy link
Contributor

gunan commented Apr 13, 2016

Can one of the admins verify this patch?

@ebrevdo
Copy link
Contributor

ebrevdo commented Apr 13, 2016

Sorry for not reviewing earlier. Let's retest this and we can get it merged.

@ebrevdo
Copy link
Contributor

ebrevdo commented Apr 14, 2016

@phvu can you rebase on HEAD and run the tests?

@vrv
Copy link

vrv commented Apr 14, 2016

@ebrevdo: why does he have to rebase to HEAD? Do you think there will be a conflict in doing so?

Also, only we can trigger tests.

@vrv
Copy link

vrv commented Apr 14, 2016

@tensorflow-jenkins: test this please

@phvu
Copy link
Contributor Author

phvu commented Apr 14, 2016

Cool, tests passed. So I don't need to rebase master for now?

@ebrevdo
Copy link
Contributor

ebrevdo commented Apr 14, 2016

Excellent, we're good to go. Thanks for the contribution! If you have any example code using this we can consider adding it elsewhere in the repo (e.g. under models).

@ebrevdo ebrevdo merged commit 97fb7dd into tensorflow:master Apr 14, 2016
@phvu
Copy link
Contributor Author

phvu commented Apr 15, 2016

Cool, thanks for all the help.
Unfortunately I can't share what I am working on with this yet. I have an application of this to character-level modeling (https://github.com/phvu/grid-lstm-tensorflow/tree/master/char-rnn) but it isn't entirely my code so I put it in a separated repo.

I think at some point I can try reproducing some experiments in the paper. Will submit PRs by then.

@phvu phvu deleted the enhancement/grid-rnn branch May 12, 2016 23:33
@jstaker7
Copy link

jstaker7 commented Jul 18, 2016

Hi @phvu,

Thanks very much for your work. I'm a little confused as to how this can be applied generally. For example, in the paper a 3-LSTM network was applied to image patches of the MNIST dataset where each patch is c * m units long (which is equal to the depth dimension). So we have a 3D grid of c * m vectors. But in your implementation of __call__, we are expecting a tensor of size batch * c * m * num_dims. For MNIST, shouldn't it support b * c * m * dim1_size * dim2_size ? This would require input_dims to be something like [1, 2] (where dim 0 is batch); but I don't see any tests covering cases besides input_dim=0. Am I missing something?

@phvu
Copy link
Contributor Author

phvu commented Jul 19, 2016

Hi @jstaker7,
There are some nuances here.
First, the __call__ method expects the state tensor of batch_size * ( (c + m) * num_dims). When c=m=2 then c*m = c+m. Sorry for the bad test case, I should have used other values.

Second, in the MNIST case, (as far as I understand) I suppose you could use Grid3LSTMCell (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py#L242). This cells receives input and gives output in the first dimension (index 0), and uses LSTM in the other 2 dimensions.
In order to replicate Figure 11 in the paper, you will need to construct 4 3-LSTM cells, each cell will handle 1 scan direction of the input image. However the first hidden LSTM layer receives original pixels as input, while other hidden LSTM layers receives the output of the LSTM just below it.
How to do that is up to you. One simple way is to have a loop to scan the image in a given scan direction (say left-to-right top-to-bottom), which will give a sequence of vectors, and then feed that sequence into the corresponding 3-LSTM cell for that scan direction. For bigger images, you might want to use tf.scan()

So your assumption that input_dims should be [1,2] is not true. The cells should receive input at dimension 0. For me this is the only reasonable interpretation of the paper, which also take 1-LSTM and 2-LSTM cells into account. (unless the authors release their implementation so that we can do a comparison).

@jstaker7
Copy link

Hi @phvu,

Thank you, this helps quite a lot. Since each dimension will always have its own LSTM, does that mean we will never use an input_dims other than 0?

@phvu
Copy link
Contributor Author

phvu commented Jul 19, 2016

To replicate the experiments in the paper, yes. But I am not sure about that generally.

When we set dimension i to be input_dims, it simply means for each __call__, the cell expects an input tensor for dimension i. For non-input dimensions, the cell will take the recurrent values (extracted from the state tensor) as input.

So I guess we can be creative and feed inputs into more than one dimension. I used dimension 0 as input and output dimensions just as a convention. You can as well construct your own GridRNNCell with any configuration.

@jstaker7
Copy link

jstaker7 commented Jul 19, 2016

Perfect, thanks so much for the clarification. Super helpful and makes a lot of sense.

This might also be interesting to look into in the future. Grid LSTMs were mentioned and it looks like there was a recent merge: #2560

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants