Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dependencies of tensors created within a tf.while_loop() might not be executed #15891

Closed
ghost opened this issue Jan 5, 2018 · 12 comments
Closed
Assignees
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower type:bug Bug

Comments

@ghost
Copy link

ghost commented Jan 5, 2018

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes. See test case below.
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): macOS 'Sierra' Version 10.12.6 (16G1114)
  • TensorFlow installed from (source or binary): Both. I have compiled TensorFlow at 136697e with my small change in PR Fix a pessimizing-move warning in GetDeviceLapackInfo() #15823. I have also tried using the pip package.
  • TensorFlow version (use command below): ('v1.4.0-19-ga52c8d9b01', '1.4.1') (pip package)
  • Python version: 2.7.10
  • Bazel version (if compiling from source): 0.9.0-homebrew
  • GCC/Compiler version (if compiling from source): Apple LLVM version 8.1.0 (clang-802.0.42)
  • CUDA/cuDNN version: CUDA 9.0.176_mac, cuDNN 9.0-osx-x64-v7
  • GPU model and memory: NVIDIA GeForce GT 750M with 2048 MB device memory (CUDA Compute Capability 3.0)
  • Exact command to reproduce:

python repro.py

.. where repro.py contains the test case to reproduce, listed below.

Describe the problem

Here is my test case:

# Part I
from __future__ import division, print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.ops import resource_variable_ops as rr

rs = np.random.RandomState(seed = 2)
A = rs.normal(size = (10, 10,))
print('singular values of A: %s' % (np.linalg.svd(A, compute_uv = False),))
B = rs.normal(size = (10, 10,))
print('singular values of B: %s' % (np.linalg.svd(B, compute_uv = False),))



# Part II
A_var = tf.Variable(B)
init_A_var_op = tf.assign(A_var, A)
A_dep = tf.constant(10, tf.int32)

with tf.control_dependencies([init_A_var_op]):
    A_dep = A_dep + 1

with tf.control_dependencies([A_dep]):
    var_s = tf.svd(A_var, compute_uv = False)
with tf.Session() as session:
    session.run(tf.global_variables_initializer())
    computed_s, computed_A_dep = session.run([var_s, A_dep])
print('computed_s = %s, computed_A_dep = %d' % (computed_s, computed_A_dep,))



# Part III
A_var = tf.Variable(B)
init_A_var_op = tf.assign(A_var, A)
A_dep = tf.constant(9, tf.int32)

def loop_condition(j, A_dep):
    return j < 1
def loop_body(j, A_dep):
    with tf.control_dependencies([init_A_var_op]):
        A_dep = A_dep + 1
    return j + 1, A_dep

_, A_dep = tf.while_loop(loop_condition,
                         loop_body,
                         loop_vars = [tf.constant(0, tf.int32), A_dep],
                         parallel_iterations = 1,
                         back_prop = False)

with tf.control_dependencies([A_dep]):
    var_s = tf.svd(A_var, compute_uv = False)
with tf.Session() as session:
    session.run(tf.global_variables_initializer())
    computed_s, computed_A_dep = session.run([var_s, A_dep])
print('computed_s = %s, computed_A_dep = %d' % (computed_s, computed_A_dep,))



# Part IV
A_var = rr.ResourceVariable(B)
init_A_var_op = A_var.assign(A)
A_dep = tf.constant(8, tf.int32)

def loop_condition(j, A_dep):
    return j < 1
def loop_body(j, A_dep):
    with tf.control_dependencies([init_A_var_op]):
        A_dep = A_dep + 1
    return j + 1, A_dep

_, A_dep = tf.while_loop(loop_condition,
                         loop_body,
                         loop_vars = [tf.constant(0, tf.int32), A_dep],
                         parallel_iterations = 1,
                         back_prop = False)

with tf.control_dependencies([A_dep]):
    var_s = tf.svd(A_var.read_value(), compute_uv = False)
with tf.Session() as session:
    session.run(tf.global_variables_initializer())
    computed_s, computed_A_dep = session.run([var_s, A_dep])
print('computed_s = %s, computed_A_dep = %d' % (computed_s, computed_A_dep,))

Part I is basic setup. I create two random 10×10 matrices and compute their singular values:

singular values of A: [ 5.65906715  4.9420261   4.40626739  3.73506125  2.70703249  2.57429488
  1.73387162  1.16000494  0.58836563  0.39101954]
singular values of B: [ 7.0283055   4.65840063  4.48502098  3.25319445  2.94667168  2.74267484
  1.86004291  1.6626967   0.63884034  0.27131664]

Part II shows usage of control_dependencies() to guarantee that A has been assigned to A_var before the singular values of A_var are computed. The output from this part is:

computed_s = [ 5.65906715  4.9420261   4.40626739  3.73506125  2.70703249  2.57429488
  1.73387162  1.16000494  0.58836563  0.39101954], computed_A_dep = 11

(This is the expected result for Part II.)

In Part III, I have introduced use of a tf.while_loop(). Now, tf.svd() is returning the singular values of B:

computed_s = [ 7.0283055   4.65840063  4.48502098  3.25319445  2.94667168  2.74267484
  1.86004291  1.6626967   0.63884034  0.27131664], computed_A_dep = 10

(This is not the expected result for Part III. I expect that the singular values of A would be printed.)

In Part IV, based on reading #4663 (comment) , I switched to using ResourceVariable. However, the output is still the same (the singular values of B):

computed_s = [ 7.0283055   4.65840063  4.48502098  3.25319445  2.94667168  2.74267484
  1.86004291  1.6626967   0.63884034  0.27131664], computed_A_dep = 9

(This is not the expected result for Part IV. I expect that the singular values of A would be printed.)

It appears the issue is that tf.control_dependencies() on tensors created by tf.while_loop() might not execute the tensors' own dependencies.

This used to work okay (around TensorFlow 1.1, if I recall correctly).

While searching for a previous report of this issue, I found #6087 which appears related, in that the sample code there has a tf.while_loop() that creates tensors with dependencies. When I run the sample code, I consistently get result = 10. This is an unexpected result, in my opinion. What is happening is that update_x runs exactly once, so for each of the 5 loop iterations, x has the value 2.

I tried rewriting the sample code to use a ResourceVariable, but the output is the same:

from __future__ import division, print_function
import tensorflow as tf
from tensorflow.python.ops import resource_variable_ops as rr

with tf.variable_scope('state'):
    x = rr.ResourceVariable(tf.constant(1, dtype=tf.float32))
    update_x = x.assign(x.read_value() + 1)

def iter_fun(i, y):
    # comment the line below, the program will run without any error
    # but I need control_dependencies, or at least some way to replace it...
    with tf.control_dependencies([update_x]):
        y = y + tf.Print(x.read_value(), ['i = ', i, 'y = ', y, 'x = ', x.read_value()])
    return (i+1, y,)

with tf.variable_scope('iteration'):
    num_iterations = 5
    initial_i = tf.constant(0, dtype=tf.int32)
    initial_y = tf.constant(0, dtype=tf.float32)
    _, result = tf.while_loop(
        cond=lambda i, *_: i < num_iterations,
        body=iter_fun,
        loop_vars=(initial_i, initial_y))

init_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init_op)
    print(sess.run(result))
@shivaniag shivaniag added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jan 11, 2018
@shivaniag
Copy link
Contributor

@ebrevdo Do you have any thoughts on this or know who would know it best?

@ebrevdo
Copy link
Contributor

ebrevdo commented Jan 11, 2018

Can you write a much smaller minimal failing test?

@ebrevdo
Copy link
Contributor

ebrevdo commented Jan 11, 2018

oh wait; i see you did!

@ebrevdo
Copy link
Contributor

ebrevdo commented Jan 11, 2018

Try passing parallel_iterations=1 to your while_loop call.

@ghost
Copy link
Author

ghost commented Jan 11, 2018

I just now tried adding parallel_iterations=1 to my adapted #6087 test case, but this didn't fix the problem.

As for the first test case, you can take Parts I and III separately to reproduce the issue:

from __future__ import division, print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.ops import resource_variable_ops as rr

rs = np.random.RandomState(seed = 2)
A = rs.normal(size = (10, 10,))
print('singular values of A: %s' % (np.linalg.svd(A, compute_uv = False),))
B = rs.normal(size = (10, 10,))
print('singular values of B: %s' % (np.linalg.svd(B, compute_uv = False),))

A_var = tf.Variable(B)
init_A_var_op = tf.assign(A_var, A)
A_dep = tf.constant(9, tf.int32)

def loop_condition(j, A_dep):
    return j < 1
def loop_body(j, A_dep):
    with tf.control_dependencies([init_A_var_op]):
        A_dep = A_dep + 1
    return j + 1, A_dep

_, A_dep = tf.while_loop(loop_condition,
                         loop_body,
                         loop_vars = [tf.constant(0, tf.int32), A_dep],
                         parallel_iterations = 1,
                         back_prop = False)

with tf.control_dependencies([A_dep]):
    var_s = tf.svd(A_var, compute_uv = False)
with tf.Session() as session:
    session.run(tf.global_variables_initializer())
    computed_s, computed_A_dep = session.run([var_s, A_dep])
print('computed_s = %s, computed_A_dep = %d' % (computed_s, computed_A_dep,))

(Alternatively, take Parts I and IV separately.)

computed_s is the vector of singular values of B, whereas I am expecting that it will be the singular values of A.

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jan 23, 2018
@tensorflowbutler
Copy link
Member

A member of the TensorFlow organization has replied after the stat:awaiting tensorflower label was applied.

@reedwm
Copy link
Member

reedwm commented Jan 30, 2018

/CC @alextp can you take a look?

@reedwm reedwm added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jan 30, 2018
@alextp
Copy link
Contributor

alextp commented Jan 30, 2018

Can you add a print(tf.get_default_graph().as_graph_def()) so I can understand how we're generating the wrong graph? (i.e. are control dependencies missing, mangled, or ignored because of weird variable stuff)

@alextp alextp added stat:awaiting response Status - Awaiting response from author and removed stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Jan 30, 2018
@ghost
Copy link
Author

ghost commented Jan 30, 2018

I created a gist for the graph def: https://gist.github.com/dtrebbien/f917cb2891e0b141b9fa6323a3c55239

Here is the exact modified test case that I used to print the graph def:

from __future__ import division, print_function
import numpy as np
import tensorflow as tf

rs = np.random.RandomState(seed = 2)
A = rs.normal(size = (10, 10,))
print('singular values of A: %s' % (np.linalg.svd(A, compute_uv = False),))
B = rs.normal(size = (10, 10,))
print('singular values of B: %s' % (np.linalg.svd(B, compute_uv = False),))

graph = tf.Graph()
with graph.as_default():
    A_var = tf.Variable(B, name = 'A_var')
    init_A_var_op = tf.assign(A_var, A, name = 'init_A_var_op')
    A_dep = tf.constant(9, tf.int32, name = 'initial_A_dep')

    def loop_condition(j, A_dep):
        return j < 1
    def loop_body(j, A_dep):
        with tf.control_dependencies([init_A_var_op]):
            A_dep = tf.add(A_dep, 1, name = 'increment_A_dep')
        return j + 1, A_dep

    _, A_dep = tf.while_loop(loop_condition,
                             loop_body,
                             loop_vars = [tf.constant(0, tf.int32), A_dep],
                             parallel_iterations = 1,
                             back_prop = False)

    with tf.control_dependencies([A_dep]):
        var_s = tf.svd(A_var, compute_uv = False)

print(graph.as_graph_def())

with tf.Session(graph = graph) as session:
    session.run(tf.global_variables_initializer())
    computed_A_dep, computed_s, computed_A_dep2 = session.run([A_dep, var_s, A_dep])
    assert computed_A_dep == computed_A_dep2
print('computed_s = %s, computed_A_dep = %d' % (computed_s, computed_A_dep,))

By the way, is there a utility that can graphically display this?

@ghost
Copy link
Author

ghost commented Jan 30, 2018

Looking at the graph def, it looks like no node has control input "^init_A_var_op".

Contrasting that with the following working script—which does not use a tf.while_loop()—the "increment_A_dep/y" node corresponding to the const second argument to the "increment_A_dep" tf.add() op has control input "^init_A_var_op":

from __future__ import division, print_function
import numpy as np
import tensorflow as tf

rs = np.random.RandomState(seed = 2)
A = rs.normal(size = (10, 10,))
print('singular values of A: %s' % (np.linalg.svd(A, compute_uv = False),))
B = rs.normal(size = (10, 10,))
print('singular values of B: %s' % (np.linalg.svd(B, compute_uv = False),))

graph = tf.Graph()
with graph.as_default():
    A_var = tf.Variable(B, name = 'A_var')
    init_A_var_op = tf.assign(A_var, A, name = 'init_A_var_op')
    A_dep = tf.constant(9, tf.int32, name = 'initial_A_dep')

    with tf.control_dependencies([init_A_var_op]):
        A_dep = tf.add(A_dep, 1, name = 'increment_A_dep')

    with tf.control_dependencies([A_dep]):
        var_s = tf.svd(A_var, compute_uv = False)

print(graph.as_graph_def())

with tf.Session(graph = graph) as session:
    session.run(tf.global_variables_initializer())
    computed_A_dep, computed_s, computed_A_dep2 = session.run([A_dep, var_s, A_dep])
    assert computed_A_dep == computed_A_dep2
print('computed_s = %s, computed_A_dep = %d' % (computed_s, computed_A_dep,))

Here is an excerpt from the working script's graph def:

node {
  name: "initial_A_dep"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 9
      }
    }
  }
}
node {
  name: "increment_A_dep/y"
  op: "Const"
  input: "^init_A_var_op"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 1
      }
    }
  }
}
node {
  name: "increment_A_dep"
  op: "Add"
  input: "initial_A_dep"
  input: "increment_A_dep/y"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "Svd"
  op: "Svd"
  input: "A_var/read"
  input: "^increment_A_dep"
  attr {
    key: "T"
    value {
      type: DT_DOUBLE
    }
  }
  attr {
    key: "compute_uv"
    value {
      b: false
    }
  }
  attr {
    key: "full_matrices"
    value {
      b: false
    }
  }
}

The non-working graph's "while/increment_A_dep/y" node has control input "^while/Identity" but not "^init_A_var_op".

@alextp
Copy link
Contributor

alextp commented Jan 30, 2018

Ok, so this is a real bug. @asimshankar who do you think we should assign this to?

There's a bug somewhere in the control dependency processing logic of WhileContext, somewhere around here most likely:

def _MaybeAddControlDependency(self, op):

@asimshankar
Copy link
Contributor

@skye was looking into something similar, I think.
@skye - let me know if I'm mistaken and will try to find someone else.

@asimshankar asimshankar added type:bug Bug stat:awaiting tensorflower Status - Awaiting response from tensorflower and removed stat:awaiting response Status - Awaiting response from author labels Jan 30, 2018
@case540 case540 closed this as completed in f8f921c Feb 6, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower type:bug Bug
Projects
None yet
Development

No branches or pull requests

7 participants