-
Notifications
You must be signed in to change notification settings - Fork 6
/
unpool.py
33 lines (32 loc) · 2.2 KB
/
unpool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def unpool_with_with_argmax(pooled, ind, ksize=[1, 2, 2, 1]):
"""
To unpool the tensor after max_pool_with_argmax.
Argumnets:
pooled: the max pooled output tensor
ind: argmax indices , the second output of max_pool_with_argmax
ksize: ksize should be the same as what you have used to pool
Returns:
unpooled: the tensor after unpooling
Some points to keep in mind ::
1. In tensorflow the indices in argmax are flattened, so that a maximum value at position [b, y, x, c] becomes flattened index ((b * height + y) * width + x) * channels + c
2. Due to point 1, use broadcasting to appropriately place the values at their right locations !
"""
# Get the the shape of the tensor in th form of a list
input_shape = pooled.get_shape().as_list()
# Determine the output shape
output_shape = (input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3])
# Ceshape into one giant tensor for better workability
pooled_ = tf.reshape(pooled, [input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3]])
# The indices in argmax are flattened, so that a maximum value at position [b, y, x, c] becomes flattened index ((b * height + y) * width + x) * channels + c
# Create a single unit extended cuboid of length bath_size populating it with continous natural number from zero to batch_size
batch_range = tf.reshape(tf.range(output_shape[0], dtype=ind.dtype), shape=[input_shape[0], 1, 1, 1])
b = tf.ones_like(ind) * batch_range
b_ = tf.reshape(b, [input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3], 1])
ind_ = tf.reshape(ind, [input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3], 1])
ind_ = tf.concat([b_, ind_],1)
ref = tf.Variable(tf.zeros([output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]]))
# Update the sparse matrix with the pooled values , it is a batch wise operation
unpooled_ = tf.scatter_nd_update(ref, ind_, pooled_)
# Reshape the vector to get the final result
unpooled = tf.reshape(unpooled_, [output_shape[0], output_shape[1], output_shape[2], output_shape[3]])
return unpooled