Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot train canned estimators in multiple estimator.train() calls when using tf.keras.optimizers or tf.optimizers #33358

Closed
JoshEZiegler opened this issue Oct 14, 2019 · 30 comments
Assignees
Labels
2.6.0 comp:keras Keras related issues type:support Support issues

Comments

@JoshEZiegler
Copy link

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • Platform: Code run in google Colab
  • Python version: Python 3
  • Tensorflow version: v2.0.0-rc1-51-g2646d23 2.0.0-rc2

Describe the current behavior
When training a canned estimator with multiple tf.train calls while using any tf.keras.optimizer the optimizer raises an exception.

Describe the expected behavior
Repeated tf.train calls train for the given amount of steps.

Code to reproduce the issue
Lightly edited example using canned estimators:
https://gist.github.com/JoshEZiegler/2a923a707d831ca7efd33dbfbf9779c9

Other info / logs

RuntimeError Traceback (most recent call last)
in ()
5 classifier.train(
6 input_fn=lambda: input_fn(train, train_y, training=True),
----> 7 steps=500)

7 frames
/tensorflow-2.0.0-rc2/python3.6/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py in iterations(self, variable)
660 def iterations(self, variable):
661 if self._iterations is not None:
--> 662 raise RuntimeError("Cannot set iterations to a new Variable after "
663 "the Optimizer weights have been created")
664 self._iterations = variable

RuntimeError: Cannot set iterations to a new Variable after the Optimizer weights have been created

@oanush oanush self-assigned this Oct 15, 2019
@oanush oanush added comp:keras Keras related issues TF 2.0-rc0 labels Oct 15, 2019
@oanush
Copy link

oanush commented Oct 15, 2019

@JoshEZiegler ,
Hi, i tried running the given gist for TF-2.0 and 2.0rc1 i did not face any error. Can you provide gist of colab where you are facing the issue ? Thanks!

@oanush oanush added stat:awaiting response Status - Awaiting response from author type:support Support issues labels Oct 15, 2019
@JoshEZiegler JoshEZiegler changed the title Cannot train canned estimators in estimator.train() calls when using tf.keras.optimizers Cannot train canned estimators in multiple estimator.train() calls when using tf.keras.optimizers Oct 15, 2019
@JoshEZiegler
Copy link
Author

I'm not sure what you mean by a gist of colab. The gist I provided was saved from colab and shows to me a link to open the ipynb in colab. It may have been still processing something when you saw it? I'm a bit unfamiliar with particulars of colab/gists so I could certainly be making a mistake.

After checking the traceback, it appears that the issue I had occurred with 2.0rc2 so this could be specific to that version (maybe that's what you're saying here). In that case maybe this is solved in the most recent release?

@JoshEZiegler
Copy link
Author

I tried versions 2.0.0-rc1 and 2.0.0 with this same ipynb in colab but they both gave me the same error.

Changing to tf.optimizers rather than keras does not change it either.

@JoshEZiegler JoshEZiegler changed the title Cannot train canned estimators in multiple estimator.train() calls when using tf.keras.optimizers Cannot train canned estimators in multiple estimator.train() calls when using tf.keras.optimizers or tf.optimizers Oct 15, 2019
@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Oct 16, 2019
@oanush
Copy link

oanush commented Oct 17, 2019

@JoshEZiegler ,
I tried with versions 2.0, 2.0rc1,2.0rc2 and didn't face any error. Please find gist of colab for respective versions.

@oanush oanush added the stat:awaiting response Status - Awaiting response from author label Oct 17, 2019
@vivarose
Copy link

I tried running the linked code and I confirm I also see an error:

RuntimeError: Cannot set iterations to a new Variable after the Optimizer weights have been created

@oanush oanush assigned rmothukuru and unassigned oanush Oct 18, 2019
@rmothukuru rmothukuru added TF 2.0 Issues relating to TensorFlow 2.0 and removed TF 2.0-rc0 labels Oct 18, 2019
@rmothukuru rmothukuru assigned tanzhenyu and unassigned rmothukuru Oct 18, 2019
@rmothukuru rmothukuru added stat:awaiting tensorflower Status - Awaiting response from tensorflower and removed stat:awaiting response Status - Awaiting response from author labels Oct 18, 2019
@JoshEZiegler
Copy link
Author

With some investigation, it looks like running estimator.train() multiple times is ok with the default optimizer, or by specifying the optimizer with a string. I believe that this is because the estimator.train() actually creates a new instance of the string-specified optimizer with each call, but retains the optimizer object if one was specified in the train call (see below).

A possible workaround for the above error could be to modify this function to return a fresh optimizer with the same parameters as the opt instance specified. However, I'm not sure if this is easily achievable or if there are any use cases that this would break...

From /tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/canned/optimizers.py

def get_optimizer_instance(opt, learning_rate=None):

if isinstance(opt, six.string_types):

    if opt in six.iterkeys(_OPTIMIZER_CLS_NAMES):
      if not learning_rate:
        raise ValueError('learning_rate must be specified when opt is string.')
      return _OPTIMIZER_CLS_NAMES[opt](learning_rate=learning_rate)
    raise ValueError(
        'Unsupported optimizer name: {}. Supported names are: {}'.format(
            opt, tuple(sorted(six.iterkeys(_OPTIMIZER_CLS_NAMES)))))
  if callable(opt):
    opt = opt()
  if not isinstance(opt, optimizer_lib.Optimizer):
    raise ValueError(
        'The given object is not an Optimizer instance. Given: {}'.format(opt))
  return opt

@awolant
Copy link

awolant commented Oct 29, 2019

@JoshEZiegler I'm not sure if this is applicable here, but I got the same error in different context and solution for me was to pass a callable instead of the instance as optimizer parameter allows that (docs). In your case you can pass optimizer=tf.keras.optimizers.Ftrl, I think.

@JoshEZiegler
Copy link
Author

@awolant I believe that workaround ends up being equivalent to what I mentioned just above your comment: passing a string to specify the optimizer. The key use case excluded by these workarounds is manual tuning of the hyperparams of the optimizer. It's unclear to me if there's a way to do that without passing an instance of optimizer...

@awolant
Copy link

awolant commented Oct 30, 2019

@JoshEZiegler Right. My use case was a bit different and I missed that. But since we can pass any callable, then maybe something like this will work for you:

from functools import partial
AdamWithParams = partial(Adam, learning_rate = 0.1)

...

optimizer = AdamWithParams

More on what partial is you can find in the docs. In my code it worked as expected.

@JoshEZiegler
Copy link
Author

@awolant Nice! Thanks, that sounds like the perfect workaround.

I'll leave this issue open because I don't believe that is the intended way to use optimizers with Estimators. At the very least not the way it is done in the TF docs.

@yhliang2018
Copy link
Contributor

Can you try with the latest tf-nightly? This should be already fixed.

@JoshEZiegler
Copy link
Author

JoshEZiegler commented Mar 7, 2020

@yhliang2018 This issue still shows up with tf-nightly-2.2.0.dev20200306. See this colab notebook.

Was there a specific version where you believe it should work?

@YingDongLuo
Copy link

Have you resolved it in TF-v2

@mustafa-qamaruddin
Copy link

I have got the same issue here too

@dynamicwebpaige
Copy link
Contributor

dynamicwebpaige commented May 17, 2020

Experienced the same issue in the latest tf-nightly, and reopening.

RuntimeError: Cannot set `iterations` to a new Variable after the Optimizer weights have been created

@yhliang2018
Copy link
Contributor

@mustafa-qamaruddin @JoshEZiegler Could you provide your use case of calling estimator.train() multiple times in a training pipeline? If the model needs to train with more steps, you can increase the steps arg in the train() method. I feel it's arguable to call train() method multiple times with the same optimizer in the training pipeline.

I'm working on a fix, and would like to check it in if more context on the use cases are provided. Thanks!

@JoshEZiegler
Copy link
Author

@yhliang2018 Sure, my use was to manually log loss/accuracy/metrics for plotting, etc within a notebook.

To get around this error I went ahead and switched to using tensorboard for achieving this same thing so I'm no longer affected by this.

Hopefully this is the type of context you're looking for?

  steps_per_period = steps / periods

  #create DNNRegressor Object
  my_optimizer = tf.optimizers.SGD(learning_rate=learning_rate, momentum=0.9, clipnorm=5.0)
  dnn_regressor = tf.estimator.DNNRegressor(
    feature_columns = construct_feature_columns(samples),
    hidden_units = hidden_units,
    optimizer = my_optimizer
  )

  # Create input functions.
  training_input_fn = lambda: input_fn(samples, 
                                          targets, 
                                          batch_size=batch_size)
  predict_training_input_fn = lambda: input_fn(samples, 
                                                  targets, 
                                                  num_epochs=1, 
                                                  shuffle=False)
  predict_validation_input_fn = lambda: input_fn(test_samples, 
                                                    test_targets, 
                                                    num_epochs=1, 
                                                    shuffle=False)
  # Train the model, but do so inside a loop so that we can periodically assess
  # loss metrics.
  print("Training model...")
  print("RMSE (on training data):")
  training_rmse = []
  validation_rmse = []
  for period in range (0, periods):
    # Train the model, starting from the prior state.
    print("Period[%s]" % (period+1))
    dnn_regressor.train(
        input_fn=training_input_fn,
        steps=steps_per_period
    )

@yhliang2018
Copy link
Contributor

@JoshEZiegler Thanks a lot for providing your use case. Yes, tensorboard is definitely a good option to check such info.

I think it's still good to support such use cases when you prefer to log/check the model related info manually. I will have the fix submitted soon, and let you know when it's available in tf-nightly.

@yhliang2018
Copy link
Contributor

More thoughts on this issue: if people create an optimizer instance for canned estimator, it's natural that people think the created optimizer object is the one used in model optimization process. However, if different optimizer instances are created to support estimator.train() call, the optimizer for the model optimization is always the new one, which confused people a lot in some use cases. One example:

    dnn_opt = tf.keras.optimizers.SGD(1.)
    linear_opt = tf.keras.optimizers.SGD(0.5)
    input_fn = ...
    est = dnn_linear_combined.DNNLinearCombinedClassifierV2(
        ...
        linear_optimizer=linear_opt,
        ...
        dnn_optimizer=dnn_opt)
    num_steps = 1
    est.train(input_fn, steps=num_steps)
    assert num_steps == est.get_variable_value(linear_opt.iterations.name)  # The linear_opt is never used in the optimizer, so it breaks the access of linear_opt.iterations.
    assert num_steps == est.get_variable_value(dnn_opt.iterations.name)  # Same as above, dnn_opt is never used either and its iterations cannot be accessed.

@JoshEZiegler given this, how about let's hold to check in the fix for now, and wait for more feedbacks?

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label May 20, 2020
@zjfGit
Copy link

zjfGit commented Dec 1, 2020

Experienced the same issue in the latest tf-nightly, and reopening.

RuntimeError: Cannot set `iterations` to a new Variable after the Optimizer weights have been created

i got the same error, have you solved it ?

@JoshEZiegler
Copy link
Author

@JoshEZiegler Right. My use case was a bit different and I missed that. But since we can pass any callable, then maybe something like this will work for you:

from functools import partial
AdamWithParams = partial(Adam, learning_rate = 0.1)

...

optimizer = AdamWithParams

More on what partial is you can find in the docs. In my code it worked as expected.

This workaround could be an option, but possibly not for your use case.

@JoshEZiegler
Copy link
Author

@yhliang2018 This issue still shows up with tf-nightly-2.2.0.dev20200306. See this colab notebook.

Was there a specific version where you believe it should work?

The same error occurs when running this same colab notebook using the latest tf-nightly.

@simnalamburt
Copy link

Instead of partial, use lambda to create optimizer object each time. It's way more easier in this way.

linear_regressor = tf.estimator.LinearRegressor(
    feature_columns=feature_columns,
    optimizer=lambda:tf.keras.optimizers.SGD(learning_rate=0.0000001, clipnorm=5.0),
)

@sanatmpa1 sanatmpa1 self-assigned this Oct 11, 2021
@sanatmpa1
Copy link

@JoshEZiegler,

We are checking to see if this is still an issue, Can you take a look at this workaround proposed by @simnalamburt and let us know if it helps? Thanks!

@sanatmpa1 sanatmpa1 added the stat:awaiting response Status - Awaiting response from author label Oct 11, 2021
@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Oct 18, 2021
@JoshEZiegler
Copy link
Author

@sanatmpa1 Hi it looks like the workaround suggested by @simnalamburt works as well.

Without a workaround it still appears to be an issue. (Tested with version 2.8.0-dev20211019)

@google-ml-butler google-ml-butler bot removed stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author labels Oct 20, 2021
@sanatmpa1
Copy link

Thanks for the confirmation @JoshEZiegler

@sanatmpa1 sanatmpa1 removed their assignment Oct 26, 2021
@sanatmpa1 sanatmpa1 added 2.6.0 and removed TF 2.0 Issues relating to TensorFlow 2.0 labels Oct 26, 2021
@tensorflowbutler
Copy link
Member

Hi There,

This is a stale issue. As you are using an older version of tensorflow, we are checking to see if you still need help on this issue. Please test the issue with the latest TensorFlow (TF2.7 and tf-nightly). If the issue still persists with the newer versions of TF, please feel free to open it in keras-team/keras repository by providing details about the issue and a standalone code to reproduce the issue. Thanks!

Please note that Keras development has moved to a separate Keras-team/keras repository to focus entirely on only Keras. Thanks!

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
2.6.0 comp:keras Keras related issues type:support Support issues
Projects
None yet
Development

No branches or pull requests