In [1]:
import tensorflow as tf

In [2]:
def maxpool2d(x):
    #                        size of window         movement of window
    return tf.nn.max_pool_with_argmax(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')

In [3]:
def unpool_with_with_argmax(pooled, ind, ksize=[1, 2, 2, 1]):
    """
       To unpool the tensor after  max_pool_with_argmax.
       Args:
           pooled:    the max pooled output tensor
           ind:       argmax indices , the second output of max_pool_with_argmax
           ksize:     ksize should be the same as what you have used to pool
       Return:
           ret:      the tensor after unpooling
    """
    #get the the shape of the tensor in th form of a list
    input_shape = pooled.get_shape().as_list()
    #determine the output shape
    output_shape = (input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3])
    #reshape into one giant tensor for better workability
    pooled_ = tf.reshape(pooled, [input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3]])
    #The indices in argmax are flattened, so that a maximum value at position [b, y, x, c] becomes flattened index ((b * height + y) * width + x) * channels + c
    #create a single unit extended cuboid of length bath_size populating it with continous natural number from zero to batch_size
    batch_range = tf.reshape(tf.range(output_shape[0], dtype=ind.dtype), shape=[input_shape[0], 1, 1, 1])
    b = tf.ones_like(ind) * batch_range
    b = tf.reshape(b, [input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3], 1])
    ind_ = tf.reshape(ind, [input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3], 1])
    ind_ = tf.concat([b, ind_],1)
    ref = tf.Variable(tf.zeros([output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]]))
    # Update the sparse matrix with the pooled values , it is a batch wise operation
    ret = tf.scatter_nd_update(ref, ind_, pooled_)
    # Reshape the vector to get the final result 
    ret = tf.reshape(ret, [output_shape[0], output_shape[1], output_shape[2], output_shape[3]])
    return ret

In [None]:
orig = tf.random_uniform([16,32,32, 8],maxval=500,dtype='float32',seed=2)

In [None]:
with tf.Session() as sess:
    orig_np=sess.run(orig)

In [None]:
orig_np.shape

In [None]:
pooled_tf,max_indices=maxpool2d(orig)

In [None]:
pooled_tf

In [None]:
unpooled2=unpool_with_with_argmax(pooled_tf,max_indices)

In [None]:
unpooled2

In [None]:
#####all experiments below this#########

In [None]:
input_shape = pooled_tf.get_shape().as_list()
pooled_ = tf.reshape(pooled_tf, [input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3]])

In [None]:
pooled_

In [None]:
max_indices.dtype

In [None]:
type(max_indices)

In [None]:
max_indices

In [None]:
orig2 = tf.random_uniform([2,4,4,3],maxval=100,dtype='float32',seed=2)

In [None]:
pooled_tf2,max_indices2=maxpool2d(orig2)

In [None]:
#visualise the original tensor , visualise the pooled tensor , visualise the argmax , visualise the unpooled tensor

In [None]:
unpooled22=unpool_with_with_argmax(pooled_tf2,max_indices2)

In [None]:
orig2.shape

In [None]:
with tf.Session() as sess:
    print(sess.run(orig2))

In [None]:
with tf.Session() as sess:
    print(sess.run(pooled_tf2))

In [None]:
with tf.Session() as sess:
    print(sess.run(max_indices2))

In [None]:
with tf.Session() as sess:
    print(sess.run(pooled_tf2))

In [None]:
##from now on we will just visualize what happens with a single example !

In [None]:
input_shape = pooled_tf2.get_shape().as_list()

In [None]:
input_shape

In [None]:
#ksize[1] is the height of the window
#ksize[2] is the width of the window
ksize=[1, 2, 2, 1]
output_shape = (input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3])

In [None]:
output_shape

In [None]:
#reshape the given POOLED into one giant tensor for better workability
pooled_ = tf.reshape(pooled_tf2, [input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3]])

In [None]:
pooled_.shape

In [None]:
#ind is the argmax
ind = max_indices2
batch_range = tf.reshape(tf.range(output_shape[0], dtype=ind.dtype), shape=[input_shape[0], 1, 1, 1])

In [None]:
batch_range.shape

In [None]:
# ^^ all ones because we are testing it for a single batch only , changing a batch will effect the first part of the shape onle others will still remain one

In [None]:
#Let me myself figure out whats happening over here !
b = tf.ones_like(ind) * batch_range

In [None]:
b

In [None]:
#I think the batch range was just introduced to effect the multiplication by the 'b' in the formula ((b * height + y) * width + x) * channels +c

In [None]:
#I have to apply some nasty broadcasting rules to get this !

In [None]:
with tf.Session() as sess:
    b_as_nparray1=sess.run(b)
    print(b_as_nparray)

In [None]:
with tf.Session() as sess:
    batch_range_as_nparray=sess.run(batch_range)
    print(batch_range_as_nparray)

In [None]:
b = tf.reshape(b, [input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3], 1])

In [None]:
with tf.Session() as sess:
    b_as_nparray=sess.run(b)
    print(b_as_nparray)

In [None]:
ind_ = tf.reshape(ind, [input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3], 1])

In [None]:
with tf.Session() as sess:
    ind__as_nparray1=sess.run(ind_)
    print(ind__as_nparray)

In [None]:
# Lets concatante them along axis 1 that is along horizontally 

In [None]:
ind_ = tf.concat([b, ind_],1)

In [None]:
with tf.Session() as sess:
    ind__as_nparray = sess.run(tf.concat([b, ind_],1))
    print(ind__as_nparray)

In [None]:
b

In [None]:
ind_

In [None]:
ind__as_nparray.shape

In [None]:
ref = tf.Variable(tf.zeros([output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]]))

In [None]:
ref

In [None]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    ref_as_nparray=sess.run(ref)
    print(ref_as_nparray)

In [None]:
ret = tf.scatter_nd_update(ref, ind_, pooled_)

In [None]:
ret

In [None]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    ret_as_nparray=sess.run(ret)
    print(ret_as_nparray)

In [None]:
ret = tf.reshape(ret, [output_shape[0], output_shape[1], output_shape[2], output_shape[3]])

In [None]:
ret

In [None]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    ret_as_nparray=sess.run(ret)
    print(ret_as_nparray)

In [None]:
with tf.Session() as sess:
    print(sess.run(orig2))

In [4]:
############################################################################################3

In [51]:
orig2 = tf.random_uniform([1,4,4,3],maxval=100,dtype='float32',seed=2)

In [6]:
pooled_tf2,max_indices2=maxpool2d(orig2)

In [7]:
#visualise the original tensor , visualise the pooled tensor , visualise the argmax , visualise the unpooled tensor

In [8]:
unpooled22=unpool_with_with_argmax(pooled_tf2,max_indices2)

In [9]:
orig2.shape

TensorShape([Dimension(1), Dimension(4), Dimension(4), Dimension(3)])

In [53]:
with tf.Session() as sess:
    print(sess.run(original_tensor))

[[[[ 67.87465668  71.41509247  46.98169327]
   [ 48.19697189  47.83817673  55.389534  ]
   [ 23.51480675  43.86433411  48.48768616]
   [ 83.74739075  66.82654572  81.30921173]]

  [[ 21.17181969  14.59358978  47.21593857]
   [ 25.94748688  76.70527649   9.67714787]
   [ 74.96897888  15.72545815  36.21578217]
   [ 40.15810394  46.98136902  37.93695068]]

  [[ 19.17402649  13.34600449  63.32398605]
   [ 75.63816071  39.96374512  50.45898056]
   [  4.35934067  71.83367157  95.57234192]
   [  0.63753128  26.81066895  83.41733551]]

  [[ 27.07540894  90.00960541   6.51501417]
   [ 66.9224472   60.66981506  63.90122223]
   [  9.09529877  87.97344208  11.55341911]
   [ 67.55619049  50.58227921  57.45575333]]]]


In [11]:
with tf.Session() as sess:
    print(sess.run(pooled_tensor))

[[[[ 67.87465668  76.70527649  55.389534  ]
   [ 83.74739075  66.82654572  81.30921173]]

  [[ 75.63816071  90.00960541  63.90122223]
   [ 67.55619049  87.97344208  95.57234192]]]]


In [12]:
with tf.Session() as sess:
    print(sess.run(max_indices))

[[[[ 0 16  5]
   [ 9 10 11]]

  [[27 37 41]
   [45 43 32]]]]


In [13]:
with tf.Session() as sess:
    print(sess.run(pooled_tf2))

[[[[ 67.87465668  76.70527649  55.389534  ]
   [ 83.74739075  66.82654572  81.30921173]]

  [[ 75.63816071  90.00960541  63.90122223]
   [ 67.55619049  87.97344208  95.57234192]]]]


In [14]:
##from now on we will just visualize what happens with a single example !

In [15]:
input_shape = pooled_tf2.get_shape().as_list()

In [16]:
input_shape

[1, 2, 2, 3]

In [17]:
#ksize[1] is the height of the window
#ksize[2] is the width of the window
ksize=[1, 2, 2, 1]
output_shape = (input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3])

In [18]:
output_shape

(1, 4, 4, 3)

In [19]:
#reshape the given POOLED into one giant tensor for better workability
pooled_ = tf.reshape(pooled_tf2, [input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3]])

In [20]:
pooled_.shape

TensorShape([Dimension(12)])

In [21]:
#ind is the argmax
ind = max_indices2
batch_range = tf.reshape(tf.range(output_shape[0], dtype=ind.dtype), shape=[input_shape[0], 1, 1, 1])

In [22]:
batch_range.shape

TensorShape([Dimension(1), Dimension(1), Dimension(1), Dimension(1)])

In [23]:
# ^^ all ones because we are testing it for a single batch only , changing a batch will effect the first part of the shape onle others will still remain one

In [24]:
#Let me myself figure out whats happening over here !
b = tf.ones_like(ind) * batch_range

In [25]:
b

<tf.Tensor 'mul_1:0' shape=(1, 2, 2, 3) dtype=int64>

In [26]:
#I think the batch range was just introduced to effect the multiplication by the 'b' in the formula ((b * height + y) * width + x) * channels +c

In [27]:
#I have to apply some nasty broadcasting rules to get this !

In [28]:
with tf.Session() as sess:
    b_as_nparray=sess.run(b)
    print(b_as_nparray)

[[[[0 0 0]
   [0 0 0]]

  [[0 0 0]
   [0 0 0]]]]


In [29]:
with tf.Session() as sess:
    batch_range_as_nparray=sess.run(batch_range)
    print(batch_range_as_nparray)

[[[[0]]]]


In [30]:
b = tf.reshape(b, [input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3], 1])

In [31]:
with tf.Session() as sess:
    b_as_nparray=sess.run(b)
    print(b_as_nparray)

[[0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]]


In [32]:
ind_ = tf.reshape(ind, [input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3], 1])

In [33]:
with tf.Session() as sess:
    ind__as_nparray=sess.run(ind_)
    print(ind__as_nparray)

[[ 0]
 [16]
 [ 5]
 [ 9]
 [10]
 [11]
 [27]
 [37]
 [41]
 [45]
 [43]
 [32]]


In [34]:
# Lets concatante them along axis 1 that is along horizontally 

In [35]:
ind_ = tf.concat([b, ind_],1)

In [36]:
with tf.Session() as sess:
    ind__as_nparray = sess.run(tf.concat([b, ind_],1))
    print(ind__as_nparray)

[[ 0  0  0]
 [ 0  0 16]
 [ 0  0  5]
 [ 0  0  9]
 [ 0  0 10]
 [ 0  0 11]
 [ 0  0 27]
 [ 0  0 37]
 [ 0  0 41]
 [ 0  0 45]
 [ 0  0 43]
 [ 0  0 32]]


In [37]:
b

<tf.Tensor 'Reshape_7:0' shape=(12, 1) dtype=int64>

In [38]:
ind_

<tf.Tensor 'concat_1:0' shape=(12, 2) dtype=int64>

In [39]:
ind__as_nparray.shape

(12, 3)

In [40]:
ref = tf.Variable(tf.zeros([output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]]))

In [41]:
ref

<tf.Variable 'Variable_1:0' shape=(1, 48) dtype=float32_ref>

In [42]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    ref_as_nparray=sess.run(ref)
    print(ref_as_nparray)

[[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]]


In [43]:
ret = tf.scatter_nd_update(ref, ind_, pooled_)

In [44]:
ret

<tf.Tensor 'ScatterNdUpdate_1:0' shape=(1, 48) dtype=float32_ref>

In [45]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    ret_as_nparray=sess.run(ret)
    print(ret_as_nparray)

[[ 67.87465668   0.           0.           0.           0.          55.389534
    0.           0.           0.          83.74739075  66.82654572
   81.30921173   0.           0.           0.           0.          76.70527649
    0.           0.           0.           0.           0.           0.
    0.           0.           0.           0.          75.63816071   0.
    0.           0.           0.          95.57234192   0.           0.
    0.           0.          90.00960541   0.           0.           0.
   63.90122223   0.          87.97344208   0.          67.55619049   0.
    0.        ]]


In [46]:
ret = tf.reshape(ret, [output_shape[0], output_shape[1], output_shape[2], output_shape[3]])

In [47]:
ret

<tf.Tensor 'Reshape_9:0' shape=(1, 4, 4, 3) dtype=float32>

In [48]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    ret_as_nparray=sess.run(ret)
    print(ret_as_nparray)

[[[[ 67.87465668   0.           0.        ]
   [  0.           0.          55.389534  ]
   [  0.           0.           0.        ]
   [ 83.74739075  66.82654572  81.30921173]]

  [[  0.           0.           0.        ]
   [  0.          76.70527649   0.        ]
   [  0.           0.           0.        ]
   [  0.           0.           0.        ]]

  [[  0.           0.           0.        ]
   [ 75.63816071   0.           0.        ]
   [  0.           0.          95.57234192]
   [  0.           0.           0.        ]]

  [[  0.          90.00960541   0.        ]
   [  0.           0.          63.90122223]
   [  0.          87.97344208   0.        ]
   [ 67.55619049   0.           0.        ]]]]


In [49]:
with tf.Session() as sess:
    print(sess.run(orig2))

[[[[ 67.87465668  71.41509247  46.98169327]
   [ 48.19697189  47.83817673  55.389534  ]
   [ 23.51480675  43.86433411  48.48768616]
   [ 83.74739075  66.82654572  81.30921173]]

  [[ 21.17181969  14.59358978  47.21593857]
   [ 25.94748688  76.70527649   9.67714787]
   [ 74.96897888  15.72545815  36.21578217]
   [ 40.15810394  46.98136902  37.93695068]]

  [[ 19.17402649  13.34600449  63.32398605]
   [ 75.63816071  39.96374512  50.45898056]
   [  4.35934067  71.83367157  95.57234192]
   [  0.63753128  26.81066895  83.41733551]]

  [[ 27.07540894  90.00960541   6.51501417]
   [ 66.9224472   60.66981506  63.90122223]
   [  9.09529877  87.97344208  11.55341911]
   [ 67.55619049  50.58227921  57.45575333]]]]


In [56]:
unpooled_tensor=unpool_with_with_argmax(pooled_tf2,max_indices2)

<tf.Tensor 'Reshape_17:0' shape=(1, 4, 4, 3) dtype=float32>

In [58]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(unpooled_tensor))

[[[[ 67.87465668   0.           0.        ]
   [  0.           0.          55.389534  ]
   [  0.           0.           0.        ]
   [ 83.74739075  66.82654572  81.30921173]]

  [[  0.           0.           0.        ]
   [  0.          76.70527649   0.        ]
   [  0.           0.           0.        ]
   [  0.           0.           0.        ]]

  [[  0.           0.           0.        ]
   [ 75.63816071   0.           0.        ]
   [  0.           0.          95.57234192]
   [  0.           0.           0.        ]]

  [[  0.          90.00960541   0.        ]
   [  0.           0.          63.90122223]
   [  0.          87.97344208   0.        ]
   [ 67.55619049   0.           0.        ]]]]
