New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weight normalization for RNN Cells. #11573

Merged
merged 3 commits into from Jan 22, 2018

Conversation

Projects
None yet
@discoveredcheck
Contributor

discoveredcheck commented Jul 18, 2017

The current RNN implementation executes a user defined
function (the call() method of subclasses of RNNCells) inside a
tf.while() loop. Weight normalisation requires a one-time
normalization of the transition matrices prior to
entering the while loop. The following 2 edits have been made in
tensorflow/python/ops to enable this functionality:

  • RNNCell now has a prepare() method. It does nothing as implemented
    in the base class
  • A call to cell.prepare() has been added just before entering
    _dynamic_rnn_loop()

Subclasses of RNNCell may implement normalization in the cell's
prepare() method. One implementation with BasicLSTMCell
and associated tests have been added to contrib. Note that any
wrappers to be used with a weight-normalized cell need to be
appropriately subclassed, as illustrated with the
PrepareableMultiRNNCell example in contrib.

@tensorflow-jenkins

This comment has been minimized.

Collaborator

tensorflow-jenkins commented Jul 18, 2017

Can one of the admins verify this patch?

gx = vs.get_variable("gx", [output_size], dtype=self.dtype)
gh = vs.get_variable("gh", [output_size], dtype=self.dtype)
wx = nn_impl.l2_normalize(wx, dim=0) * gx

This comment has been minimized.

@ebrevdo

ebrevdo Jul 24, 2017

Contributor

why does this require a special prepare()?

This comment has been minimized.

@discoveredcheck

discoveredcheck Jul 25, 2017

Contributor

Hi @ebrevdo. Thanks for your review. It's mainly because the weights should be normalized once before entering the RNN while loop (following https://github.com/openai/generating-reviews-discovering-sentiment/blob/master/encoder.py). Without the prepare() method we would have to add the normalization code inside the cell's call() function which will (incorrectly) normalize the weights after each time step. We did try this implementation and found that it gives OOM errors during backpropagation in even moderately sized networks due to the wasteful normalization op at every time step.

We also thought about putting this code inside the cell's constructor, but that would go against the functional semantics of how other existing RNNCell classes operate. Also, it won't be possible to inherit the variable scopes of any wrappers one might want to use (specifically the MultiRNNCell which adds cell_%d to the variable scope).

Our priority has been to try and not make any changes to core tensorflow, but this seemed to be way forward based on the above reasoning. Does this make sense? Or am I missing something here?

Ashwini

@discoveredcheck

This comment has been minimized.

Contributor

discoveredcheck commented Jul 25, 2017

Hi @ebrevdo, the NLP team from Winton will be at the ACL conference next week. We will be happy to meet with someone from the tensorflow team if they are attending. We thought it might be more efficient seeing as this PR proposes an update in core tensorflow and might require further discussion/alternate implementation ideas.

@vrv

This comment has been minimized.

Contributor

vrv commented Jul 26, 2017

IIRC @ebrevdo just went on vacation for 2 weeks, and he's the best person to look at this change. Is this urgent, or can it wait for him to get back? Adding an (optional) API to something as core as the RNN interface probably deserves some careful thought, which I know you've already spent some time thinking about!

@ebrevdo

This comment has been minimized.

Contributor

ebrevdo commented Jul 26, 2017

@discoveredcheck discoveredcheck force-pushed the discoveredcheck:weightnorm branch 2 times, most recently from 14aab21 to 10411fd Jul 26, 2017

@discoveredcheck

This comment has been minimized.

Contributor

discoveredcheck commented Jul 27, 2017

@ebrevdo Thanks for pointing out the use of control_dependencies. I had a slightly different reading of semantics of the None argument in there - meaning the ops would just run normally as if tf.control_dependencies wasn't there at all.
Not quite sure how the tf.while skips these ops during the loop. But it solved the problem!

@discoveredcheck discoveredcheck force-pushed the discoveredcheck:weightnorm branch 2 times, most recently from 98f846b to e61e29e Aug 8, 2017

@discoveredcheck

This comment has been minimized.

Contributor

discoveredcheck commented Aug 8, 2017

Hi @ebrevdo , Hope you've had a nice vacation when you see this.

I've updated the PR with more commits (squashed) to include documentation and formatting, mostly following the pylint template and existing code in contrib. I should mention here that most of the code is adapted from rnn_cell_impl.LSTMCell with an additional norm argument, which if True, does the normalization, else returns the same outputs as rnn_cell_impl.LSTMCell. I've included this information in the docstrings.

Hopefully these changes are in the right direction. Looking forward to your comments.

@rmlarsen

This comment has been minimized.

Member

rmlarsen commented Aug 8, 2017

@ebrevdo could you please take another look?

@ebrevdo

This comment has been minimized.

Contributor

ebrevdo commented Aug 8, 2017

@discoveredcheck discoveredcheck force-pushed the discoveredcheck:weightnorm branch from e61e29e to dcbc1eb Aug 8, 2017

@discoveredcheck

This comment has been minimized.

Contributor

discoveredcheck commented Aug 15, 2017

@tensorflow-jenkins test this please

@discoveredcheck

This comment has been minimized.

Contributor

discoveredcheck commented Aug 17, 2017

Hi @ebrevdo , did you get a chance to look at the code? I triggered a Jenkins build earlier which seems to be okay.

@discoveredcheck

This comment has been minimized.

Contributor

discoveredcheck commented Sep 12, 2017

Hi @rmlarsen @ebrevdo , any news on this?

@drpngx

This comment has been minimized.

Member

drpngx commented Sep 17, 2017

@ebrevdo any chance to take a look?

@sb2nov

This comment has been minimized.

Member

sb2nov commented Sep 26, 2017

@ebrevdo please take a look

@sb2nov

This comment has been minimized.

Member

sb2nov commented Sep 26, 2017

Jenkins, test this please.

@discoveredcheck

This comment has been minimized.

Contributor

discoveredcheck commented Oct 9, 2017

Jenkins, test this please

@ebrevdo

ebrevdo approved these changes Oct 9, 2017

approval contingent on a few nits.

"""Normalizes the columns of the given weight matrix and
multiplies each column with an independent scalar variable."""
output_size = weight.get_shape().as_list()[1]

This comment has been minimized.

@ebrevdo

ebrevdo Oct 9, 2017

Contributor

are you guaranteed at this point that the shape has known rank and dim 1?

This comment has been minimized.

@discoveredcheck

discoveredcheck Oct 9, 2017

Contributor

Yes, the shape of weight is known since it is created via get_variable() with full shape specification (line 2242). Assert on line 2232 will fire if the shape could not be calculated.

return self._output_size
def _normalize(self, weight, name):
"""Normalizes the columns of the given weight matrix and

This comment has been minimized.

@ebrevdo

ebrevdo Oct 9, 2017

Contributor

nit: first line of a comment should be a short sentence that fits on one line and ends with a period.

This comment has been minimized.

@discoveredcheck

discoveredcheck Oct 9, 2017

Contributor

Fixed.

@discoveredcheck

This comment has been minimized.

Contributor

discoveredcheck commented Oct 10, 2017

Jenkins, test this please

1 similar comment
@gunan

This comment has been minimized.

Member

gunan commented Nov 5, 2017

Jenkins, test this please

@discoveredcheck

This comment has been minimized.

Contributor

discoveredcheck commented Nov 15, 2017

Removed lines of commented out code. Have left in pylint and explanatory comments.

def testBasicCellWithNorm(self):
"""Tests cell w/o peepholes and with normalisation"""
cell = lambda: contrib_rnn_cell.WeightNormLSTMCell(2,

This comment has been minimized.

@ebrevdo

ebrevdo Nov 17, 2017

Contributor

Make this a def cell, multiline lambdas are discouraged

This comment has been minimized.

@discoveredcheck

discoveredcheck Nov 17, 2017

Contributor

Okay, updated in the latest commits.

@@ -42,6 +42,7 @@
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.util import nest
from tensorflow.python.training import adam

This comment has been minimized.

@ebrevdo

ebrevdo Nov 17, 2017

Contributor

Do you really need Adam, or can you get away with just calling gradients?

This comment has been minimized.

@discoveredcheck

discoveredcheck Nov 17, 2017

Contributor

you are right, it isn't required. I'll remove it.

@ebrevdo

This comment has been minimized.

Contributor

ebrevdo commented Nov 17, 2017

Update your PR message, it's outdated.

Weight normalization for RNN Cells.
The class `WeightNormLSTMCell` is added in contrib/rnn. This
implements most of the functionality provided by `LSTMCell`
with one additional constructor argument `norm`. If set to `True`, this
will normalize the state transition and projection (if used via the
`num_proj` argument) matrices. If set to `False`, the outputs are identical
to that of `LSTMCell`.

Tests are added for checking this equivalence and correctness
of the normalization operation.

@discoveredcheck discoveredcheck force-pushed the discoveredcheck:weightnorm branch from dfee744 to 3b7520f Nov 17, 2017

@discoveredcheck

This comment has been minimized.

Contributor

discoveredcheck commented Nov 17, 2017

Thanks for your comments @ebrevdo . I've update the PR message and code accordingly. Squashed the commits as well.

Jenkins, test this please.

@ebrevdo

This comment has been minimized.

Contributor

ebrevdo commented Dec 14, 2017

@discoveredcheck quick ping on this.

@discoveredcheck

This comment has been minimized.

Contributor

discoveredcheck commented Dec 14, 2017

Hi @ebrevdo, I have pushed in commits covering your latest review comments (multiline lambdas, PR message and commented-out lines of code). Is there anything else that needs change?

@drpngx

This comment has been minimized.

Member

drpngx commented Dec 27, 2017

@discoveredcheck the change has been approved, but you need to pull rebase and push again. Then we'll run the tests.

@caisq

This comment has been minimized.

Contributor

caisq commented Jan 1, 2018

@discoveredcheck can you please resolve conflicts before we test again and merge the PR?

@rorywaite

This comment has been minimized.

rorywaite commented Jan 2, 2018

@discoveredcheck is on vacation until the 8th. Does this need to be resolved before the 8th? If so, I can have a look.

@discoveredcheck

This comment has been minimized.

Contributor

discoveredcheck commented Jan 3, 2018

@rorywaite @caisq @drpngx I will have a look today. It shouldnt need much since a rebase was done with the last round of changes.

discoveredcheck added some commits Jan 3, 2018

@discoveredcheck

This comment has been minimized.

Contributor

discoveredcheck commented Jan 3, 2018

Hi @caisq , I've made the necessary changes. RNN tests pass locally. Please let us know if any further work is required on this.

Jenkins, test this please.

@yifeif

This comment has been minimized.

Member

yifeif commented Jan 11, 2018

@tensorflow-jenkins test this please

@rorywaite

This comment has been minimized.

rorywaite commented Jan 17, 2018

Hey, could you give us a heads-up a day or two before you plan to merge? We want to ensure that we're available to fix merge issues as they crop up.

@caisq

This comment has been minimized.

Contributor

caisq commented Jan 17, 2018

@rorywaite we are ready to merge this PR sometime today or tomorrow.

@drpngx

This comment has been minimized.

Member

drpngx commented Jan 20, 2018

@rorywaite ok to merge?

@discoveredcheck

This comment has been minimized.

Contributor

discoveredcheck commented Jan 21, 2018

@drpngx , yes, please go ahead with the merge.

@drpngx drpngx merged commit 57b32ea into tensorflow:master Jan 22, 2018

18 checks passed

Android Demo App Internal CI build successful
Details
GPU CC Internal CI build successful
Details
GPU Python3 Internal CI build successful
Details
Linux CPU Tests (Python 3) SUCCESS
Details
MacOS Contrib Internal CI build successful
Details
MacOS Python2 and CC Internal CI build successful
Details
Sanity Checks SUCCESS
Details
TF Test Suite SUCCESS
Details
Ubuntu CC Internal CI build successful
Details
Ubuntu Makefile Internal CI build successful
Details
Ubuntu Python2 Internal CI build successful
Details
Ubuntu Python3 Internal CI build successful
Details
Ubuntu Sanity Internal CI build successful
Details
Ubuntu contrib Internal CI build successful
Details
Windows Cmake Tests SUCCESS
Details
XLA Internal CI build successful
Details
ci.tensorflow.org SUCCESS
Details
cla/google All necessary CLAs are signed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment