New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimizers with decoupled weight decay. #17438

Merged
merged 4 commits into from Jun 13, 2018

Conversation

Projects
None yet
@PhilJd
Contributor

PhilJd commented Mar 5, 2018

This pull request implements decoupled weight decay as described in 'Fixing Weight Decay Regularization' by Loshchilov & Hutter https://arxiv.org/abs/1711.05101.

This paper shows that for adaptive gradient algorithms, the implemented method regularizes variables with large gradients more than L2 regularization would and that this yields better training loss and generalization error.

For SGD variants, this simplifies hyperparameter search since it decouples the settings of weight decay and learning rate, which is nicely visualized in Fig. 2 in the paper:
adamw

This implementation explicitly adds the optimizers described in the paper (AdamW and MomentumW) to tf.contrib.opt and provides a factory function extend_with_decoupled_weight_decay that can be used to create a new optimizer class with decoupled weight decay.

Closes #15237.

@googlebot googlebot added the cla: yes label Mar 5, 2018

@PhilJd PhilJd changed the title from Add weightdecay_optimizers. to Add optimizers with decoupled weight decay. Mar 5, 2018

@PhilJd

This comment has been minimized.

Contributor

PhilJd commented Mar 7, 2018

I'll be travelling for the next 4 weeks, so I probably won't be able to answer comments immediately ;)

@AntreasAntoniou

This comment has been minimized.

AntreasAntoniou commented Mar 27, 2018

I just tried using AdamW as implemented above, please correct me if I am wrong, but I think it breaks tf.layers.batch_norm during inference (when training=False)

@PhilJd

This comment has been minimized.

Contributor

PhilJd commented Mar 28, 2018

Thanks for trying this out! :)
I tested this implementation adapting the official resnet implementation and it worked for me, with and without decaying the batch norm variables.
I suspect that your training diverged as the optimizer is not active in the forward pass (assuming you don't do inference on the backward pass).

  • Did you adapt the weight decay rate? Most likely it should be lower compared to l2 loss regularization.
  • Do you decay all variables, i.e. including the batch norm variables? Often it's useful to exclude the batch norm vars from weight decay.
  • Do you schedule the weight decay similar to the learning rate? I.e., when lowering the learning rate, do you also lower the weight decay?
@AntreasAntoniou

This comment has been minimized.

AntreasAntoniou commented Mar 28, 2018

@PhilJd

This comment has been minimized.

Contributor

PhilJd commented Apr 9, 2018

@andrewharp Sorry for manually pulling you in, I thought it might make sense as the tensorflow butler assigned you to triage the last few pull request and this PR fell off the radar. Could you triage this PR? Thanks a lot!

@martinwicke martinwicke requested a review from alextp Apr 17, 2018

@martinwicke martinwicke self-assigned this Apr 17, 2018

def _decay_weights(self, var):
if (not self._decay_var_list or
(self._decay_var_list and var in self._decay_var_list)):
return var.assign_sub(self._weight_decay * var, self._use_locking)

This comment has been minimized.

@facaiy

facaiy Apr 17, 2018

Member

The method could keep original optimizer unchanged, however it seems a little tricky.

@PhilJd

This comment has been minimized.

Contributor

PhilJd commented Apr 17, 2018

Following the discussion in #15237 I re-checked the apply functions for other optimizers.
This implementation uses the fact that the apply_ functions already get the precomputed gradient and Adam + Momentum optimizers don't compute values based on var, so the pre-decay implementation is equivalent to the algorithm described in the paper for these optimizers.
However, I just checked the apply functions of other optimizers and e.g. Ftrl computes factors based on var, so decoupling decay for such optimizers would need a custom op.
My question now is: Do you favor limiting the decoupled implementation to Adam + Momentum optimizers, i.e., remove the extend_with-decoupled_weight_decay function or do you think adding a warning in the documentation is enough, still allowing to create e.g., AdadeltaW, AdagradW, RMSPropW with one line of code?

@facaiy

This comment has been minimized.

Member

facaiy commented Apr 18, 2018

@PhilJd I think the hacking method is feasible (and looks concise and flexible), however I have no idea of whether it is the best solution. I might prefer to implement the window decay for each optimizer awkwardly. Anyway, thanks for your nice work :) . Let's wait for the reply from tensorflower @alextp

return super(DecoupledWeightDecayExtension, self)._resource_apply_dense(
grad, var)
def _apply_sparse(self, grad, var):

This comment has been minimized.

@alextp

alextp Apr 18, 2018

Member

Doing a dense weight decay when doing a sparse variable update seems like a bad idea. I'd prefer it if these optimizers logged a warning or something like that when used for sparse models since you're unlikely to see reasonable performance.

This comment has been minimized.

@PhilJd

PhilJd Apr 18, 2018

Contributor

I completely agree but I wasn't sure what would be your preference here as e.g. Adam does dense updates of momentum even when using apply_sparse if I remember correctly.
What would you prefer:

  • add a warning
  • do sparse weight decay
  • add a flag force_dense_decay to __init__ , which always computes dense decay if True and sparse decay for apply_sparse if false?
@tensorflowbutler

This comment has been minimized.

Member

tensorflowbutler commented May 4, 2018

It has been 14 days with no activity and the awaiting response label was assigned. Is this PR still valid? Assigning the stalled label. Please comment to reassure me that this is still being worked on.

@PhilJd

This comment has been minimized.

Contributor

PhilJd commented May 7, 2018

I'm working on sparse updates but I'm rather busy due to upcoming ICRA.

@PhilJd

This comment has been minimized.

Contributor

PhilJd commented Jun 13, 2018

I'm not sure what causes this error:
RuntimeError: The following files are missing load("//tensorflow:tensorflow.bzl", "py_test").
py_test is included at the top of the BUILD file.

The do_buildifier error is now fixed (wrong order).

@alextp

alextp approved these changes Jun 13, 2018

@martinwicke martinwicke merged commit feb9a3e into tensorflow:master Jun 13, 2018

16 checks passed

Android Demo App Internal CI build successful
Details
GPU CC Internal CI build successful
Details
GPU Python3 Internal CI build successful
Details
MacOS Contrib Internal CI build successful
Details
MacOS Python2 and CC Internal CI build successful
Details
Ubuntu CC Internal CI build successful
Details
Ubuntu Makefile Internal CI build successful
Details
Ubuntu Python2 Internal CI build successful
Details
Ubuntu Python3 Internal CI build successful
Details
Ubuntu Python3 PIP Internal CI build successful
Details
Ubuntu Sanity Internal CI build successful
Details
Ubuntu contrib Internal CI build successful
Details
Windows Bazel Internal CI build successful
Details
Windows CMake Internal CI build successful
Details
XLA Internal CI build successful
Details
cla/google All necessary CLAs are signed
@pierremac

This comment has been minimized.

pierremac commented Jun 15, 2018

Thanks a lot for the great work @PhilJd !
I've been trying it a bit and have been getting very chaotic results so far.
I'm a bit lost with the weight decay parameter. Compared to the original paper, is your weight_decay parameter corresponding to their normalized or regular weight decay?
Any insights on what kind of values we should use?

@PhilJd

This comment has been minimized.

Contributor

PhilJd commented Jun 15, 2018

@pierremac: The weight decay parameter (let's call it w here) is the parameter that is multiplied with the variable before subtracting it, i.e., an update step looks roughly like this:
var = var - grad - (w * var)
It's not possible to compute the normalized weight decay within the optimizer as this depends on your specific dataset.
To now create AdamWR as in the paper, you need to compute the decay manually by multiplying your initial decay (ideally the normalized version) with your learning rate schedule, e.g.,

LR = ... # set your learning rate here
W_NORM = ...  # set your weight decay value here
global_step = tf.train.get_or_create_global_step()
schedule = tf.train.cosine_decay_restarts(1, global_step,
                                          first_decay_steps=?, t_mul=2.0,
                                          m_mul=1.0, alpha=0.0)
lr = LR * schedule
weight_decay = W_NORM * sqrt(batch_size / num_training_samples * num_epochs) * schedule

loss = ...
optimizer = tf.contrib.opt.AdamwOptimizer(weight_decay, lr)
train_op = optimizer.minimize(loss)

Regarding values for LR and W_NORM, this really depends on your model. W_NORM usually should be smaller than LR as otherwise regularization is stronger than the update step. The nice thing about decoupled weight decay is that you can tune learning rate and weight decay independent. So I'd suggest to first set decay to 0, do hyperparameter optimization for the learning rate and once you have found the
best performing learning rate you keep that fixed and start to optimize the decay parameter.

Hope that helps a bit ;)

@PhilJd PhilJd deleted the PhilJd:phil/weight_decay_optimizer branch Jun 15, 2018

@safrooze safrooze referenced this pull request Jun 15, 2018

Open

Inconsistent weight decay logics in multiple optimizers #9881

0 of 15 tasks complete
@iron9light

This comment has been minimized.

iron9light commented Jul 5, 2018

Will this change be added to v1.9.0 release? I do want this feature!

@martinwicke

This comment has been minimized.

Member

martinwicke commented Jul 5, 2018

I believe it missed 1.9, it will be in 1.10.

@yym-ustc

This comment has been minimized.

yym-ustc commented Jul 18, 2018

@PhilJd . I tried the code,and got the error:
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py", line 179, in _apply_sparse
decay_op = self._decay_weights_sparse_op(var, grad.indices, scatter_add)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py", line 163, in _decay_weights_sparse_op
self._use_locking)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/state_ops.py", line 405, in scatter_add
use_locking=use_locking, name=name)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_state_ops.py", line 591, in scatter_add
use_locking=use_locking, name=name)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/util/deprecation.py", line 488, in new_func
return func(*args, **kwargs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 3195, in create_op
op_def=op_def)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1718, in init
self._traceback = tf_stack.extract_stack()

InvalidArgumentError (see above for traceback): Must have updates.shape = indices.shape + params.shape[1:] or updates.shape = [], got updates.shape [4000000,64], indices.shape [20612], params.shape [4000000,64]

@PhilJd

This comment has been minimized.

Contributor

PhilJd commented Jul 19, 2018

@yym-ustc Do you have any small example to reproduce your error? And which tf-version do you use?

@yym-ustc

This comment has been minimized.

yym-ustc commented Jul 19, 2018

@PhilJd I use the lastest tensorflow code.

@yym-ustc

This comment has been minimized.

yym-ustc commented Jul 19, 2018

@PhilJd I used an embedding vector(4000000 * 64). When tf updated gradient of the embedding vector, the above error happened.

@PhilJd

This comment has been minimized.

Contributor

PhilJd commented Jul 19, 2018

@yym-ustc Could you share a small code snippet demonstrating your error?

@yym-ustc

This comment has been minimized.

yym-ustc commented Jul 20, 2018

@PhilJd It's ok to use the AdamOptimizer, but when change to use AdaWOptimizer, I got the error.
weights_var = tf.trainable_variables()
gradients = tf.gradients(cost, weights_var)
#optimizer = tf.train.AdamOptimizer(learning_rate=deep_learning_rate)
AdamWOptimizer = tf.contrib.opt.extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)
optimizer = AdamWOptimizer(weight_decay=weight_decay, learning_rate=deep_learning_rate)
train_op = optimizer.apply_gradients(zip(gradients, weights_var))

@PhilJd

This comment has been minimized.

Contributor

PhilJd commented Jul 20, 2018

Do you have a small, self-contained example that I can run to reproduce the bug?
E.g., without knowing what your weights_var looks like (I suppose it's sparse, but resource/not resource etc) I can't really debug the problem.
Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment