Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Frame of the Variable #4478

Closed
link-er opened this issue Sep 19, 2016 · 14 comments
Closed

Frame of the Variable #4478

link-er opened this issue Sep 19, 2016 · 14 comments
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower

Comments

@link-er
Copy link

link-er commented Sep 19, 2016

Hello

I am using tf.scan for implementing memory augmented network and when trying to run tf.initialize_all_variables() getting the error about frame of the variables:

All inputs to node scan_1/while/Variable_13/Assign must be from the same frame.

As I understood this frame is some inner identification of the variables, so what should I do to get rid of this error? Scan works if I do not define any variables outside the step, but I need at least weights-biases and loss-accuracy variables.

@link-er
Copy link
Author

link-er commented Sep 20, 2016

@ebrevdo Hi! I saw that you solved same problem for while_loop in #3114, maybe here just the same fix is needed?

@jmchen-g jmchen-g added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Sep 20, 2016
@yaroslavvb
Copy link
Contributor

It sounds like you are trying to use external tensors in scan (ie,
variables that are not passed as arguments)

On Mon, Sep 19, 2016 at 10:59 PM, Linara Adilova notifications@github.com
wrote:

@ebrevdo https://github.com/ebrevdo Hi! I saw that you solved same
problem for while_loop in #3114
#3114, maybe here just
the same fix is needed?


You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub
#4478 (comment),
or mute the thread
https://github.com/notifications/unsubscribe-auth/AABaHNG-DCv0Eim56u8zMAGe_t9i1FVKks5qr3ZagaJpZM4KBDCg
.

@link-er
Copy link
Author

link-er commented Sep 20, 2016

Yes, I am, because I have weights variables that need to be initialized outside.
In my loop I just iterate over the input sequence, but I do not need to iterate weights and I cannot initialize them every step - what is the way to solve this?
If needed, I can provide code (at least critical parts of it)

@ebrevdo
Copy link
Contributor

ebrevdo commented Sep 20, 2016

Please provide the critical part of the code.

On Sep 20, 2016 1:07 PM, "Linara Adilova" notifications@github.com wrote:

Yes, I am, because I have weights variables that need to be initialized
outside.
In my loop I just iterate over the input sequence, but I do not need to
iterate weights and I cannot initialize them every step - what is the way
to solve this?
If needed, I can provide code (at least critical parts of it)


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
#4478 (comment),
or mute the thread
https://github.com/notifications/unsubscribe-auth/ABtim4WoEhF4_GRyv1zA6G0QeEhrjaPqks5qsD0PgaJpZM4KBDCg
.

@ebrevdo
Copy link
Contributor

ebrevdo commented Sep 20, 2016

p.s. is this something for which you can use tf.get_variable?

On Sep 20, 2016 2:15 PM, wrote:

Please provide the critical part of the code.

On Sep 20, 2016 1:07 PM, "Linara Adilova" notifications@github.com wrote:

Yes, I am, because I have weights variables that need to be initialized
outside.
In my loop I just iterate over the input sequence, but I do not need to
iterate weights and I cannot initialize them every step - what is the way
to solve this?
If needed, I can provide code (at least critical parts of it)


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
#4478 (comment),
or mute the thread
https://github.com/notifications/unsubscribe-auth/ABtim4WoEhF4_GRyv1zA6G0QeEhrjaPqks5qsD0PgaJpZM4KBDCg
.

@link-er
Copy link
Author

link-er commented Sep 21, 2016

Please provide the critical part of the code.

def step(prev_step, cur_input):
        layer_1 = tf.add(tf.matmul(cur_input, weights['h1']), biases['b1'])
        layer_1 = tf.nn.sigmoid(layer_1)

        layer_2 = tf.add(tf.matmul(layer_1, weights['h2']), biases['b2'])
        layer_2 = tf.nn.sigmoid(layer_2)

        layer_3 = tf.add(tf.matmul(layer_2, weights['h3']), biases['b3'])
        layer_3 = tf.nn.sigmoid(layer_3)

        key = tf.nn.tanh(tf.add(tf.matmul(layer_3, weights['out']), biases['out']))

        # calculate alpha gateway depending on the key
        alpha = tf.nn.softmax(tf.add(tf.matmul(key, weights['alpha']), biases['alpha'])) # (batch_size, 1)
        # calculate write weights as combination of previously read rows and least used rows
        w_weights = tf.sigmoid(alpha) * prev_step[1] # (batch_size, mem_height) 

    # n-th smallest element's index in previous usage_weights
    nth_smallest = tf.nn.top_k(-1*prev_step[2], k=number_of_reads, sorted=True)[1] # (batch_size, 1)

    linear_index = tf.reshape(nth_smallest, [batch_size]) + (mem_height * tf.range(0,batch_size))
    linear_w_weights = tf.reshape(w_weights, [-1])
    ref = tf.Variable(linear_w_weights, trainable=False)
    least_used_update = tf.reshape((1 - tf.sigmoid(alpha)), [batch_size])
    w_weights = tf.stop_gradient(tf.reshape(tf.scatter_add(ref, linear_index, least_used_update, use_locking=True), 
                           [batch_size, mem_height]))

    # put to 0 least used row in memory
    linear_memory = tf.reshape(prev_step[0], [batch_size*mem_height, mem_width])
    ref1 = tf.Variable(linear_memory, trainable=False)
    memory = tf.stop_gradient(tf.reshape(tf.scatter_update(ref1, linear_index, tf.zeros([batch_size, mem_width]), use_locking=True),
                        (batch_size,mem_height,mem_width))) # (batch_size,mem_height,mem_width)
    # update memory state with write weights
    # (batch_size, mem_height, mem_width)
    memory = memory + \
            tf.batch_matmul(tf.reshape(w_weights, (batch_size,mem_height,1)), tf.reshape(key, (batch_size,1,mem_width)))
        normed_key = key / tf.sqrt(tf.reduce_sum(tf.square(key))) # (batch_size, mem_width)
        # (batch_size, mem_height, mem_width)
        normed_memory = tf.div(memory, tf.sqrt(tf.reduce_sum(tf.square(memory), 1, keep_dims=True)))
        # calculate similarities to each memory row
        # (batch_size, mem_height)
        similarity = tf.reshape(tf.batch_matmul(normed_memory, tf.reshape(normed_key, (batch_size,mem_width,1))),
                                (batch_size,mem_height))
        # calculate read weights as softmax probability distribution
        r_weights = tf.nn.softmax(similarity) # (batch_size, mem_height)
        # retrieve memory
        retrieved_memory = tf.reshape(tf.batch_matmul(tf.reshape(r_weights, (batch_size,1,mem_height)), memory),
                                      (batch_size,mem_width)) # (batch_size, mem_width)
        # put retrieved memory through output layer to get prediction
        do_output = tf.add(tf.matmul(retrieved_memory, weights['do']), biases['do'])
        prediction = tf.nn.softmax(do_output) # (batch_size, n_classes)
        # calculate usage weights
        u_weights = gamma * prev_step[2] + r_weights + w_weights # (batch_size, mem_height)

        return (memory, r_weights, u_weights, prediction) 

def model(elems):
        _, _, _, predictions = tf.scan(step, elems, 
                                    initializer=(np.zeros((batch_size, mem_height, mem_width)).astype(np.float32), 
                                      np.random.rand(batch_size, mem_height).astype(np.float32),
                                      np.random.rand(batch_size, mem_height).astype(np.float32),
                                      np.zeros((batch_size, n_classes)).astype(np.float32)), 
            parallel_iterations=1, back_prop=True, swap_memory=False)
        return predictions

weights = {
  'h1': tf.Variable(tf.random_normal([n_input, n_hidden_1],stddev=stddev)),
  'h2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2],stddev=stddev)),
  'h3': tf.Variable(tf.random_normal([n_hidden_2, n_hidden_3],stddev=stddev)),
  'out': tf.Variable(tf.random_normal([n_hidden_3, mem_width],stddev=stddev)),
  'do': tf.Variable(tf.random_normal([mem_width, n_classes],stddev=stddev)),
  'alpha': tf.Variable(tf.random_normal([mem_width, 1],stddev=stddev))
}
biases = {
  'b1': tf.Variable(tf.random_normal([n_hidden_1])),
  'b2': tf.Variable(tf.random_normal([n_hidden_2])),
  'b3': tf.Variable(tf.random_normal([n_hidden_3])),
  'out': tf.Variable(tf.random_normal([mem_width])),
  'do': tf.Variable(tf.random_normal([n_classes])),
  'alpha': tf.Variable(tf.random_normal([1]))
}

input_data = tf.placeholder("float32",[n_classes*n_examples, batch_size, n_input], name='input_data')
labels = tf.placeholder("float32",[n_classes*n_examples, batch_size, n_classes],name='labels')

predictions = model(input_data)

And after this calling initialize_all_variables() to run causes that error

p.s. is this something for which you can use tf.get_variable?

What exactly this call will do?

@osdf
Copy link

osdf commented Oct 21, 2016

I'm having the same error message that is reported in the issue(s). I have a RNN like structure that has some building blocks (component neural networks) that are passed in by the user. Here is a minimal example:

import tensorflow as tf
tf.reset_default_graph()

def initialize(shape):
    init = tf.random_normal(shape, mean=0, stddev=0.1, dtype=tf.float32)
    return init

def test_rnn_with_external(input, hiddens, external_fct):
    """
    A simple rnn that makes the standard update, then
    feeds the new hidden state through some external
    function.
    """
    dim_in = input.get_shape().as_list()[-1]
    btsz = input.get_shape().as_list()[1]
    shape = (dim_in + hiddens, hiddens)
    _init = initialize(shape)
    W = tf.get_variable("rnn_w", initializer=_init)
    _init = tf.zeros([hiddens])
    b = tf.get_variable("rnn_b", initializer=_init)

    def _step(previous, input):
        concat = tf.concat(1, [input, previous])     
        h_t = tf.tanh(tf.add(tf.matmul(concat, W), b))

        h_t = external_fct(h_t)

        return h_t

    h_0 = tf.zeros([btsz, hiddens])
    states = tf.scan(_step,
                     input,
                     initializer=h_0,
                     name="states")
    return states

# the external function, relying on the templating mechanism.
def ext_fct(hiddens):
    """
    """
    def tmp(input):
        shape = (hiddens, hiddens)
        _init = initialize(shape)
        W = tf.get_variable("ext_w", initializer=_init)
        b = 0
        return tf.add(tf.matmul(input, W), b, name="external")
    return tf.make_template(name_="external_fct", func_=tmp)

# run from here on
t = 5
btsz = 4
dim = 2
hiddens = 3

x = tf.placeholder(tf.float32, shape=(t, btsz, dim))
ext = ext_fct(hiddens)

states = test_rnn_with_external(x, hiddens, external_fct=ext)

sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())

with the error ending in:

InvalidArgumentError: All inputs to node external_fct/ext_w/Assign must be from the same frame.

With Frame, I would associate an area on the stack. So I thought that maybe tf.make_template does something very wired, and thus it is not useable here. The external function can be rewritten a bit and then called more directly, like so:

import tensorflow as tf
tf.reset_default_graph()

def initialize(shape):
    init = tf.random_normal(shape, mean=0, stddev=0.1, dtype=tf.float32)
    return init

def test_rnn_with_external(input, hiddens, external_fct):
    dim_in = input.get_shape().as_list()[-1]
    btsz = input.get_shape().as_list()[1]
    shape = (dim_in + hiddens, hiddens)
    _init = initialize(shape)
    W = tf.get_variable("rnn_w", initializer=_init)
    _init = tf.zeros([hiddens])
    b = tf.get_variable("rnn_b", initializer=_init)

    def _step(previous, input):
        """
        """
        concat = tf.concat(1, [input, previous])     
        h_t = tf.tanh(tf.add(tf.matmul(concat, W), b))

        h_t = external_fct(h_t, hiddens)

        return h_t

    h_0 = tf.zeros([btsz, hiddens])
    states = tf.scan(_step,
                     input,
                     initializer=h_0,
                     name="states")
    return states

def ext_fct_new(input, hiddens):
    """
    """
    shape = (hiddens, hiddens)
    _init = initialize(shape)
    W = tf.get_variable("ext_w_new", initializer=_init)
    b = 0
    return tf.add(tf.matmul(input, W), b, name="external_new")

t = 5
btsz = 4
dim = 2
hiddens = 3
x = tf.placeholder(tf.float32, shape=(t, btsz, dim))

states = test_rnn_with_external(x, hiddens, external_fct=ext_fct_new)

sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())

However, still the same error InvalidArgumentError: All inputs to node ext_w_new/Assign must be from the same frame.

Of course, moving contents of the external function into the _step part (and tf.get_variableing before) works. But then the flexibility (necessary in the original code) is gone.

What am I doing wrong?

@osdf
Copy link

osdf commented Oct 25, 2016

This is probably related to #2211

@ebrevdo
Copy link
Contributor

ebrevdo commented Oct 25, 2016

Can you do this without the templating mechanism?

On Oct 21, 2016 2:55 PM, "Christian" notifications@github.com wrote:

I'm having the same error message that is reported in the issue(s). I have
a RNN like structure that has some building blocks (component neural
networks) that are passed in by the user. Here is a minimal example:

import tensorflow as tf
tf.reset_default_graph()

def initialize(shape):
init = tf.random_normal(shape, mean=0, stddev=0.1, dtype=tf.float32)
return init

def test_rnn_with_external(input, hiddens, external_fct):
"""
A simple rnn that makes the standard update, then
feeds the new hidden state through some external
function.
"""
dim_in = input.get_shape().as_list()[-1]
btsz = input.get_shape().as_list()[1]
shape = (dim_in + hiddens, hiddens)
_init = initialize(shape)
W = tf.get_variable("rnn_w", initializer=_init)
_init = tf.zeros([hiddens])
b = tf.get_variable("rnn_b", initializer=_init)

def _step(previous, input):
    concat = tf.concat(1, [input, previous])
    h_t = tf.tanh(tf.add(tf.matmul(concat, W), b))

    h_t = external_fct(h_t)

    return h_t

h_0 = tf.zeros([btsz, hiddens])
states = tf.scan(_step,
                 input,
                 initializer=h_0,
                 name="states")
return states

the external function, relying on the templating mechanism.

def ext_fct(hiddens):
"""
"""
def tmp(input):
shape = (hiddens, hiddens)
init = initialize(shape)
W = tf.get_variable("ext_w", initializer=init)
b = 0
return tf.add(tf.matmul(input, W), b, name="external")
return tf.make_template(name
="external_fct", func
=tmp)

run from here on

t = 5
btsz = 4
dim = 2
hiddens = 3

x = tf.placeholder(tf.float32, shape=(t, btsz, dim))
ext = ext_fct(hiddens)

states = test_rnn_with_external(x, hiddens, external_fct=ext)

sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())

with the error ending in:

InvalidArgumentError: All inputs to node external_fct/ext_w/Assign must be from the same frame.


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
#4478 (comment),
or mute the thread
https://github.com/notifications/unsubscribe-auth/ABtim27Ix4gaRbkMdbrS6mw9SglCr7iJks5q2TS5gaJpZM4KBDCg
.

@osdf
Copy link

osdf commented Oct 25, 2016

No, does not work without templating either. The second code part in the previous post tries this, leading to the same error message.

@ebrevdo
Copy link
Contributor

ebrevdo commented Oct 25, 2016

Yuan pushed improved error messaging today. Can you try using a fresh build
of tensorflow tomorrow and report back the more informative error messages?

On Oct 25, 2016 2:11 PM, "Christian" notifications@github.com wrote:

No, does not work without templating either. The second code part in the
previous post tries this, leading to the same error message.


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
#4478 (comment),
or mute the thread
https://github.com/notifications/unsubscribe-auth/ABtim5R7AQ7ygprwKcS9m1qh_vETglJ1ks5q3nCIgaJpZM4KBDCg
.

@akosiorek
Copy link

akosiorek commented Nov 17, 2016

I have the same problem. Passing an initializer function instead of a tensor/constant as tf.get_variable(name, shape, initializer=init_fun) works as a work-around, though.

@osdf
Copy link

osdf commented May 15, 2017

@akosiorek Thank you! I didn't try your approach, but a similar one suggested by @mrry here: http://stackoverflow.com/questions/42564698/invalidargumenterror-the-node-has-inputs-from-different-frames. A tf.constant_initializer resolves the described problem.

@itsmeolivia
Copy link
Contributor

Automatically closing due to lack of recent activity. Since this issue is old at this point, please reopen the issue if it still occurs when tried with the latest version of Tensorflow. Thank you.

ti250 added a commit to TrMPS/MPS-MNIST that referenced this issue Jun 29, 2017
Will now construct while loops, but cannot actually run them due to
tensorflow.python.framework.errors_impl.InvalidArgumentError: The node
'while_1/Variable/Assign' has inputs from different frames. The input
'while_1/Size' is in frame 'while_1/while_1/'. The input
'while_1/Variable' is in frame ‘’., seems to be related to
tensorflow/tensorflow#4478, but I don’t
believe I’m passing in external tensors. Once this is fixed, however,
should be able to apply same structure to a left-propagating update of
the weights, after which the whole thing should (?) be done.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower
Projects
None yet
Development

No branches or pull requests

7 participants