Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force tensor evaluation inside while-loop, scan and others. #13616

Closed
awav opened this issue Oct 10, 2017 · 30 comments
Closed

Force tensor evaluation inside while-loop, scan and others. #13616

awav opened this issue Oct 10, 2017 · 30 comments
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower

Comments

@awav
Copy link

awav commented Oct 10, 2017

Hello everyone,

I had big plans for tf.while_loop until I discovered that it is impossible to re-evaluate tensor inside it. Let's dive into the the issue and potential useful feature:

In [13]: import tensorflow as tf
    ...: import numpy as np
    ...:
    ...: def cond(i, _x, _sq):
    ...:       return tf.less(i, 10)
    ...:
    ...: def gen_body(v):
    ...:     def body(i, x, sq):
    ...:         x = tf.Print(x, [x, sq], "x and sq: ")
    ...:         v_assign = v.assign(x + 1)
    ...:         v_assign = tf.Print(v_assign, [v_assign], "v_assign: ")
    ...:         with tf.control_dependencies([v_assign]):
    ...:             sq_neg = tf.negative(sq)
    ...:         sq_neg = tf.Print(sq_neg, [i, sq_neg], message='i and sq_neg:')
    ...:         return tf.add(i, 1), sq_neg, sq
    ...:     return body
    ...:

In [14]: sess = tf.InteractiveSession()

In [15]: i = tf.Variable(0)
    ...: v = tf.Variable(2)
    ...: sq = tf.square(v)
    ...: l = tf.while_loop(cond, gen_body(v), (i, v, sq))
    ...: sess.run(tf.global_variables_initializer())
    ...: sess.run(l)
    ...:
2017-10-10 22:59:44.819271: I tensorflow/core/kernels/logging_ops.cc:79] x and sq: [2][4]
2017-10-10 22:59:44.819405: I tensorflow/core/kernels/logging_ops.cc:79] v_assign: [3]
2017-10-10 22:59:44.819466: I tensorflow/core/kernels/logging_ops.cc:79] i and sq_neg:[0][-4]
2017-10-10 22:59:44.819553: I tensorflow/core/kernels/logging_ops.cc:79] x and sq: [-4][4]
2017-10-10 22:59:44.819615: I tensorflow/core/kernels/logging_ops.cc:79] v_assign: [-3]
2017-10-10 22:59:44.819680: I tensorflow/core/kernels/logging_ops.cc:79] i and sq_neg:[1][-4]
2017-10-10 22:59:44.819827: I tensorflow/core/kernels/logging_ops.cc:79] x and sq: [-4][4]
2017-10-10 22:59:44.819885: I tensorflow/core/kernels/logging_ops.cc:79] v_assign: [-3]
2017-10-10 22:59:44.819932: I tensorflow/core/kernels/logging_ops.cc:79] i and sq_neg:[2][-4]
2017-10-10 22:59:44.820034: I tensorflow/core/kernels/logging_ops.cc:79] x and sq: [-4][4]
2017-10-10 22:59:44.820094: I tensorflow/core/kernels/logging_ops.cc:79] v_assign: [-3]
2017-10-10 22:59:44.820111: I tensorflow/core/kernels/logging_ops.cc:79] i and sq_neg:[3][-4]
2017-10-10 22:59:44.820162: I tensorflow/core/kernels/logging_ops.cc:79] x and sq: [-4][4]
2017-10-10 22:59:44.820250: I tensorflow/core/kernels/logging_ops.cc:79] v_assign: [-3]
2017-10-10 22:59:44.820265: I tensorflow/core/kernels/logging_ops.cc:79] i and sq_neg:[4][-4]
2017-10-10 22:59:44.820315: I tensorflow/core/kernels/logging_ops.cc:79] x and sq: [-4][4]
2017-10-10 22:59:44.820379: I tensorflow/core/kernels/logging_ops.cc:79] v_assign: [-3]
2017-10-10 22:59:44.820408: I tensorflow/core/kernels/logging_ops.cc:79] i and sq_neg:[5][-4]
2017-10-10 22:59:44.820428: I tensorflow/core/kernels/logging_ops.cc:79] x and sq: [-4][4]
2017-10-10 22:59:44.820438: I tensorflow/core/kernels/logging_ops.cc:79] v_assign: [-3]
2017-10-10 22:59:44.820446: I tensorflow/core/kernels/logging_ops.cc:79] i and sq_neg:[6][-4]
2017-10-10 22:59:44.820464: I tensorflow/core/kernels/logging_ops.cc:79] x and sq: [-4][4]
2017-10-10 22:59:44.820490: I tensorflow/core/kernels/logging_ops.cc:79] v_assign: [-3]
2017-10-10 22:59:44.820500: I tensorflow/core/kernels/logging_ops.cc:79] i and sq_neg:[7][-4]
2017-10-10 22:59:44.820519: I tensorflow/core/kernels/logging_ops.cc:79] x and sq: [-4][4]
2017-10-10 22:59:44.820532: I tensorflow/core/kernels/logging_ops.cc:79] v_assign: [-3]
2017-10-10 22:59:44.820542: I tensorflow/core/kernels/logging_ops.cc:79] i and sq_neg:[8][-4]
2017-10-10 22:59:44.820559: I tensorflow/core/kernels/logging_ops.cc:79] x and sq: [-4][4]
2017-10-10 22:59:44.820580: I tensorflow/core/kernels/logging_ops.cc:79] v_assign: [-3]
2017-10-10 22:59:44.820593: I tensorflow/core/kernels/logging_ops.cc:79] i and sq_neg:[9][-4]
Out[15]: (10, -4, 4)

I created v variable and tensor sq which equals to v^2. In fact, we don't have control over them, they are our input as x and y and we know that y depends on x. I would like to assign new value inside TensorFlow loop to x (equavalent to v at example) and evaluate fresh y (sq inside example) at each iteration of the loop. Meanwhile we can do other evaluations inside while_loop, but most important is that I need to update x and get updated y. Currently, assigning operation doesn't propagate updates down to the dependant nodes and it shouldn't, but when someone calls tensor depending on value which were updated via assign inside while_loop, I suppose the tensor node must detect this change and evaluate new tensor value again.

Thanks!

@tatatodd
Copy link
Contributor

@skye Can you comment on this?

@tatatodd tatatodd added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Oct 11, 2017
@skye
Copy link
Member

skye commented Oct 11, 2017

Let me make sure I understand what you're requesting: you would like a tensor computed from variables to have a new updated value every time it's executed in a while loop.

This is a reasonable expectation. However, I don't think we'll change this behavior. Right now, only tensors defined inside the loop will be evaluated every loop iteration. All tensors defined outside a loop will be evaluated exactly once. So you would have to write something like:

import tensorflow as tf
import numpy as np

def cond(i, _x):
      return tf.less(i, 10)

def gen_body(v):
    def body(i, x):
        x = tf.Print(x, [x], "x: ")
        v_assign = v.assign(x + 1)
        v_assign = tf.Print(v_assign, [v_assign], "v_assign: ")
        with tf.control_dependencies([v_assign]):
          sq = tf.square(v)  # <---- put 'sq' definition inside loop
          sq = tf.Print(sq, [sq], "sq: ")
          sq_neg = tf.negative(sq)
        sq_neg = tf.Print(sq_neg, [i, sq_neg], message='i and sq_neg:')
        return tf.add(i, 1), sq_neg
    return body

sess = tf.InteractiveSession()

v = tf.Variable(2)
l = tf.while_loop(cond, gen_body(v), (1, v))
sess.run(tf.global_variables_initializer())
sess.run(l)

Note that I moved the sq = ... inside the loop (and also made it depend on v_assign, to make sure we pick up the new assignment).

Does this help?

@skye
Copy link
Member

skye commented Oct 11, 2017

@ebrevdo @mrry FYI

@skye skye added stat:awaiting response Status - Awaiting response from author and removed stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Oct 11, 2017
@ebrevdo
Copy link
Contributor

ebrevdo commented Oct 11, 2017

I believe if you want a variable to be evaluated at each value of the loop, you should access it using variable.ref

@ebrevdo
Copy link
Contributor

ebrevdo commented Oct 11, 2017

or perhaps if you use a ResourceVariable (get_variable(..., use_resource=True)) it may do the right thing.

you may need to also set parallel_iterations=1, but probably not.

@ebrevdo
Copy link
Contributor

ebrevdo commented Oct 11, 2017

This is an optimization that we turned on because you don't want a Variable sitting on a parameter server elsewhere in the datacenter to be accessed at each iteration of an RNN. you want to access it only once and cache it. in this case you have an assign inside the while_loop and converting a Variable to a Tensor doesn't look to see if an assign has been performed just prior to this, to know it must refresh the value. @mrry should we reset the caching identity tensor of a Variable after each assign?

@ebrevdo
Copy link
Contributor

ebrevdo commented Oct 11, 2017

@skye also, wdyt?

@skye
Copy link
Member

skye commented Oct 11, 2017

You're more familiar with these semantics than I am, but changing this might break (or at least regress) existing models, no?

@ebrevdo
Copy link
Contributor

ebrevdo commented Oct 11, 2017

Not sure. Depends how many users perform a tf.assign/tf.assign_{add,sub} inside a while_loop (including applying an optimiser) and read the new value inside the while loop -- and what behavior they expect; and may be getting silent errors or models not converging because of this issue.

If this already just "works" with ResourceVariable, then that's a sign that the Variable semantics are just broken.

@mrry
Copy link
Contributor

mrry commented Oct 11, 2017

should we reset the caching identity tensor of a Variable after each assign?

I think this gets tricky because (from the perspective of a tf.Variable object) only some subset of the Variable.assign() (etc.) calls run in an any particular step, and so you can't use control dependencies to ensure that the identity is ordered after an assign. (Perhaps I am being too pessimistic and you have a more cunning solution in mind?)

If this already just "works" with ResourceVariable, then that's a sign that the Variable semantics are just broken.

+1. I'd strongly encourage people to switch to ResourceVariable in new code, and it's more compatible with other recent features like tf.data and some of the Estimator variants. I'm not sure what it would take to switch the default, however.

@awav
Copy link
Author

awav commented Oct 11, 2017

Thank you for your responses. Here is some clarifications to the task:

  • There are a variable x and a tensor y - the result of computation upon x. y = tf.square(x). They created outside of the loop.
  • I pass implicitly to the while_loop both tensors - x and y
  • I assign new value to x inside loop (let's say x.assign(x + 1)) and expect y gets evaluated with new x value (x + 1) when it is used somewhere after assigning. So, the aim is to re-evaluate y at each iteration along with x assigning.

edited

@skye, so sq cannot be recreated inside while_loop, because we don't know what it does.

@aselle aselle removed the stat:awaiting response Status - Awaiting response from author label Oct 11, 2017
@awav
Copy link
Author

awav commented Oct 11, 2017

I can be wrong, but in example below I never got different values for assign operation b and a variable outputs. I'm sure it can be a coincidence. Anyway, I got different values for c, and it is obvious because control_dependencies does not work in that case. I guess, that it would be useful to be able to construct an operation like read_value(), but for any tensor, so that whenever it is called it would re-evaluate that tensor.

In [91]: tf.reset_default_graph()
    ...: a = tf.get_variable('a', initializer=2, use_resource=True)
    ...: c = a + 1
    ...: b = a.assign(a * 10)
    ...: with tf.control_dependencies([b]):
    ...:     c = c # <<<<<<<<<<<<<<< 
    ...:              # here should be `c.eval_tensor()` or `c.read_tensor()`, 
    ...:              # which constructs an operation being equivalent
    ...:              # to read_value() for variables.
    ...:              # ***Edited***: tf.identity() doesn't work here!
    ...: sess = tf.InteractiveSession()
    ...: sess.run(tf.global_variables_initializer())
    ...: sess.run([a, b, c])
...
Out[91]: [20, 20, 3]   # <<< 3 
...
Out[92]: [20, 20, 21]   # <<< 21

@tatatodd tatatodd added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Oct 12, 2017
@awav
Copy link
Author

awav commented Oct 12, 2017

@skye, @mrry, @ebrevdo I'm sorry for bombarding you with multiple messages, here is a proof that assigning inside loop is not a problem:

import tensorflow as tf
import numpy as np

def cond(i, _x):
      return tf.less(i, 10000)

def gen_body(x):
    def body(i, xa_prev):
        xa = x.assign(xa_prev + 1)
        with tf.control_dependencies([xa]):
            with tf.control_dependencies([tf.assert_equal(x, xa)]):
                i = tf.add(i, 1)
        return i, xa
    return body

tf.reset_default_graph()
i = tf.get_variable('i', initializer=0, use_resource=True)
v = tf.get_variable('v', initializer=0, use_resource=True)
sess = tf.InteractiveSession()
loop = tf.while_loop(cond, gen_body(v), [i, v])
sess.run(tf.global_variables_initializer())
sess.run([loop])

Out[203]: [(10000, 10000)]

The issue is that there is no way to recompute tensors which depend on it and were constructed outside of the while_loop. :)

@dchatterjee172
Copy link

dchatterjee172 commented Oct 12, 2017

I faced the same problem,

def body(rawstate,x,v,coef):
    rawstate.assign_add(tf.matmul(x[v],coef[v]))
    v.assign_add(1)
    return [rawstate,x,v,coef]

calrawstate=tf.while_loop(cond,body,[rawstate,x,v,coef])
And I faced this error,
AttributeError: 'Tensor' object has no attribute 'assign_add'
Here rawstate and v are two variables created using tf.get_variable()

Although this problem is trivial, and can be done without using assign or assign_add inside body
but sometimes it's needed in some case or in some custom model
suppose if we want to assign some specific value in rawstate[0] it's hard...

@awav
Copy link
Author

awav commented Oct 12, 2017

@dchatterjee172, I think you have different issue. rawstate is a tensor, and by definition you can't assign value to it. If you expected to see same variable at each iteration, then it works a bit different. But it would be nice to pass variable to the loop as a reference.

@ebrevdo what is variable.ref?

@awav awav closed this as completed Oct 12, 2017
@awav awav reopened this Oct 12, 2017
@ebrevdo
Copy link
Contributor

ebrevdo commented Oct 12, 2017

After playing with this for a bit, I reached two conclusions:

  1. ResourceVariable has the same problem as Variable.
  2. Variable.ref has been hidden from view; but you can access it using a hidden API (wherein your code will break in the future) via Variable._ref() or Variable._variable.

Here's code that does what I think you want it to do:

def cond(i, _x, _sq):
  return i < 5

def gen_body(v):
  def body(i, x, sq):
    x = tf.Print(x, [x, sq], "x and sq: ")
    with tf.control_dependencies([v.assign(x + 1)]):
      v_assign = v._ref()
    v_assign = tf.Print(v_assign, [v_assign], "v_assign: ")
    with tf.control_dependencies([v_assign]):
      sq_neg = tf.negative(sq)
      sq_neg = tf.Print(sq_neg, [i, sq_neg], message='i and sq_neg:')
      return tf.add(i, 1), sq_neg, sq
  return body

sess = tf.InteractiveSession()


i = tf.get_variable("i", initializer=0)
v = tf.get_variable("v", initializer=2)
sq = tf.square(v)
l = tf.while_loop(cond, gen_body(v), (i, v, sq))
sess.run(tf.global_variables_initializer())
sess.run((l, v))

@alextp any suggestions on how one would force accessing the updated value of a ResourceVariable within a calculation? This has implications not only for while loops but also other distributed computations where one wants to access a new value from another machine after waiting for a bit.

@ebrevdo
Copy link
Contributor

ebrevdo commented Oct 12, 2017

@alextp perhaps a stateful read operator is in order?

@alextp
Copy link
Contributor

alextp commented Oct 12, 2017

It is possible to force a resource variable read to happen after an assignment, and indeed it will happen. @ebrevdo , why did resource variables not work? Do you have an example with them?

(I do not fully understand why we're using variables at all here instead of the regular while loop variable mechanism)

@awav
Copy link
Author

awav commented Oct 12, 2017

@alextp, @ebrevdo Hello, your snippet doesn't do what I need, I simplified it a bit below. Let me explain what it does.
There is function f and variable x. At each iteration step I increment x and want to re-evaluate f and take negative of it. What's good: if I take assign tensorflow operation and pass it to the next iteration (even if I'm not using it), I can successfully increment external variable x. What's bad the function - f evaluated only once, but must be evaluated with freshly assigned x as many times as number of iterations. Check example below:

Here f(x) = x^2, and f_neg=-f(x):

def cond(i, _x_prev, _f_prev):
  return i < 3

def gen_body(x, f):
  def body(i, _x_prev, _f_prev):
    x_assign = x.assign(x + 1)
    # with tf.control_dependencies([x.assign(x + 1)]):
    #   x_assign = x._ref()
    with tf.control_dependencies([x_assign]):
      f_neg = tf.negative(f)
      i = tf.add(i, 1)
      i = tf.Print(i, [i], message='>>> Iteration ')
      i = tf.Print(i, [x], message='x = ')
      i = tf.Print(i, [x_assign], message='x_assign = ')
      i = tf.Print(i, [f], message='f = ')
      i = tf.Print(i, [f_neg], message='f_neg = ')
      return i, x_assign, f_neg
  return body

tf.reset_default_graph()
sess = tf.InteractiveSession()
i = tf.get_variable("i", initializer=0)
v = tf.get_variable("v", initializer=0)
func_v = tf.square(v)
l = tf.while_loop(cond, gen_body(v, func_v), (i, v, func_v))
sess.run(tf.global_variables_initializer())
sess.run((l, v))

Output. x incremented successfully from 0 to 3, but f remains non-updated! I need updated f.

... >>> Iteration [1]
... x = [1]
... x_assign = [1]
... f = [0]
... f_neg = [0]

... >>> Iteration [2]
... x = [2]
... x_assign = [2]
... f = [0]
... f_neg = [0]

... >>> Iteration [3]
... x = [3]
... x_assign = [3]
... f = [0]
... f_neg = [0]
In [259]: tf.__version__
Out[259]: '1.2.1'

I propose to add a method to the tensor structure which will be equivalent to read_value() for standard variable. It will form operation which re-evalutes the tensor when it is required. In example above that tensor is f.

@awav
Copy link
Author

awav commented Oct 12, 2017

@alextp,

(I do not fully understand why we're using variables at all here instead of the regular while loop variable mechanism)

The task is update the variable and the dependant on it function which are defined outside of while_loop. In example above, the f is func_v tensor and x is v variable, both passed to gen_body.

@alextp
Copy link
Contributor

alextp commented Oct 13, 2017

Right. If you want the function to be computed inside the while loop you need to call it inside the while loop. Something like

def cond(i, _x_prev, _f_prev):
  return i < 3

def gen_body(x, f):
  def body(i, _x_prev, _f_prev):
    x_assign = x.assign(x + 1)
    # with tf.control_dependencies([x.assign(x + 1)]):
    #   x_assign = x._ref()
    with tf.control_dependencies([x_assign]):
      f_neg = tf.negative(f(x_assign))
      i = tf.add(i, 1)
      i = tf.Print(i, [i], message='>>> Iteration ')
      i = tf.Print(i, [x], message='x = ')
      i = tf.Print(i, [f(x_assign)], message='x_assign = ')
      i = tf.Print(i, [f(v)], message='f = ')
      i = tf.Print(i, [f_neg], message='f_neg = ')
      return i, x_assign, f_neg
  return body

tf.reset_default_graph()
sess = tf.InteractiveSession()
i = tf.get_variable("i", initializer=0)
v = tf.get_variable("v", initializer=0)
func_v = lambda v: tf.square(v)
l = tf.while_loop(cond, gen_body(v, func_v), (i, v, func_v(v)))
sess.run(tf.global_variables_initializer())
sess.run((l, v))

When you define func_v as a tensor outside the loop it has a fixed value which will stay fixed forever. If you want it to be recomputed you need to actually recompute it.

@awav
Copy link
Author

awav commented Oct 13, 2017

@alextp, that's right. In fact, it was a feature request - to be able to recompute the tensor and introduce property which generates this operation. I'm sorry if it wasn't clear from previous messages That would be super useful in while_loop-like flow statements.

The solution which you proposed doesn't work for me, as I do not have control on input tensors x and f, and I do not know what's inside f.

How hard is it to implement re-computation operation for tensor?

@alextp
Copy link
Contributor

alextp commented Oct 13, 2017 via email

@awav
Copy link
Author

awav commented Oct 13, 2017

@alextp. oh, that's bad. Okay, If I manage to pass f(x) as python function, hence at each iteration I construct new tensor y = f(x), does it mean that the size of the graph will grow linearly with number of iterations? Can an unwise usage of while_loop lead to the graph "overflow", taking into account that graph size is limited?

@alextp
Copy link
Contributor

alextp commented Oct 13, 2017 via email

@awav
Copy link
Author

awav commented Oct 13, 2017

@alextp, last question: If I create c++ operation which will accept x variable and f tensor, will I be able to manipulate an execution of f depending on changes in x. In other words, will I be able to execute f tensor each time when I change x. Thanks!

@alextp
Copy link
Contributor

alextp commented Oct 13, 2017 via email

@awav
Copy link
Author

awav commented Oct 15, 2017

Thanks all for your responses!

@alextp
Copy link
Contributor

alextp commented Nov 5, 2018 via email

@jonas-eschle
Copy link
Contributor

Ah, I see (so it's actually independent of a loop): one sess.run -> zero (if can be cached) or one evaluation of a tensor.

Thanks, @alextp!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower
Projects
None yet
Development

No branches or pull requests

9 participants