Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weight normalization for RNN Cells. #11573

Merged
merged 3 commits into from
Jan 22, 2018

Conversation

discoveredcheck
Copy link
Contributor

The current RNN implementation executes a user defined
function (the call() method of subclasses of RNNCells) inside a
tf.while() loop. Weight normalisation requires a one-time
normalization of the transition matrices prior to
entering the while loop. The following 2 edits have been made in
tensorflow/python/ops to enable this functionality:

  • RNNCell now has a prepare() method. It does nothing as implemented
    in the base class
  • A call to cell.prepare() has been added just before entering
    _dynamic_rnn_loop()

Subclasses of RNNCell may implement normalization in the cell's
prepare() method. One implementation with BasicLSTMCell
and associated tests have been added to contrib. Note that any
wrappers to be used with a weight-normalized cell need to be
appropriately subclassed, as illustrated with the
PrepareableMultiRNNCell example in contrib.

@tensorflow-jenkins
Copy link
Collaborator

Can one of the admins verify this patch?

gx = vs.get_variable("gx", [output_size], dtype=self.dtype)
gh = vs.get_variable("gh", [output_size], dtype=self.dtype)

wx = nn_impl.l2_normalize(wx, dim=0) * gx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this require a special prepare()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ebrevdo. Thanks for your review. It's mainly because the weights should be normalized once before entering the RNN while loop (following https://github.com/openai/generating-reviews-discovering-sentiment/blob/master/encoder.py). Without the prepare() method we would have to add the normalization code inside the cell's call() function which will (incorrectly) normalize the weights after each time step. We did try this implementation and found that it gives OOM errors during backpropagation in even moderately sized networks due to the wasteful normalization op at every time step.

We also thought about putting this code inside the cell's constructor, but that would go against the functional semantics of how other existing RNNCell classes operate. Also, it won't be possible to inherit the variable scopes of any wrappers one might want to use (specifically the MultiRNNCell which adds cell_%d to the variable scope).

Our priority has been to try and not make any changes to core tensorflow, but this seemed to be way forward based on the above reasoning. Does this make sense? Or am I missing something here?

Ashwini

@discoveredcheck
Copy link
Contributor Author

Hi @ebrevdo, the NLP team from Winton will be at the ACL conference next week. We will be happy to meet with someone from the tensorflow team if they are attending. We thought it might be more efficient seeing as this PR proposes an update in core tensorflow and might require further discussion/alternate implementation ideas.

@vrv
Copy link

vrv commented Jul 26, 2017

IIRC @ebrevdo just went on vacation for 2 weeks, and he's the best person to look at this change. Is this urgent, or can it wait for him to get back? Adding an (optional) API to something as core as the RNN interface probably deserves some careful thought, which I know you've already spent some time thinking about!

@ebrevdo
Copy link
Contributor

ebrevdo commented Jul 26, 2017 via email

@discoveredcheck discoveredcheck force-pushed the weightnorm branch 2 times, most recently from 14aab21 to 10411fd Compare July 26, 2017 18:48
@discoveredcheck
Copy link
Contributor Author

@ebrevdo Thanks for pointing out the use of control_dependencies. I had a slightly different reading of semantics of the None argument in there - meaning the ops would just run normally as if tf.control_dependencies wasn't there at all.
Not quite sure how the tf.while skips these ops during the loop. But it solved the problem!

@discoveredcheck discoveredcheck force-pushed the weightnorm branch 2 times, most recently from 98f846b to e61e29e Compare August 8, 2017 09:39
@discoveredcheck
Copy link
Contributor Author

Hi @ebrevdo , Hope you've had a nice vacation when you see this.

I've updated the PR with more commits (squashed) to include documentation and formatting, mostly following the pylint template and existing code in contrib. I should mention here that most of the code is adapted from rnn_cell_impl.LSTMCell with an additional norm argument, which if True, does the normalization, else returns the same outputs as rnn_cell_impl.LSTMCell. I've included this information in the docstrings.

Hopefully these changes are in the right direction. Looking forward to your comments.

@rmlarsen rmlarsen added the awaiting review Pull request awaiting review label Aug 8, 2017
@rmlarsen
Copy link
Member

rmlarsen commented Aug 8, 2017

@ebrevdo could you please take another look?

@ebrevdo
Copy link
Contributor

ebrevdo commented Aug 8, 2017 via email

@discoveredcheck
Copy link
Contributor Author

@tensorflow-jenkins test this please

@discoveredcheck
Copy link
Contributor Author

Hi @ebrevdo , did you get a chance to look at the code? I triggered a Jenkins build earlier which seems to be okay.

@discoveredcheck
Copy link
Contributor Author

Hi @rmlarsen @ebrevdo , any news on this?

@drpngx
Copy link
Contributor

drpngx commented Sep 17, 2017

@ebrevdo any chance to take a look?

@sb2nov
Copy link
Contributor

sb2nov commented Sep 26, 2017

@ebrevdo please take a look

@sb2nov
Copy link
Contributor

sb2nov commented Sep 26, 2017

Jenkins, test this please.

@discoveredcheck
Copy link
Contributor Author

Jenkins, test this please

Copy link
Contributor

@ebrevdo ebrevdo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approval contingent on a few nits.

"""Normalizes the columns of the given weight matrix and
multiplies each column with an independent scalar variable."""

output_size = weight.get_shape().as_list()[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you guaranteed at this point that the shape has known rank and dim 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the shape of weight is known since it is created via get_variable() with full shape specification (line 2242). Assert on line 2232 will fire if the shape could not be calculated.

return self._output_size

def _normalize(self, weight, name):
"""Normalizes the columns of the given weight matrix and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: first line of a comment should be a short sentence that fits on one line and ends with a period.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@discoveredcheck
Copy link
Contributor Author

Jenkins, test this please

1 similar comment
@gunan
Copy link
Contributor

gunan commented Nov 5, 2017

Jenkins, test this please

c0 = array_ops.zeros([1, 2])
h0 = array_ops.zeros([1, 2])

#cell = contrib_rnn_cell.WeightNormLSTMCell(2, norm, peep)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove commented out lines everywhere

@discoveredcheck
Copy link
Contributor Author

Removed lines of commented out code. Have left in pylint and explanatory comments.

@jhseu jhseu added the kokoro:force-run Tests on submitted change label Nov 16, 2017
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Nov 16, 2017
def testBasicCellWithNorm(self):
"""Tests cell w/o peepholes and with normalisation"""

cell = lambda: contrib_rnn_cell.WeightNormLSTMCell(2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this a def cell, multiline lambdas are discouraged

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, updated in the latest commits.

@@ -42,6 +42,7 @@
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.util import nest
from tensorflow.python.training import adam
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you really need Adam, or can you get away with just calling gradients?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, it isn't required. I'll remove it.

@ebrevdo
Copy link
Contributor

ebrevdo commented Nov 17, 2017

Update your PR message, it's outdated.

@gunan gunan added the kokoro:force-run Tests on submitted change label Nov 17, 2017
@kokoro-team kokoro-team removed kokoro:force-run Tests on submitted change labels Nov 17, 2017
The class `WeightNormLSTMCell` is added in contrib/rnn. This
implements most of the functionality provided by `LSTMCell`
with one additional constructor argument `norm`. If set to `True`, this
will normalize the state transition and projection (if used via the
`num_proj` argument) matrices. If set to `False`, the outputs are identical
to that of `LSTMCell`.

Tests are added for checking this equivalence and correctness
of the normalization operation.
@discoveredcheck
Copy link
Contributor Author

Thanks for your comments @ebrevdo . I've update the PR message and code accordingly. Squashed the commits as well.

Jenkins, test this please.

@ebrevdo
Copy link
Contributor

ebrevdo commented Dec 14, 2017

@discoveredcheck quick ping on this.

@discoveredcheck
Copy link
Contributor Author

Hi @ebrevdo, I have pushed in commits covering your latest review comments (multiline lambdas, PR message and commented-out lines of code). Is there anything else that needs change?

@drpngx
Copy link
Contributor

drpngx commented Dec 27, 2017

@discoveredcheck the change has been approved, but you need to pull rebase and push again. Then we'll run the tests.

@caisq
Copy link
Contributor

caisq commented Jan 1, 2018

@discoveredcheck can you please resolve conflicts before we test again and merge the PR?

@rorywaite
Copy link

@discoveredcheck is on vacation until the 8th. Does this need to be resolved before the 8th? If so, I can have a look.

@discoveredcheck
Copy link
Contributor Author

@rorywaite @caisq @drpngx I will have a look today. It shouldnt need much since a rebase was done with the last round of changes.

@discoveredcheck
Copy link
Contributor Author

Hi @caisq , I've made the necessary changes. RNN tests pass locally. Please let us know if any further work is required on this.

Jenkins, test this please.

@yifeif
Copy link
Contributor

yifeif commented Jan 11, 2018

@tensorflow-jenkins test this please

@yifeif yifeif added the kokoro:force-run Tests on submitted change label Jan 11, 2018
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Jan 11, 2018
@rorywaite
Copy link

Hey, could you give us a heads-up a day or two before you plan to merge? We want to ensure that we're available to fix merge issues as they crop up.

@caisq
Copy link
Contributor

caisq commented Jan 17, 2018

@rorywaite we are ready to merge this PR sometime today or tomorrow.

@drpngx
Copy link
Contributor

drpngx commented Jan 20, 2018

@rorywaite ok to merge?

@discoveredcheck
Copy link
Contributor Author

@drpngx , yes, please go ahead with the merge.

@drpngx drpngx merged commit 57b32ea into tensorflow:master Jan 22, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes stat:awaiting response Status - Awaiting response from author
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet