In [1]:
import tensorflow as tf
import numpy as np
np.set_printoptions(threshold=np.nan)
np.set_printoptions(precision=3)
np.set_printoptions(suppress=True)

In [87]:
tt = np.array([range(1, 33)])
tt = np.reshape(tt, [1,4,4,2], order='F')
tt = np.transpose(tt, [0, 2, 1, 3])
tt = np.reshape(tt, [1,4,4,2], order='C')
tt[0,0,1,0] = 1
tt[0,1,2,0] = 8
tt[0,0,3,0] = 8
tt[0,1,3,0] = 4
tt[0,3,1,0] = 13
tt[0,2,2,0] = 16
tt[0,0,3,1] = 19
tt[0,1,2,1] = 19
tt[0,3,0,1] = 26
tt[0,3,1,1] = 26
tt[0,3,2,1] = 27
tt[0,3,3,1] = 27

In [88]:
tt[0,:,:,0]

array([[ 1,  1,  3,  8],
       [ 5,  6,  8,  4],
       [ 9, 10, 16, 12],
       [13, 13, 15, 16]])

In [89]:
tt[0,:,:,1]

array([[17, 18, 19, 19],
       [21, 22, 19, 24],
       [25, 26, 27, 28],
       [26, 26, 27, 27]])

In [5]:
# extract patches from feature maps
# input shape N, H, W, C
# output shape N, H, W, K, C
def extract_patches(x, padding, ksize=2, stride=2):
    temp = tf.extract_image_patches(images=x, ksizes=[1, ksize, ksize, 1], strides=[1, stride, stride, 1], rates=[1,1,1,1], padding=padding)
    [N, H, W, C] = temp.get_shape().as_list()
    C = x.get_shape().as_list()[-1]
#     reshape to N,H,W,K,C
    temp = tf.reshape(temp, [N, H, W, ksize*ksize, C])
    return temp

In [37]:
# compute the frequency of element in each patch
# input extracted patches tensor in shape N, H, W, K, C
# output frequency tensor in shape N, H, W, K, C
def majority_frequency(temp):
    
    [N, H, W, K, C] = temp.get_shape().as_list()
    temp = tf.to_int32(tf.round(temp))
    
#     build one hot vector
    temp = tf.transpose(temp, [0,1,2,4,3])
    one_hot = tf.one_hot(indices=temp, depth=tf.reduce_max(temp) + 1, dtype=tf.float32)
#     the dimension is bathch, row, col, lay, one hot
#     the order tensorflow takes, when doiong transpose, it will from the most right to most left
    one_hot = tf.reduce_sum(one_hot, axis=4)
    temp = tf.transpose(temp, [0, 3, 1, 2, 4])
    temp = tf.reshape(temp, [-1,1])
    one_hot = tf.transpose(one_hot, [0,3,1,2,4])
    one_hot = tf.reshape(one_hot, [N*H*W*C, -1])
    
    index = tf.constant(np.array([range(temp.get_shape().as_list()[0])])/ K, dtype=tf.int32)
    temp = tf.concat((tf.transpose(index), temp), axis=1)
    
#     to get the percentage
    temp = tf.gather_nd(one_hot, temp)
    temp = tf.reshape(temp, [N, C, H, W, K])
#     finally we change it back to N,H,W,K,C
    temp = tf.transpose(temp, [0, 2, 3, 4, 1])
    
    return temp

In [7]:
# compute weight based on frequency tensor
# fun could be tf.reduce_max, tf.reduce_sum, reduce_size
# output in shape N, H, W, K, C
def compute_weight(w, fun):
    if isinstance(fun, str): deno = w.get_shape().as_list()[3]
    else: deno = fun(w, axis=3, keep_dims=True)
    temp = tf.divide(w, deno)
    return temp

In [84]:
# weight before maxpool p:= patches, w:= weights
def weight_max(p, w, fun):
    temp = tf.multiply(p, compute_weight(w, fun))
    temp = tf.reduce_max(temp, axis=3)
    return temp

# maxpool before weight
def max_weight(p, w, fun):
#     for now both p and w are in the shape of N,H,W,K,C
    [N, H, W, K, C] = p.get_shape().as_list()
    w = compute_weight(w, fun)
#     argmax in the shape of N, H, W, C
    argmax = tf.argmax(p, axis=3)
    p = tf.reduce_max(p, axis=3)
#     move C before H
    argmax = tf.transpose(argmax, [0, 3, 1, 2])
    w = tf.transpose(w, [0, 4, 1, 2, 3])
#     flatten argmax and w
    argmax = tf.reshape(argmax, [N*H*W*C, 1])
    w = tf.reshape(w, [N*H*W*C, K])
#     create index helper
    index = tf.constant(np.array([range(argmax.get_shape().as_list()[0])]), dtype=tf.int64)
    argmax = tf.concat((tf.transpose(index), argmax), axis=1)
#     get the corresponding weight of the max
    w = tf.gather_nd(w, argmax)
    w = tf.reshape(w, [N, C, H, W])
    w = tf.transpose(w, [0, 2, 3, 1])
    
    return tf.multiply(p, w)

In [90]:
x = tf.constant(tt, dtype=tf.float32)

p = extract_patches(x, "VALID", 2, 2)
f = majority_frequency(p)

temp = max_weight(p, f, tf.reduce_max)

# temp = tf.nn.max_pool(x, ksize=[1,2,2,1],padding='VALID',strides=[1,2,2,1])

with tf.Session() as sess:
    ret = sess.run(temp)

In [245]:
x = tf.constant(np.array([[1,2,3,4],[5,6,7,8]]))

with tf.Session() as sess:
    ret1 = sess.run(temp)

In [91]:
ret

array([[[[  3.   ,  22.   ],
         [  8.   ,   8.   ]],

        [[ 13.   ,  26.   ],
         [ 16.   ,   9.333]]]], dtype=float32)

In [65]:
ret

array([[0, 3],
       [1, 3],
       [2, 2],
       [3, 0],
       [4, 3],
       [5, 3],
       [6, 1],
       [7, 1]])

In [83]:
ret

array([[[[ 0.5  ,  1.   ],
         [ 0.5  ,  0.333]],

        [[ 1.   ,  1.   ],
         [ 1.   ,  0.333]]]], dtype=float32)

In [78]:
ret

array([[[[  6.,  22.],
         [  8.,  24.]],

        [[ 13.,  26.],
         [ 16.,  28.]]]], dtype=float32)

In [99]:
v = np.zeros(4,4)

TypeError: data type not understood

In [98]:
v

array([ 0.,  0.,  0.,  0.])

In [95]:
v[:,0] = 1

In [96]:
v

array([[ 1.,  0.,  0.,  0.],
       [ 1.,  0.,  0.,  0.],
       [ 1.,  0.,  0.,  0.],
       [ 1.,  0.,  0.,  0.]])