In [170]:
import tensorflow as tf
import numpy as np

In [171]:
def squash(input_tensor):
    squared_norm = tf.reduce_sum((input_tensor ** 2), axis=-1, keepdims=True)
    output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * tf.sqrt(squared_norm))
    return output_tensor

In [251]:
def build(net_input, no_primary_capsules=8, no_digit_capsules=10, routes = 32*6*6):
    # (batch, channel, height, width) or (batch, height, width, channel)
    
    with tf.variable_scope('entry_conv'):
        conv_out = tf.layers.conv2d(inputs=net_input, filters=256, kernel_size=9, activation=tf.nn.relu)
    print("CONV OUT: {}".format(conv_out.get_shape()))
    # PRIMARY CAPSULE
    primary_capsules = []
    for i in range(no_primary_capsules):
        with tf.variable_scope('primary_capsule_{}'.format(i)):
            primary_capsules.append(tf.transpose(tf.layers.conv2d(inputs=conv_out, filters=32, kernel_size=9, strides=(2,2), padding='valid'), [0,3,1,2]))
    primary_capsules = tf.stack(primary_capsules, axis=1)
    print("PRIMARY_CAPSULES: {}".format(primary_capsules.get_shape()))
    primary_capsules = tf.reshape(primary_capsules, [conv_out.get_shape()[0],32*6*6,-1])
    print("PRIMARY_CAPSULES: {}".format(primary_capsules.get_shape()))
    primary_capsules = squash(primary_capsules)

    # DIGIT CAPSULE
    batch_size = primary_capsules.get_shape()[0]
    x = tf.expand_dims(tf.stack([primary_capsules] * no_digit_capsules, axis=2), axis=4)
    with tf.variable_scope("digit_weights", reuse=tf.AUTO_REUSE):
        W = tf.get_variable('W', [1, routes, no_digit_capsules, 16, 8], trainable=True)
    with tf.variable_scope("digit_bias", reuse=tf.AUTO_REUSE):
        b_ij = tf.get_variable('b', [1, routes, no_digit_capsules, 1])
    
    W_batch = tf.concat([W] * batch_size, axis=0)
    print("X SHAPE: {}".format(x.get_shape()))
    print("W BATCH SHAPE: {}".format(W_batch.get_shape()))
    
    u_hat = tf.matmul(W_batch,x)
    print('u_hat SHAPE: {}'.format(u_hat.get_shape()))
    
    num_iterations = 3
    for iteration in range(num_iterations):
        c_ij = tf.nn.softmax(b_ij)
        c_ij = tf.expand_dims(tf.concat([c_ij] * batch_size, axis=0), axis=4)
        
        s_j = tf.reduce_sum((c_ij * u_hat), axis=1, keepdims=True)
        v_j = squash(s_j)
        
        if iteration < num_iterations - 1:
            a_ij = tf.matmul(tf.transpose(u_hat, [0,1,2,4,3]), tf.concat([v_j] * routes, axis=1))
            b_ij = b_ij + tf.reduce_mean(tf.squeeze(a_ij, axis=4), axis=0, keepdims=True)
    digit_output = tf.squeeze(v_j, axis=1)
            
    # DECODER
    classes = tf.reduce_sum(tf.sqrt(digit_output ** 2), axis=2)
    classes = tf.nn.softmax(classes)
    
    max_length_indices = tf.argmax(classes, axis=1)
    masked = tf.eye(10)
    masked = tf.gather(masked, axis=0, indices=tf.squeeze(max_length_indices, axis=1))
    
    decoder_input = x * tf.reshape(masked[:, :, None, None], [x.get_shape()[0], -1])
    return masked

In [252]:
tf.reset_default_graph()
input_pl = tf.placeholder(shape=[3,28,28,1], dtype=tf.float32)
net = build(input_pl)
print()
numpy_input = np.random.rand(3,28,28,1)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    res = sess.run(net, feed_dict={input_pl: numpy_input})
    print('RES SHAPE: {}'.format(res.shape))
        

CONV OUT: (3, 20, 20, 256)
PRIMARY_CAPSULES: (3, 8, 32, 6, 6)
PRIMARY_CAPSULES: (3, 1152, 8)
X SHAPE: (3, 1152, 10, 8, 1)
W BATCH SHAPE: (3, 1152, 10, 16, 8)
u_hat SHAPE: (3, 1152, 10, 16, 1)


ValueError: Dimensions must be equal, but are 8 and 3 for 'mul_11' (op: 'Mul') with input shapes: [3,1152,10,8,1], [3,10].