Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"Cannot assign a device to node..." bug in TensorArrayScatter_grad when using pre_scanned tensor in double loop of scan/while/map #5117

Closed
yhg0112 opened this issue Oct 21, 2016 · 11 comments

Comments

@yhg0112
Copy link

yhg0112 commented Oct 21, 2016

Environment info

tensorflow branch : 0.11.0rc0
CUDA version : 7.0
cuDNN version : 6.5.48
OS version : Ubuntu 14.04.5 LTS
GPU : GPU0 titan x(maxwell), GPU1 Tesla K20c(not using in this code)
(Also using anaconda2 environment and Jupyter with tf.InteractiveSession())

The bug (or is this intended error?)

I was using tf.scan() and tf.map() to code seq-2-seq encoder decoder structure with attention mechanism.
When i tried to put scanned tensor in map_fn() inside another scan(), the graph is drawn as normally and i even can evaluate the value of output tensor.

However when i try to optimize, or get gradient of that tensor, the bug pops up saying InvalidArgumentError: Cannot assign a device to node 'gradients/scan_1/while/map/TensorArrayPack/TensorArrayScatter_grad/TensorArrayGather/f_acc': Could not satisfy explicit device specification '' because the node was colocated with a group of nodes that required incompatible device '/job:localhost/replica:0/task:0/GPU:0.

I tried config.allow_soft_placement = True, but it only changed the error log and didn't work.
It was really awkward that the error log complains about AttrValue must not be the value of DT_STRING_REF when i set config.allow_soft_placement = True

My code is like ,in simplified version ,following : (i wrote a bug-reproducing example code at the bottom)
encoder_states = tf.scan(_encoder_step, encoder_inputs, initializer=encoder_initial_states)
decoder_states = tf.scan(_decoder_step, decoder_inputs, initializer=encoder_states[-1]

, and in def _decoder_step(prev_h, inputs): i used tf.map_fn() to get aligned context of encoder states as attention mechanism in https://arxiv.org/abs/1409.0473.

It looks like following :
in _decoder_step(prev_h, inputs):

    def alignment_model(inputs):
        # prev_h has the shape of [num_layer, num_batch, num_hidden]; prev_h[0] is the value of the first layer in decoder.
        alignment_state = tf.nn.tanh(_linear([inputs, prev_h[0], output_dim=num_slot, bias=False))
        return _linear([alignment_state], output_dim=1, bias=False)
    alignment = tf.map_fn(alignment_model, encoder_states) # and here the bug comes.

error message

with config.allow_soft_placement = False :

InvalidArgumentError: Cannot assign a device to node 'gradients/Decoder/scan/while/Attention/map/TensorArrayPack/TensorArrayScatter_grad/TensorArrayGather/f_acc': Could not satisfy explicit device specification '' because the node was colocated with a group of nodes that required incompatible device '/job:localhost/replica:0/task:0/GPU:0'
Colocation Debug Info:
Colocation group had the following types and devices: 
TensorArrayWrite: GPU CPU 
TensorArray: GPU CPU 
StackPush: GPU CPU 
TensorArrayRead: GPU CPU 
Range: GPU CPU 
Stack: GPU CPU 
StackPop: GPU CPU 
Const: GPU CPU 
TensorArrayScatter: GPU CPU 
RefEnter: GPU CPU 
Enter: GPU CPU 
TensorArrayGather: GPU CPU 
StridedSlice: GPU CPU 
TensorArrayGrad: GPU CPU 
Identity: GPU CPU 
Shape: GPU CPU 
     [[Node: gradients/Decoder/scan/while/Attention/map/TensorArrayPack/TensorArrayScatter_grad/TensorArrayGather/f_acc = Stack[_class=["loc:@Decoder/scan/while/Attention/map/TensorArray"], elem_type=DT_INT32, stack_name=""]()]]

Caused by op u'gradients/Decoder/scan/while/Attention/map/TensorArrayPack/TensorArrayScatter_grad/TensorArrayGather/f_acc', defined at:
  File "/home/youaredeadl/anaconda2/lib/python2.7/runpy.py", line 162, in _run_module_as_main
    "__main__", fname, loader, pkg_name)
  File "/home/youaredeadl/anaconda2/lib/python2.7/runpy.py", line 72, in _run_code
    exec code in run_globals
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/ipykernel/__main__.py", line 3, in <module>
    app.launch_new_instance()
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/traitlets/config/application.py", line 596, in launch_instance
    app.start()
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/ipykernel/kernelapp.py", line 442, in start
    ioloop.IOLoop.instance().start()
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/zmq/eventloop/ioloop.py", line 162, in start
    super(ZMQIOLoop, self).start()
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tornado/ioloop.py", line 883, in start
    handler_func(fd_obj, events)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tornado/stack_context.py", line 275, in null_wrapper
    return fn(*args, **kwargs)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 440, in _handle_events
    self._handle_recv()
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tornado/stack_context.py", line 275, in null_wrapper
    return fn(*args, **kwargs)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 276, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 228, in dispatch_shell
    handler(stream, idents, msg)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 391, in execute_request
    user_expressions, allow_stdin)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/ipykernel/ipkernel.py", line 199, in do_execute
    shell.run_cell(code, store_history=store_history, silent=silent)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2723, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2825, in run_ast_nodes
    if self.run_code(code, result):
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2885, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-16-940fe231159a>", line 3, in <module>
    grads, _ = tf.clip_by_global_norm(tf.gradients(model.loss, tvars, aggregation_method=tf.AggregationMethod.ADD_N), 1)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gradients.py", line 469, in gradients
    in_grads = _AsList(grad_fn(op, *out_grads))
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/tensor_array_grad.py", line 213, in _TensorArrayScatterGrad
    grad = g.gather(indices)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/tensor_array_ops.py", line 301, in gather
    element_shape=element_shape)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gen_data_flow_ops.py", line 1302, in _tensor_array_gather
    element_shape=element_shape, name=name)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 749, in apply_op
    op_def=op_def)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2380, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1302, in __init__
    self._control_flow_context.AddOp(self)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 1941, in AddOp
    self._AddOpInternal(op)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 1965, in _AddOpInternal
    self.AddValue(x)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 1900, in AddValue
    real_val = grad_ctxt.grad_state.GetRealValue(val)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 987, in GetRealValue
    history_value = cur_grad_state.AddForwardAccumulator(cur_value)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 861, in AddForwardAccumulator
    acc = gen_data_flow_ops._stack(value.dtype.base_dtype, name="f_acc")
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gen_data_flow_ops.py", line 1104, in _stack
    stack_name=stack_name, name=name)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 749, in apply_op
    op_def=op_def)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2380, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1298, in __init__
    self._traceback = _extract_stack()

...which was originally created as op u'Decoder/scan/while/Attention/map/TensorArrayPack/TensorArrayScatter', defined at:
  File "/home/youaredeadl/anaconda2/lib/python2.7/runpy.py", line 162, in _run_module_as_main
    "__main__", fname, loader, pkg_name)
[elided 17 identical lines from previous traceback]
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2885, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-15-153e1e7c7c3b>", line 6, in <module>
    num_slot=_num_slot)
  File "<ipython-input-13-1ccd26e81475>", line 34, in __init__
    initializer=self.decoder_initial_states) # shape of [time, num_layer, batch, num_hidden]
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/functional_ops.py", line 563, in scan
    back_prop=back_prop, swap_memory=swap_memory)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2518, in while_loop
    result = context.BuildLoop(cond, body, loop_vars, shape_invariants)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2356, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2306, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/functional_ops.py", line 553, in compute
    a_out = fn(packed_a, packed_elems)
  File "<ipython-input-13-1ccd26e81475>", line 98, in _decoder_step
    context = self._get_context(encoder_last_states, states[0])
  File "<ipython-input-13-1ccd26e81475>", line 81, in _get_context
    alignment = tf.map_fn(alignment_hidden_layer, encoder_states) # shape of [time, batch, 1]
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/functional_ops.py", line 333, in map_fn
    elem_ta.unpack(elem) for elem_ta, elem in zip(elems_ta, elems_flat)]
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/tensor_array_ops.py", line 391, in unpack
    indices=math_ops.range(0, num_elements), value=value, name=name)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/tensor_array_ops.py", line 412, in scatter
    name=name)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gen_data_flow_ops.py", line 1447, in _tensor_array_scatter
    name=name)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 749, in apply_op
    op_def=op_def)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2380, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1298, in __init__
    self._traceback = _extract_stack()

InvalidArgumentError (see above for traceback): Cannot assign a device to node 'gradients/Decoder/scan/while/Attention/map/TensorArrayPack/TensorArrayScatter_grad/TensorArrayGather/f_acc': Could not satisfy explicit device specification '' because the node was colocated with a group of nodes that required incompatible device '/job:localhost/replica:0/task:0/GPU:0'
Colocation Debug Info:
Colocation group had the following types and devices: 
TensorArrayWrite: GPU CPU 
TensorArray: GPU CPU 
StackPush: GPU CPU 
TensorArrayRead: GPU CPU 
Range: GPU CPU 
Stack: GPU CPU 
StackPop: GPU CPU 
Const: GPU CPU 
TensorArrayScatter: GPU CPU 
RefEnter: GPU CPU 
Enter: GPU CPU 
TensorArrayGather: GPU CPU 
StridedSlice: GPU CPU 
TensorArrayGrad: GPU CPU 
Identity: GPU CPU 
Shape: GPU CPU 
     [[Node: gradients/Decoder/scan/while/Attention/map/TensorArrayPack/TensorArrayScatter_grad/TensorArrayGather/f_acc = Stack[_class=["loc:@Decoder/scan/while/Attention/map/TensorArray"], elem_type=DT_INT32, stack_name=""]()]]

with config.allow_soft_placement = True:

InvalidArgumentError: AttrValue must not have reference type value of string_ref
     for attr 'tensor_type'
    ; NodeDef: scan_1/while/map/TensorArray/_211 = _Recv[_start_time=0, client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_666_scan_1/while/map/TensorArray", tensor_type=DT_STRING_REF, _device="/job:localhost/replica:0/task:0/cpu:0"](^_cloopscan_1/while/map/TensorArrayPack_1/range/delta/_37); Op<name=_Recv; signature= -> tensor:tensor_type; attr=tensor_type:type; attr=tensor_name:string; attr=send_device:string; attr=send_device_incarnation:int; attr=recv_device:string; attr=client_terminated:bool,default=false; is_stateful=true>
     [[Node: scan_1/while/map/TensorArray/_211 = _Recv[_start_time=0, client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_666_scan_1/while/map/TensorArray", tensor_type=DT_STRING_REF, _device="/job:localhost/replica:0/task:0/cpu:0"](^_cloopscan_1/while/map/TensorArrayPack_1/range/delta/_37)]]

reproducible example code

import tensorflow as tf
import numpy as np

in_data_for_pre_scan = np.ones(shape=[2, 8, 5], dtype=np.float64)
in_data_for_post_scan = np.ones(shape=[2, 8, 5], dtype=np.float64)
initial_state_data_for_pre_scan = np.zeros(shape=[8, 5], dtype=np.float64)
initials_state_data_for_post_scan = np.zeros(shape=[8, 5], dtype=np.float64)

inputs_for_pre_scan = tf.placeholder(shape=[None, None, 5], dtype=tf.float64)
inputs_for_post_scan = tf.placeholder(shape=[None, None, 5], dtype=tf.float64)
initial_state_for_pre_scan = tf.placeholder(shape=[None, 5], dtype=tf.float64)
initial_state_for_post_scan = tf.placeholder(shape=[None, 5], dtype=tf.float64)

weight = tf.get_variable('W', [5, 5], dtype=tf.float64)

def pre_scan(states, inputs):    
    return states + tf.matmul(inputs, weight)

def post_scan(states, inputs):
    def inner_map(inputs):
        return inputs
    loop_output = tf.map_fn(inner_map, pre_scanned[0]) 

    return states + loop_output + inputs

pre_scanned = tf.scan(pre_scan, inputs_for_pre_scan, initializer=initial_state_for_pre_scan)
res = tf.scan(post_scan, inputs_for_post_scan, initializer=initial_state_for_post_scan)

opt_func = tf.train.AdamOptimizer()
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(tf.reduce_mean(tf.square(res)), tvars, aggregation_method=tf.AggregationMethod.ADD_N), 1)
optimizer = opt_func.apply_gradients(zip(grads, tvars))

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
#config.allow_soft_placement = True
sess = tf.InteractiveSession(config=config)
sess.run(tf.initialize_all_variables())

sess.run(res, feed_dict={inputs_for_pre_scan:in_data_for_pre_scan, 
                         initial_state_for_pre_scan:initial_state_data_for_pre_scan, 
                         inputs_for_post_scan:in_data_for_post_scan, 
                         initial_state_for_post_scan:initials_state_data_for_post_scan}) # this runs as just fine.

sess.run(optimizer, feed_dict={inputs_for_pre_scan:in_data_for_pre_scan, 
                               initial_state_for_pre_scan:initial_state_data_for_pre_scan, 
                               inputs_for_post_scan:in_data_for_post_scan, 
                               initial_state_for_post_scan:initials_state_data_for_post_scan}) # this doesn't work.

The log says about 'scatter()' in 'ops/tensor_array_ops.py' and _tensor_array_scatter' inops/gen_data_flow_ops.py`, which is written in this branch.

@ebrevdo would anybody get me some hints about this?


edited :

if i only run the optimizer part without running res, such as,

'''
sess.run(res, feed_dict={inputs_for_pre_scan:in_data_for_pre_scan, 
                         initial_state_for_pre_scan:initial_state_data_for_pre_scan, 
                         inputs_for_post_scan:in_data_for_post_scan, 
                         initial_state_for_post_scan:initials_state_data_for_post_scan}) # this runs as just fine.
'''
sess.run(optimizer, feed_dict={inputs_for_pre_scan:in_data_for_pre_scan, 
                               initial_state_for_pre_scan:initial_state_data_for_pre_scan, 
                               inputs_for_post_scan:in_data_for_post_scan, 
                               initial_state_for_post_scan:initials_state_data_for_post_scan}) # this doesn't work.

then it works fine.

but still if i run both of them, sess.run(optimizer) raises InvalidArgumentError. could it be problem in my GPU config? actually when i execute nvidia-smi, it says that GPU0 is Tesla k20c and GPU1 is Geforce gtx titan x, but tensorflow says in the reverse order.

@drpngx
Copy link
Contributor

drpngx commented Oct 24, 2016

@lukaszkaiser this started from seq2seq, but seems isolated to something else. Maybe you have some insight.

@lukaszkaiser
Copy link
Contributor

Maybe it is a problem with dynamic loops and TensorArrays? I see a lot of scan use. I think ebrevdo@ or Yuan might know better.

@drpngx
Copy link
Contributor

drpngx commented Oct 25, 2016

@ebrevdo Any clue?

@ebrevdo
Copy link
Contributor

ebrevdo commented Oct 25, 2016

When you call tf.gradiwnts, there's an option collocate_with_... Do you
enable it?

On Oct 21, 2016 8:24 AM, "yhg0112" notifications@github.com wrote:

Environment info

tensorflow branch : 0.11.0rc0
CUDA version : 7.0
cuDNN version : 6.5.48
OS version : Ubuntu 14.04.5 LTS
GPU : GPU0 titan x(maxwell), GPU1 Tesla K20c(not using in this code)
(Also using anaconda2 environment and Jupyter with tf.InteractiveSession())
The bug (or is this intended error?)

I was using tf.scan() and tf.map() to code seq-2-seq encoder decoder
structure with attention mechanism.
When i tried to put scanned tensor in map_fn() inside another scan(), the
graph is drawn as normally and i even can evaluate the value of output
tensor.

However when i try to optimize, or get gradient of that tensor, the bug
pops up saying InvalidArgumentError: Cannot assign a device to node
'gradients/scan_1/while/map/TensorArrayPack/TensorArrayScatter_grad/TensorArrayGather/f_acc':
Could not satisfy explicit device specification '' because the node was
colocated with a group of nodes that required incompatible device
'/job:localhost/replica:0/task:0/GPU:0.

I tried config.allow_soft_placement = True, but it only changed the error
log and didn't work.
It was really awkward that the error log complains about AttrValue must
not be the value of DT_STRING_REF when i set config.allow_soft_placement
= True

My code is like ,in simplified version ,following : (i wrote a
bug-reproducing example code at the bottom)
encoder_states = tf.scan(_encoder_step, encoder_inputs,
initializer=encoder_initial_states)
decoder_states = tf.scan(_decoder_step, decoder_inputs,
initializer=encoder_states[-1]

, and in def _decoder_step(prev_h, inputs): i used tf.map_fn() to get
aligned context of encoder states as attention mechanism in
https://arxiv.org/abs/1409.0473.

It looks like following :
in _decoder_step(prev_h, inputs):

def alignment_model(inputs):
    # prev_h has the shape of [num_layer, num_batch, num_hidden]; prev_h[0] is the value of the first layer in decoder.
    alignment_state = tf.nn.tanh(_linear([inputs, prev_h[0], output_dim=num_slot, bias=False))
    return _linear([alignment_state], output_dim=1, bias=False)
alignment = tf.map_fn(alignment_model, encoder_states) # and here the bug comes.

error message

with config.allow_soft_placement = False :

InvalidArgumentError: Cannot assign a device to node 'gradients/Decoder/scan/while/Attention/map/TensorArrayPack/TensorArrayScatter_grad/TensorArrayGather/f_acc': Could not satisfy explicit device specification '' because the node was colocated with a group of nodes that required incompatible device '/job:localhost/replica:0/task:0/GPU:0'
Colocation Debug Info:
Colocation group had the following types and devices:
TensorArrayWrite: GPU CPU
TensorArray: GPU CPU
StackPush: GPU CPU
TensorArrayRead: GPU CPU
Range: GPU CPU
Stack: GPU CPU
StackPop: GPU CPU
Const: GPU CPU
TensorArrayScatter: GPU CPU
RefEnter: GPU CPU
Enter: GPU CPU
TensorArrayGather: GPU CPU
StridedSlice: GPU CPU
TensorArrayGrad: GPU CPU
Identity: GPU CPU
Shape: GPU CPU
[[Node: gradients/Decoder/scan/while/Attention/map/TensorArrayPack/TensorArrayScatter_grad/TensorArrayGather/f_acc = Stack_class=["loc:@Decoder/scan/while/Attention/map/TensorArray"], elem_type=DT_INT32, stack_name=""]]

Caused by op u'gradients/Decoder/scan/while/Attention/map/TensorArrayPack/TensorArrayScatter_grad/TensorArrayGather/f_acc', defined at:
File "/home/youaredeadl/anaconda2/lib/python2.7/runpy.py", line 162, in _run_module_as_main
"main", fname, loader, pkg_name)
File "/home/youaredeadl/anaconda2/lib/python2.7/runpy.py", line 72, in _run_code
exec code in run_globals
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/ipykernel/main.py", line 3, in
app.launch_new_instance()
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/traitlets/config/application.py", line 596, in launch_instance
app.start()
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/ipykernel/kernelapp.py", line 442, in start
ioloop.IOLoop.instance().start()
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/zmq/eventloop/ioloop.py", line 162, in start
super(ZMQIOLoop, self).start()
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tornado/ioloop.py", line 883, in start
handler_func(fd_obj, events)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tornado/stack_context.py", line 275, in null_wrapper
return fn(_args, *_kwargs)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 440, in _handle_events
self._handle_recv()
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 472, in _handle_recv
self._run_callback(callback, msg)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
callback(_args, *_kwargs)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tornado/stack_context.py", line 275, in null_wrapper
return fn(_args, *_kwargs)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 276, in dispatcher
return self.dispatch_shell(stream, msg)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 228, in dispatch_shell
handler(stream, idents, msg)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 391, in execute_request
user_expressions, allow_stdin)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/ipykernel/ipkernel.py", line 199, in do_execute
shell.run_cell(code, store_history=store_history, silent=silent)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2723, in run_cell
interactivity=interactivity, compiler=compiler, result=result)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2825, in run_ast_nodes
if self.run_code(code, result):
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2885, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "", line 3, in
grads, _ = tf.clip_by_global_norm(tf.gradients(model.loss, tvars, aggregation_method=tf.AggregationMethod.ADD_N), 1)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gradients.py", line 469, in gradients
in_grads = _AsList(grad_fn(op, *out_grads))
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/tensor_array_grad.py", line 213, in _TensorArrayScatterGrad
grad = g.gather(indices)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/tensor_array_ops.py", line 301, in gather
element_shape=element_shape)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gen_data_flow_ops.py", line 1302, in _tensor_array_gather
element_shape=element_shape, name=name)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 749, in apply_op
op_def=op_def)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2380, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1302, in init
self._control_flow_context.AddOp(self)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 1941, in AddOp
self._AddOpInternal(op)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 1965, in _AddOpInternal
self.AddValue(x)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 1900, in AddValue
real_val = grad_ctxt.grad_state.GetRealValue(val)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 987, in GetRealValue
history_value = cur_grad_state.AddForwardAccumulator(cur_value)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 861, in AddForwardAccumulator
acc = gen_data_flow_ops._stack(value.dtype.base_dtype, name="f_acc")
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gen_data_flow_ops.py", line 1104, in _stack
stack_name=stack_name, name=name)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 749, in apply_op
op_def=op_def)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2380, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1298, in init
self._traceback = _extract_stack()

...which was originally created as op u'Decoder/scan/while/Attention/map/TensorArrayPack/TensorArrayScatter', defined at:
File "/home/youaredeadl/anaconda2/lib/python2.7/runpy.py", line 162, in _run_module_as_main
"main", fname, loader, pkg_name)
[elided 17 identical lines from previous traceback]
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2885, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "", line 6, in
num_slot=_num_slot)
File "", line 34, in init
initializer=self.decoder_initial_states) # shape of [time, num_layer, batch, num_hidden]
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/functional_ops.py", line 563, in scan
back_prop=back_prop, swap_memory=swap_memory)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2518, in while_loop
result = context.BuildLoop(cond, body, loop_vars, shape_invariants)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2356, in BuildLoop
pred, body, original_loop_vars, loop_vars, shape_invariants)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2306, in _BuildLoop
body_result = body(*packed_vars_for_body)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/functional_ops.py", line 553, in compute
a_out = fn(packed_a, packed_elems)
File "", line 98, in _decoder_step
context = self._get_context(encoder_last_states, states[0])
File "", line 81, in _get_context
alignment = tf.map_fn(alignment_hidden_layer, encoder_states) # shape of [time, batch, 1]
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/functional_ops.py", line 333, in map_fn
elem_ta.unpack(elem) for elem_ta, elem in zip(elems_ta, elems_flat)]
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/tensor_array_ops.py", line 391, in unpack
indices=math_ops.range(0, num_elements), value=value, name=name)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/tensor_array_ops.py", line 412, in scatter
name=name)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gen_data_flow_ops.py", line 1447, in _tensor_array_scatter
name=name)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 749, in apply_op
op_def=op_def)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2380, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/home/youaredeadl/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1298, in init
self._traceback = _extract_stack()

InvalidArgumentError (see above for traceback): Cannot assign a device to node 'gradients/Decoder/scan/while/Attention/map/TensorArrayPack/TensorArrayScatter_grad/TensorArrayGather/f_acc': Could not satisfy explicit device specification '' because the node was colocated with a group of nodes that required incompatible device '/job:localhost/replica:0/task:0/GPU:0'
Colocation Debug Info:
Colocation group had the following types and devices:
TensorArrayWrite: GPU CPU
TensorArray: GPU CPU
StackPush: GPU CPU
TensorArrayRead: GPU CPU
Range: GPU CPU
Stack: GPU CPU
StackPop: GPU CPU
Const: GPU CPU
TensorArrayScatter: GPU CPU
RefEnter: GPU CPU
Enter: GPU CPU
TensorArrayGather: GPU CPU
StridedSlice: GPU CPU
TensorArrayGrad: GPU CPU
Identity: GPU CPU
Shape: GPU CPU
[[Node: gradients/Decoder/scan/while/Attention/map/TensorArrayPack/TensorArrayScatter_grad/TensorArrayGather/f_acc = Stack_class=["loc:@Decoder/scan/while/Attention/map/TensorArray"], elem_type=DT_INT32, stack_name=""]]

with config.allow_soft_placement = True:

InvalidArgumentError: AttrValue must not have reference type value of string_ref
for attr 'tensor_type'
; NodeDef: scan_1/while/map/TensorArray/_211 = _Recv_start_time=0, client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_666_scan_1/while/map/TensorArray", tensor_type=DT_STRING_REF, _device="/job:localhost/replica:0/task:0/cpu:0"; Op<name=_Recv; signature= -> tensor:tensor_type; attr=tensor_type:type; attr=tensor_name:string; attr=send_device:string; attr=send_device_incarnation:int; attr=recv_device:string; attr=client_terminated:bool,default=false; is_stateful=true>
[[Node: scan_1/while/map/TensorArray/_211 = _Recv_start_time=0, client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_666_scan_1/while/map/TensorArray", tensor_type=DT_STRING_REF, _device="/job:localhost/replica:0/task:0/cpu:0"]]

reproducible example code

import tensorflow as tf
import numpy as np

in_data_for_pre_scan = np.ones(shape=[2, 8, 5], dtype=np.float64)
in_data_for_post_scan = np.ones(shape=[2, 8, 5], dtype=np.float64)
initial_state_data_for_pre_scan = np.zeros(shape=[8, 5], dtype=np.float64)
initials_state_data_for_post_scan = np.zeros(shape=[8, 5], dtype=np.float64)

inputs_for_pre_scan = tf.placeholder(shape=[None, None, 5], dtype=tf.float64)
inputs_for_post_scan = tf.placeholder(shape=[None, None, 5], dtype=tf.float64)
initial_state_for_pre_scan = tf.placeholder(shape=[None, 5], dtype=tf.float64)
initial_state_for_post_scan = tf.placeholder(shape=[None, 5], dtype=tf.float64)

weight = tf.get_variable('W', [5, 5], dtype=tf.float64)

def pre_scan(states, inputs):
return states + tf.matmul(inputs, weight)

def post_scan(states, inputs):
def inner_map(inputs):
return inputs
loop_output = tf.map_fn(inner_map, pre_scanned[0])

return states + loop_output + inputs

pre_scanned = tf.scan(pre_scan, inputs_for_pre_scan, initializer=initial_state_for_pre_scan)
res = tf.scan(post_scan, inputs_for_post_scan, initializer=initial_state_for_post_scan)

opt_func = tf.train.AdamOptimizer()
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(tf.reduce_mean(tf.square(res)), tvars, aggregation_method=tf.AggregationMethod.ADD_N), 1)
optimizer = opt_func.apply_gradients(zip(grads, tvars))

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
#config.allow_soft_placement = True
sess = tf.InteractiveSession(config=config)
sess.run(tf.initialize_all_variables())

sess.run(res, feed_dict={inputs_for_pre_scan:in_data_for_pre_scan,
initial_state_for_pre_scan:initial_state_data_for_pre_scan,
inputs_for_post_scan:in_data_for_post_scan,
initial_state_for_post_scan:initials_state_data_for_post_scan}) # this runs as just fine.

sess.run(optimizer, feed_dict={inputs_for_pre_scan:in_data_for_pre_scan,
initial_state_for_pre_scan:initial_state_data_for_pre_scan,
inputs_for_post_scan:in_data_for_post_scan,
initial_state_for_post_scan:initials_state_data_for_post_scan}) # this doesn't work.

The log says about 'scatter()' in 'ops/tensor_array_ops.py' and _tensor_array_scatter'
inops/gen_data_flow_ops.py`, which is written in this branch.

@ebrevdo https://github.com/ebrevdo would you get me some hints about
this?


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
#5117, or mute the thread
https://github.com/notifications/unsubscribe-auth/ABtim8SbUObZ5JPp0bRIVlbjfFw2UlLHks5q2NlBgaJpZM4KdWKH
.

@drpngx drpngx added the stat:awaiting response Status - Awaiting response from author label Oct 25, 2016
@yhg0112
Copy link
Author

yhg0112 commented Oct 25, 2016

I've just run with tf.gradients colocate_gradients_with_ops option. Either setting it True or False, same error had been raised.

@ebrevdo
Copy link
Contributor

ebrevdo commented Oct 25, 2016

We just pushed some better debugging for this to master. It should be
synced by EOD. Can you try using a build off master tonight or tomorrow
and report back?

On Oct 25, 2016 1:24 PM, "yhg0112" notifications@github.com wrote:

I've just run with tf.gradients colocate_gradients_with_ops option.
Either setting it True or False, same error had been raised.


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
#5117 (comment),
or mute the thread
https://github.com/notifications/unsubscribe-auth/ABtim6PHvDcfkdFUr8gYcCQD7--fokZCks5q3mWVgaJpZM4KdWKH
.

@yhg0112
Copy link
Author

yhg0112 commented Oct 26, 2016

alright. i've just did test with master branch, /tensorflow-0.11.0rc1-py2-none-any.whl.

Evaluation works fine as well, but optimization with gradients have been broken again with same error.

i guess it's not merged yet.

@ebrevdo
Copy link
Contributor

ebrevdo commented Oct 27, 2016

Please provide a minimal failing example, full code. Try to reduce it to
as few ops as possible. Are you running rnns on separate GPUs?

On Oct 25, 2016 10:46 PM, "yhg0112" notifications@github.com wrote:

alright. i've just did test with master branch, /tensorflow-0.11.0rc1-py2-
none-any.whl.

Evaluation works fine as well, but optimization with gradients have been
broken again with same error.

i guess it's not merged yet.


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
#5117 (comment),
or mute the thread
https://github.com/notifications/unsubscribe-auth/ABtim9S91NVmwvjyQSgEUeCvN3hb-QILks5q3ukrgaJpZM4KdWKH
.

@yhg0112
Copy link
Author

yhg0112 commented Oct 30, 2016

i'm really sorry about late response.

i have just pulled the master branch and re-installed tensorflow (tf.__version__ is 0.11.0rc1), and nothing changed yet.

Isn't my bug-reproducing code example simple enough? the error seems to happen when i try to put scaned tensor into double map_fn (or scan) loop. It's really strange that it works if i first run sess.run(optimizer, ... but it doesn't work if i first run sess.run(res, ... before i run the optimizer.

@yhg0112
Copy link
Author

yhg0112 commented Oct 30, 2016

i've run the example with pip installed version tensorflow (tf.__version__ is 0.11.0rc2).

And the error didn't happen. i think it is solved. thank you all.

@yhg0112 yhg0112 closed this as completed Oct 30, 2016
@drpngx
Copy link
Contributor

drpngx commented Oct 30, 2016

Yay!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants