Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.unravel_index (Was: tf.argmin across all dimensions) #2075

Closed
educob opened this issue Apr 23, 2016 · 10 comments
Closed

tf.unravel_index (Was: tf.argmin across all dimensions) #2075

educob opened this issue Apr 23, 2016 · 10 comments
Labels
stat:contribution welcome Status - Contributions welcome type:feature Feature requests

Comments

@educob
Copy link

educob commented Apr 23, 2016

Hi,

tf.argmin only works in one dimension.
Let's say I have a picture which is a 2x2 array of pixels (each pixels is a 3 value array) and I want to know which pixel is closest to a certain color. In that case I have to reshape the 2x2 array to a 1-D array then get the min index, then find out that index to what 2x2 array position corresponds.

Couldn't I just as tensorflow to give me the index in the 2x2 array. For instance [1,15]?

I really don't understand the reason for having to translate between lineal index and array location all the time.

Thanks.

@girving girving changed the title Feature suggestion: tf.argmin returning min in all dimensions. tf.unravel_index (Was: tf.argmin across all dimensions) Jun 7, 2016
@girving girving added the stat:contribution welcome Status - Contributions welcome label Jun 7, 2016
@girving
Copy link
Contributor

girving commented Jun 7, 2016

I think it'd be better to handle this the way numpy does it, with np.unravel_index: https://bytes.com/topic/python/answers/509074-numpy-argmin-multidimensional-arrays. Making tf.argmin do this directly would either require a new op or an unpleasant boolean flag that indexes over everything, and it still wouldn't be what you want since most of time you want to minimize over some but not all of the dimensions (e.g., the last three dims of a 4D batched image tensor).

Cc @aselle since unravel_index sounds vaguely index related, but I'll mark this contributions welcome for now.

@jiankang1991
Copy link

where is tf.unravel_index in the document? I cannot find it.
Thank you.

@girving
Copy link
Contributor

girving commented Jun 8, 2016

Not sure which document you mean, but tf.unravel_index doesn't exist yet. np.unravel_index is described in that thread I linked to.

@jiankang1991
Copy link

Yeah. I know np.unravel_index. but can np.unravel_index be used to build the graph?
I write a simple example. How to achieve that using Tensorflow?
Thank you.

ksize = 3
stride = 1

input_image = tf.placeholder(tf.float32, name='input_image')

#conv1
kernel = tf.Variable(tf.truncated_normal([ksize, ksize, 3, 16],stddev=0.1),
                    name='kernel')
conv = tf.nn.conv2d(input_image, kernel, [1,stride,stride,1], padding='SAME')
biases = tf.Variable(tf.constant(0.0, shape = [16]), name = 'biases')
bias = tf.nn.bias_add(conv, biases)
conv1 = tf.nn.relu(bias, name='conv1')

#pool1
pool1, pool1_indices = tf.nn.max_pool_with_argmax(conv1, ksize=[1, 2, 2, 1], 
                                                  strides=[1, 2, 2, 1], 
                                                  padding='SAME', name='pool1')

#upsample by assigning the values of pool1 to the position in unpooling Tensor according to pool1_indices                                                
indices = pool1_indices
unravel_pool1_indices = np.unravel_index(indices,[4,32,32,16])
unravel_pool1_coordinates = np.array(unravel_pool1_indices)
coor_shape = np.shape(unravel_pool1_coordinates)
unravel_pool1_coordinates = np.reshape(unravel_pool1_coordinates,(coor_shape[0],coor_shape[1]*coor_shape[2]*coor_shape[3]*coor_shape[4]))
unravel_pool1_coordinates = unravel_pool1_coordinates.T

values = pool1
values = np.reshape(values,(np.size(values)))

up1 = tf.constant(0.0, shape = [4,32,32,16])
delta = tf.SparseTensor(unravel_pool1_coordinates, values, shape = [4,32,32,16])

result = up1 + tf.sparse_tensor_to_dense(delta)


with tf.Session() as session:
    session.run(tf.initialize_all_variables())
    test_image = np.random.rand(4,32,32,3)
    sess_outputs = session.run([pool1, pool1_indices],
                               {input_image.name: test_image})

@girving
Copy link
Contributor

girving commented Jun 9, 2016

Someone would have to write a TensorFlow version of np.unravel_index, which could be called tf.unravel_index. We might not do that soon, so PRs adding it would be welcome. tf.unravel_index could either be a new C++ op or something written in Python.

@ibab
Copy link
Contributor

ibab commented Jun 30, 2016

Here's a sketch of a possible Python implementation using cumprod:

import tensorflow as tf

def unravel_index(indices, shape):
  with tf.name_scope('unravel_index'):
    indices = tf.expand_dims(indices, 0)
    shape = tf.expand_dims(shape, 1)
    strides = tf.cumprod(shape, reverse=True)
    strides_shifted = tf.cumprod(shape, exclusive=True, reverse=True)
    return (indices // strides_shifted) % strides

s = tf.Session()
out = unravel_index([22, 41, 37], (7, 6))
print(s.run(out))
# ==> [[3 6 6]
#      [4 5 1]]

@aselle aselle removed the triaged label Jul 28, 2016
@aselle aselle added type:feature Feature requests and removed enhancement labels Feb 9, 2017
@girving
Copy link
Contributor

girving commented Jun 16, 2017

Automatically closing due to lack of recent activity. Please update the issue when new information becomes available, and we will reopen the issue. Thanks!

@girving girving closed this as completed Jun 16, 2017
@meijun
Copy link
Contributor

meijun commented Aug 4, 2017

@ibab The last line in the function unravel_index should be return (indices % strides) // strides_shifted.

@yongtang
Copy link
Member

I think unravel_index is quite a useful feature that could be used in many places. Created a PR #14895 to add the C++ kernel for it. Please take a look if interested.

@mehmetbasbug
Copy link

mehmetbasbug commented Jan 15, 2018

@ibab @meijun

In fact, the implementation must depend on the parity of the rank of the shape.

Following modification works for both cases and runs smoothly on cpu and gpu.

def unravel_index(indices, shape):
    indices = tf.expand_dims(indices, 0)
    shape = tf.expand_dims(shape, 1)
    shape = tf.cast(shape, tf.float32)
    strides = tf.cumprod(shape, reverse=True)
    strides_shifted = tf.cumprod(shape, exclusive=True, reverse=True)
    strides = tf.cast(strides, tf.int32)
    strides_shifted = tf.cast(strides_shifted, tf.int32)
    def even():
        rem = indices - (indices // strides) * strides
        return rem // strides_shifted
    def odd():
        div = indices // strides_shifted
        return div - (div // strides) * strides
    rank = tf.rank(shape)
    return tf.cond(tf.equal(rank - (rank // 2) * 2, 0), even, odd)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:contribution welcome Status - Contributions welcome type:feature Feature requests
Projects
None yet
Development

No branches or pull requests

8 participants