Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot use AdagradOptimizer with MirroredStrategy #19551

Closed
ed-alertedh opened this issue May 25, 2018 · 5 comments
Closed

Cannot use AdagradOptimizer with MirroredStrategy #19551

ed-alertedh opened this issue May 25, 2018 · 5 comments
Assignees
Labels
comp:dist-strat Distribution Strategy related issues type:bug Bug

Comments

@ed-alertedh
Copy link

ed-alertedh commented May 25, 2018

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04
  • TensorFlow installed from (source or binary): Binary
  • TensorFlow version (use command below): v1.8.0-0-g93bc2e2072 1.8.0
  • Python version: 3.6.3 64-bit
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version: 9.0 / 6.0
  • GPU model and memory: Nvidia GTX 1080 8 GB
  • Exact command to reproduce: python train_model.py

Describe the problem

It took me a while to work out what was going on, but it seems that tf.train.AdagradOptimizer has some specific implementation detail that causes an error when used with MirroredStrategy. I did a spot check with GradientDescentOptimizer and RMSPropOptimizer and they both appear to work in my environment. I'm happy to use a different optimizer as a workaround but I thought at the very least this might save others some time hunting down the cause of the error!

Source code / logs

This is almost exactly copied from the example at https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute (except for the choice of optimizer)

import tensorflow as tf

def model_fn(features, labels, mode):
    layer = tf.layers.Dense(1)
    logits = layer(features)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {"logits": logits}
        return tf.estimator.EstimatorSpec(mode, predictions=predictions)

    loss = tf.losses.mean_squared_error(
            labels=labels, predictions=tf.reshape(logits, []))

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode, loss=loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        train_op = tf.train.AdagradOptimizer(0.2).minimize(loss)
        return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

def input_fn():
    features = tf.data.Dataset.from_tensors([[1.]]).repeat(100)
    labels = tf.data.Dataset.from_tensors(1.).repeat(100)
    return tf.data.Dataset.zip((features, labels))

distribution = tf.contrib.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(train_distribute=distribution)
classifier = tf.estimator.Estimator(model_fn=model_fn, config=config)
classifier.train(input_fn=input_fn)

Log output:

2018-05-25 15:30:12.300908: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2018-05-25 15:30:14.231174: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1356] Found device 0 with properties: 
name: GeForce GTX 1080 major: 6 minor: 1 memoryClockRate(GHz): 1.898
pciBusID: 0000:03:00.0
totalMemory: 7.93GiB freeMemory: 7.81GiB
2018-05-25 15:30:14.557081: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1356] Found device 1 with properties: 
name: GeForce GTX 1080 major: 6 minor: 1 memoryClockRate(GHz): 1.898
pciBusID: 0000:81:00.0
totalMemory: 7.93GiB freeMemory: 7.81GiB
2018-05-25 15:30:14.557174: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1435] Adding visible gpu devices: 0, 1
2018-05-25 15:30:15.198082: I tensorflow/core/common_runtime/gpu/gpu_device.cc:923] Device interconnect StreamExecutor with strength 1 edge matrix:
2018-05-25 15:30:15.198134: I tensorflow/core/common_runtime/gpu/gpu_device.cc:929]      0 1 
2018-05-25 15:30:15.198142: I tensorflow/core/common_runtime/gpu/gpu_device.cc:942] 0:   N N 
2018-05-25 15:30:15.198145: I tensorflow/core/common_runtime/gpu/gpu_device.cc:942] 1:   N N 
2018-05-25 15:30:15.198488: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1053] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 7542 MB memory) -> physical GPU (device: 0, name: GeForce GTX 1080, pci bus id: 0000:03:00.0, compute capability: 6.1)
2018-05-25 15:30:15.324510: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1053] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:1 with 7543 MB memory) -> physical GPU (device: 1, name: GeForce GTX 1080, pci bus id: 0000:81:00.0, compute capability: 6.1)
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpqglycjzk
2018-05-25 15:30:15.455314: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1435] Adding visible gpu devices: 0, 1
2018-05-25 15:30:15.455414: I tensorflow/core/common_runtime/gpu/gpu_device.cc:923] Device interconnect StreamExecutor with strength 1 edge matrix:
2018-05-25 15:30:15.455423: I tensorflow/core/common_runtime/gpu/gpu_device.cc:929]      0 1 
2018-05-25 15:30:15.455427: I tensorflow/core/common_runtime/gpu/gpu_device.cc:942] 0:   N N 
2018-05-25 15:30:15.455431: I tensorflow/core/common_runtime/gpu/gpu_device.cc:942] 1:   N N 
2018-05-25 15:30:15.455615: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1053] Created TensorFlow device (/device:GPU:0 with 7542 MB memory) -> physical GPU (device: 0, name: GeForce GTX 1080, pci bus id: 0000:03:00.0, compute capability: 6.1)
2018-05-25 15:30:15.455720: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1053] Created TensorFlow device (/device:GPU:1 with 7543 MB memory) -> physical GPU (device: 1, name: GeForce GTX 1080, pci bus id: 0000:81:00.0, compute capability: 6.1)
Traceback (most recent call last):
  File "train_model.py", line 62, in <module>
    classifier.train(input_fn=input_fn)
  File "/home/aeiq/.local/share/virtualenvs/echoiq-308-HW9bcZ4A/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 363, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/home/aeiq/.local/share/virtualenvs/echoiq-308-HW9bcZ4A/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 841, in _train_model
    return self._train_model_distributed(input_fn, hooks, saving_listeners)
  File "/home/aeiq/.local/share/virtualenvs/echoiq-308-HW9bcZ4A/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 884, in _train_model_distributed
    self.config)
  File "/home/aeiq/.local/share/virtualenvs/echoiq-308-HW9bcZ4A/lib/python3.6/site-packages/tensorflow/python/training/distribute.py", line 756, in call_for_each_tower
    return self._call_for_each_tower(fn, *args, **kwargs)
  File "/home/aeiq/.local/share/virtualenvs/echoiq-308-HW9bcZ4A/lib/python3.6/site-packages/tensorflow/contrib/distribute/python/mirrored_strategy.py", line 254, in _call_for_each_tower
    coord.join(threads)
  File "/home/aeiq/.local/share/virtualenvs/echoiq-308-HW9bcZ4A/lib/python3.6/site-packages/tensorflow/python/training/coordinator.py", line 389, in join
    six.reraise(*self._exc_info_to_raise)
  File "/home/aeiq/.local/share/virtualenvs/echoiq-308-HW9bcZ4A/lib/python3.6/site-packages/six.py", line 693, in reraise
    raise value
  File "/home/aeiq/.local/share/virtualenvs/echoiq-308-HW9bcZ4A/lib/python3.6/site-packages/tensorflow/python/training/coordinator.py", line 297, in stop_on_exception
    yield
  File "/home/aeiq/.local/share/virtualenvs/echoiq-308-HW9bcZ4A/lib/python3.6/site-packages/tensorflow/contrib/distribute/python/mirrored_strategy.py", line 248, in _call_for_each_tower
    self, *merge_args, **merge_kwargs)
  File "/home/aeiq/.local/share/virtualenvs/echoiq-308-HW9bcZ4A/lib/python3.6/site-packages/tensorflow/python/training/optimizer.py", line 671, in _distributed_apply
    self._create_slots(var_list)
  File "/home/aeiq/.local/share/virtualenvs/echoiq-308-HW9bcZ4A/lib/python3.6/site-packages/tensorflow/python/training/adagrad.py", line 66, in _create_slots
    with ops.colocate_with(v):
  File "/home/aeiq/.pyenv/versions/3.6.3/lib/python3.6/contextlib.py", line 81, in __enter__
    return next(self.gen)
  File "/home/aeiq/.local/share/virtualenvs/echoiq-308-HW9bcZ4A/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 4186, in _colocate_with_for_gradient
    with self.colocate_with(op, ignore_existing):
  File "/home/aeiq/.pyenv/versions/3.6.3/lib/python3.6/contextlib.py", line 81, in __enter__
    return next(self.gen)
  File "/home/aeiq/.local/share/virtualenvs/echoiq-308-HW9bcZ4A/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 4239, in colocate_with
    op = internal_convert_to_tensor_or_indexed_slices(op, as_ref=True).op
  File "/home/aeiq/.local/share/virtualenvs/echoiq-308-HW9bcZ4A/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1262, in internal_convert_to_tensor_or_indexed_slices
    value, dtype=dtype, name=name, as_ref=as_ref)
  File "/home/aeiq/.local/share/virtualenvs/echoiq-308-HW9bcZ4A/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1104, in internal_convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  File "/home/aeiq/.local/share/virtualenvs/echoiq-308-HW9bcZ4A/lib/python3.6/site-packages/tensorflow/contrib/distribute/python/values.py", line 243, in _tensor_conversion
    assert not as_ref
AssertionError
@DavidWiesner
Copy link

I got the same problem when using the MovingAverageOptimizer

def model_fn(features, labels, mode):
    layer = tf.layers.Dense(1)
    logits = layer(features)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {"logits": logits}
        return tf.estimator.EstimatorSpec(mode, predictions=predictions)

    loss = tf.losses.mean_squared_error(
            labels=labels, predictions=tf.reshape(logits, []))

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode, loss=loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.GradientDescentOptimizer(0.02)
        optimizer = tf.contrib.opt.MovingAverageOptimizer(optimizer, 0.9)
        train_op = optimizer.minimize(loss)
        return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

def input_fn():
    features = tf.data.Dataset.from_tensors([[1.]]).repeat(100)
    labels = tf.data.Dataset.from_tensors(1.).repeat(100)
    return tf.data.Dataset.zip((features, labels))

distribution = tf.contrib.distribute.MirroredStrategy(num_gpus=2)
config = tf.estimator.RunConfig(train_distribute=distribution)
classifier = tf.estimator.Estimator(model_fn=model_fn, config=config)
classifier.train(input_fn=input_fn)

The error with some context

AssertionError                            Traceback (most recent call last)
<ipython-input-114-62e57e2880d4> in <module>()
     27 config = tf.estimator.RunConfig(train_distribute=distribution)
     28 classifier = tf.estimator.Estimator(model_fn=model_fn, config=config)
---> 29 classifier.train(input_fn=input_fn)

/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py in train(self, input_fn, hooks, steps, max_steps, saving_listeners)
    361 
    362     saving_listeners = _check_listeners_type(saving_listeners)
--> 363     loss = self._train_model(input_fn, hooks, saving_listeners)
    364     logging.info('Loss for final step: %s.', loss)
    365     return self

/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py in _train_model(self, input_fn, hooks, saving_listeners)
    839   def _train_model(self, input_fn, hooks, saving_listeners):
    840     if self._distribution:
--> 841       return self._train_model_distributed(input_fn, hooks, saving_listeners)
    842     else:
    843       return self._train_model_default(input_fn, hooks, saving_listeners)

/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py in _train_model_distributed(self, input_fn, hooks, saving_listeners)
    882             labels,  # although this will be None it seems
    883             model_fn_lib.ModeKeys.TRAIN,
--> 884             self.config)
    885 
    886         # TODO(anjalisridhar): Figure out how to resolve the folowing scaffold

/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/distribute.py in call_for_each_tower(self, fn, *args, **kwargs)
    754     """
    755     _require_cross_tower_context(self)
--> 756     return self._call_for_each_tower(fn, *args, **kwargs)
    757 
    758   def _call_for_each_tower(self, fn, *args, **kwargs):

/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/distribute/python/mirrored_strategy.py in _call_for_each_tower(self, fn, *args, **kwargs)
    252       for t in threads:
    253         t.should_run.set()
--> 254       coord.join(threads)
    255 
    256     return values.regroup({t.device: t.main_result for t in threads})

/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/coordinator.py in join(self, threads, stop_grace_period_secs, ignore_live_threads)
    387       self._registered_threads = set()
    388       if self._exc_info_to_raise:
--> 389         six.reraise(*self._exc_info_to_raise)
    390       elif stragglers:
    391         if ignore_live_threads:

/usr/local/lib/python3.5/dist-packages/six.py in reraise(tp, value, tb)
    691             if value.__traceback__ is not tb:
    692                 raise value.with_traceback(tb)
--> 693             raise value
    694         finally:
    695             value = None

/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/coordinator.py in stop_on_exception(self)
    295     """
    296     try:
--> 297       yield
    298     except:  # pylint: disable=bare-except
    299       self.request_stop(ex=sys.exc_info())

/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/distribute/python/mirrored_strategy.py in run(self)
    463                 self._captured_var_scope, reuse=self.tower_id > 0), \
    464             variable_scope.variable_creator_scope(self.variable_creator_fn):
--> 465           self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
    466           self.done = True
    467       finally:

/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py in _call_model_fn(self, features, labels, mode, config)
    829 
    830     logging.info('Calling model_fn.')
--> 831     model_fn_results = self._model_fn(features=features, **kwargs)
    832     logging.info('Done calling model_fn.')
    833 

<ipython-input-114-62e57e2880d4> in model_fn(features, labels, mode)
     16         optimizer = tf.train.GradientDescentOptimizer(0.02)
     17         optimizer = tf.contrib.opt.MovingAverageOptimizer(optimizer, 0.9)
---> 18         train_op = optimizer.minimize(loss)
     19         return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
     20 

/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/optimizer.py in minimize(self, loss, global_step, var_list, gate_gradients, aggregation_method, colocate_gradients_with_ops, name, grad_loss)
    422 
    423     return self.apply_gradients(grads_and_vars, global_step=global_step,
--> 424                                 name=name)
    425 
    426   def compute_gradients(self, loss, var_list=None,

/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/opt/python/training/moving_average_optimizer.py in apply_gradients(self, grads_and_vars, global_step, name)
     97     if self._sequential_update:
     98       with ops.control_dependencies([train_op]):
---> 99         ma_op = self._ema.apply(var_list)
    100     else:
    101       ma_op = self._ema.apply(var_list)

/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/moving_averages.py in apply(self, var_list)
    423         zero_debias = self._averages[var] in zero_debias_true
    424         updates.append(assign_moving_average(
--> 425             self._averages[var], var, decay, zero_debias=zero_debias))
    426       return control_flow_ops.group(*updates, name=scope)
    427 

/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/moving_averages.py in assign_moving_average(variable, value, decay, zero_debias, name)
     82   with ops.name_scope(name, "AssignMovingAvg",
     83                       [variable, value, decay]) as scope:
---> 84     with ops.colocate_with(variable):
     85       decay = ops.convert_to_tensor(1.0 - decay, name="decay")
     86       if decay.dtype != variable.dtype.base_dtype:

/usr/lib/python3.5/contextlib.py in __enter__(self)
     57     def __enter__(self):
     58         try:
---> 59             return next(self.gen)
     60         except StopIteration:
     61             raise RuntimeError("generator didn't yield") from None

/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py in _colocate_with_for_gradient(self, op, gradient_uid, ignore_existing)
   4184   def _colocate_with_for_gradient(self, op, gradient_uid,
   4185                                   ignore_existing=False):
-> 4186     with self.colocate_with(op, ignore_existing):
   4187       if gradient_uid is not None and self._control_flow_context is not None:
   4188         try:

/usr/lib/python3.5/contextlib.py in __enter__(self)
     57     def __enter__(self):
     58         try:
---> 59             return next(self.gen)
     60         except StopIteration:
     61             raise RuntimeError("generator didn't yield") from None

/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py in colocate_with(self, op, ignore_existing)
   4237     if op is not None and not isinstance(op, Operation):
   4238       # We always want to colocate with the reference op.
-> 4239       op = internal_convert_to_tensor_or_indexed_slices(op, as_ref=True).op
   4240 
   4241     # By default, colocate_with resets the device function stack,

/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py in internal_convert_to_tensor_or_indexed_slices(value, dtype, name, as_ref)
   1260   else:
   1261     return internal_convert_to_tensor(
-> 1262         value, dtype=dtype, name=name, as_ref=as_ref)
   1263 
   1264 

/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py in internal_convert_to_tensor(value, dtype, name, as_ref, preferred_dtype, ctx)
   1102 
   1103     if ret is None:
-> 1104       ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
   1105 
   1106     if ret is NotImplemented:

/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/distribute/python/values.py in _tensor_conversion(var, dtype, name, as_ref)
    241   # Try to avoid assignments to and other mutations of MirroredVariable
    242   # state except through a DistributionStrategy.update() call.
--> 243   assert not as_ref
    244   return ops.internal_convert_to_tensor(
    245       var.get(), dtype=dtype, name=name, as_ref=as_ref)

AssertionError: 

@rohan100jain rohan100jain added stat:awaiting tensorflower Status - Awaiting response from tensorflower type:bug Bug labels Jun 18, 2018
@guptapriya guptapriya assigned yuefengz and unassigned guptapriya Jun 19, 2018
@guptapriya
Copy link
Contributor

Thanks for bringing this to our attention. We will look into why AdagadOptimizer and MovingAverageOptimizer don't work.

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jun 20, 2018
@guptapriya guptapriya added the comp:dist-strat Distribution Strategy related issues label Jul 17, 2018
@guptapriya guptapriya assigned anj-s and unassigned yuefengz Jul 23, 2018
@JRMeyer
Copy link
Contributor

JRMeyer commented Jul 30, 2018

Hi All,

I'm getting this same Error with Adagrad (the default for pre-canned tf.estimarot.DNNClassifier)

Here's my definitions:

# Setup MirroredStrategy
distribution = tf.contrib.distribute.MirroredStrategy()
run_config = tf.estimator.RunConfig(train_distribute=distribution)

# Define Specs                                                                                                                                                                               
train_spec_dnn = tf.estimator.TrainSpec(input_fn = lambda: my_input_fn('train.tfrecords'))
eval_spec_dnn = tf.estimator.EvalSpec(input_fn = lambda: my_input_fn('eval.tfrecords') )


DNNClassifier = tf.estimator.DNNClassifier(
    feature_columns = [tf.feature_column.numeric_column(key='feats', dtype=tf.float64, shape=(nDims,))],                                                                    
    hidden_units = [256, 256, 256, 256],                                                                                                                                 
    n_classes = 200,
    model_dir = '/tmp/tf',
    config = run_config)

Below is the ERROR message:

2018-07-30 11:57:19.736726: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1435] Adding visible gpu devices: 0, 1, 2, 3
2018-07-30 11:57:19.736801: I tensorflow/core/common_runtime/gpu/gpu_device.cc:923] Device interconnect StreamExecutor with strength 1 edge matrix:
2018-07-30 11:57:19.736818: I tensorflow/core/common_runtime/gpu/gpu_device.cc:929]      0 1 2 3 
2018-07-30 11:57:19.736826: I tensorflow/core/common_runtime/gpu/gpu_device.cc:942] 0:   N Y Y Y 
2018-07-30 11:57:19.736833: I tensorflow/core/common_runtime/gpu/gpu_device.cc:942] 1:   Y N Y Y 
2018-07-30 11:57:19.736840: I tensorflow/core/common_runtime/gpu/gpu_device.cc:942] 2:   Y Y N Y 
2018-07-30 11:57:19.736846: I tensorflow/core/common_runtime/gpu/gpu_device.cc:942] 3:   Y Y Y N 
2018-07-30 11:57:19.737320: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1053] Created TensorFlow device (/device:GPU:0 with 14867 MB memory) -> physical GPU (device: 0, name: Tesla V100-SXM2-16GB, pci bus id: 0000:00:1b.0, compute capability: 7.0)
2018-07-30 11:57:19.737953: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1053] Created TensorFlow device (/device:GPU:1 with 14867 MB memory) -> physical GPU (device: 1, name: Tesla V100-SXM2-16GB, pci bus id: 0000:00:1c.0, compute capability: 7.0)
2018-07-30 11:57:19.738083: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1053] Created TensorFlow device (/device:GPU:2 with 14867 MB memory) -> physical GPU (device: 2, name: Tesla V100-SXM2-16GB, pci bus id: 0000:00:1d.0, compute capability: 7.0)
2018-07-30 11:57:19.738209: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1053] Created TensorFlow device (/device:GPU:3 with 14867 MB memory) -> physical GPU (device: 3, name: Tesla V100-SXM2-16GB, pci bus id: 0000:00:1e.0, compute capability: 7.0)
WARNING:tensorflow:Partitioned variables are disabled when using DistributionStrategy.
Traceback (most recent call last):
  File "train_and_eval.py", line 137, in <module>
    tf.estimator.train_and_evaluate(DNNClassifier, train_spec_dnn, eval_spec_dnn)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/estimator/training.py", line 439, in train_and_evaluate
    executor.run()
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/estimator/training.py", line 518, in run
    self.run_local()
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/estimator/training.py", line 650, in run_local
    hooks=train_hooks)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 363, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 841, in _train_model
    return self._train_model_distributed(input_fn, hooks, saving_listeners)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 884, in _train_model_distributed
    self.config)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/training/distribute.py", line 756, in call_for_each_tower
    return self._call_for_each_tower(fn, *args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/contrib/distribute/python/mirrored_strategy.py", line 254, in _call_for_each_tower
    coord.join(threads)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/training/coordinator.py", line 389, in join
    six.reraise(*self._exc_info_to_raise)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/six.py", line 693, in reraise
    raise value
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/training/coordinator.py", line 297, in stop_on_exception
    yield
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/contrib/distribute/python/mirrored_strategy.py", line 248, in _call_for_each_tower
    self, *merge_args, **merge_kwargs)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/training/optimizer.py", line 671, in _distributed_apply
    self._create_slots(var_list)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/training/adagrad.py", line 66, in _create_slots
    with ops.colocate_with(v):
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/contextlib.py", line 81, in __enter__
    return next(self.gen)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 4186, in _colocate_with_for_gradient
    with self.colocate_with(op, ignore_existing):
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/contextlib.py", line 81, in __enter__
    return next(self.gen)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 4239, in colocate_with
    op = internal_convert_to_tensor_or_indexed_slices(op, as_ref=True).op
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1262, in internal_convert_to_tensor_or_indexed_slices
    value, dtype=dtype, name=name, as_ref=as_ref)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1104, in internal_convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/contrib/distribute/python/values.py", line 243, in _tensor_conversion
    assert not as_ref
AssertionError

@tensorflowbutler
Copy link
Member

Nagging Assignee @anj-s: It has been 15 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

@tianmingdu
Copy link

optimizer1 = tf.train. GradientDescentOptimizer(learning_rate=FLAGS.tnet_lr)
optimizer1 = tf.contrib.estimator.TowerOptimizer(optimizer1)
optimizer2 = tf.train. GradientDescentOptimizer(learning_rate=FLAGS.mnet_lr)
optimizer2 = tf.contrib.estimator.TowerOptimizer(optimizer2)
tnet_variables = get_model_variables('tnet')
mnet_variables = get_model_variables('mnet')

train_op1 = slim.learning.create_train_op(loss,
                                          optimizer1,
                                          variables_to_train=tnet_variables,
                                          summarize_gradients=True)
train_op2 = slim.learning.create_train_op(loss,
                                          optimizer2,
                                          variables_to_train=mnet_variables,
                                          summarize_gradients=True)
train_op = tf.group(train_op1, train_op2)

Above two GradientDescentOptimizers do not work either.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:dist-strat Distribution Strategy related issues type:bug Bug
Projects
None yet
Development

No branches or pull requests

9 participants