In [313]:
import tensorflow as tf
import numpy as np

In [314]:
def squash(input_tensor):
    squared_norm = tf.reduce_sum((input_tensor ** 2), axis=-1, keepdims=True)
    output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * tf.sqrt(squared_norm))
    return output_tensor

In [438]:
def forward(net_input, no_primary_capsules=8, no_digit_capsules=10, routes = 32*6*6):
    # (batch, channel, height, width) or (batch, height, width, channel)
    
    with tf.variable_scope('ReLU_Conv1'):
        conv_out = tf.layers.conv2d(inputs=net_input, filters=256, kernel_size=9, activation=tf.nn.relu)
    print("ReLU Conv1 OUT: {}".format(tf.shape(conv_out)))
    print("ReLU Conv1 OUT: {}".format(conv_out.get_shape()))
    # PRIMARY CAPSULE
    primary_capsules = []
    for i in range(no_primary_capsules):
        with tf.variable_scope('primary_capsule_{}'.format(i)):
            primary_capsules.append(tf.transpose(tf.layers.conv2d(inputs=conv_out, filters=32, kernel_size=9, strides=(2,2), padding='valid'), [0,3,1,2]))
    primary_capsules = tf.stack(primary_capsules, axis=1)
    print("PRIMARY CAPSULES (before reshape): {}".format(tf.shape(primary_capsules)))
    print("PRIMARY CAPSULES (before reshape): {}".format(primary_capsules.get_shape()))
    primary_capsules = tf.reshape(primary_capsules, [tf.shape(conv_out)[0],32*6*6,-1])
    print("PRIMARY CAPSULES (before squash): {}".format(tf.shape(primary_capsules)))
    print("PRIMARY CAPSULES (before squash): {}".format(primary_capsules.get_shape()))
    primary_capsules = squash(primary_capsules)
    print("PRIMARY CAPSULES OUT: {}".format(tf.shape(primary_capsules)))
    print("PRIMARY CAPSULES OUT: {}".format(primary_capsules.get_shape()))

    # DIGIT CAPSULE
    batch_size = tf.shape(primary_capsules)[0]
    x = tf.expand_dims(tf.stack([primary_capsules] * no_digit_capsules, axis=2), axis=4)
    print('X SHAPE: {}'.format(tf.shape(x)))
    with tf.variable_scope("transformation_matrix_weights", reuse=tf.AUTO_REUSE):
        W = tf.get_variable('W', [1, routes, no_digit_capsules, 16, 8], trainable=True)
    with tf.variable_scope("similarity_score", reuse=tf.AUTO_REUSE):
        b_ij = tf.get_variable('b', [1, routes, no_digit_capsules, 1], initializer=tf.zeros_initializer)
    
    W_batch = tf.tile(W, [batch_size,1,1,1,1])
    print("X SHAPE: {}".format(tf.shape(x)))
    print("X SHAPE: {}".format(x.get_shape()))
    print("W BATCH SHAPE: {}".format(tf.shape(W_batch)))
    print("W BATCH SHAPE: {}".format(W_batch.get_shape()))
    
    # u_hat = \hat{u}_{j|i} -- prediction vector
    u_hat = tf.matmul(W_batch,x)
    print('u_hat SHAPE: {}'.format(tf.shape(u_hat)))
    print('u_hat SHAPE: {}'.format(u_hat.get_shape()))
    
    num_iterations = 3
    for iteration in range(num_iterations):
        # c_ij -- coupling coefficients
        c_ij = tf.nn.softmax(b_ij)
        c_ij = tf.expand_dims(tf.concat(tf.tile(c_ij, [batch_size, 1, 1, 1]), axis=0), axis=4)
        
        s_j = tf.reduce_sum((c_ij * u_hat), axis=1, keepdims=True)
        # v_j -- output | activity vector
        v_j = squash(s_j)
        
        if iteration < num_iterations - 1:
            a_ij = tf.matmul(tf.transpose(u_hat, [0,1,2,4,3]), tf.concat([v_j] * routes, axis=1))
            b_ij = b_ij + tf.reduce_mean(tf.squeeze(a_ij, axis=4), axis=0, keepdims=True)
    digit_output = tf.squeeze(v_j, axis=1)
    print('DIGIT CAPSULE OUT: {}'.format(tf.shape(digit_output)))
    print('DIGIT CAPSULE OUT: {}'.format(digit_output.get_shape()))
            
    # DECODER
    classes = tf.reduce_sum(tf.sqrt(digit_output ** 2), axis=2)
    classes = tf.nn.softmax(classes)
    
    max_length_indices = tf.argmax(classes, axis=1)
    masked = tf.eye(10)
    masked = tf.gather(masked, axis=0, indices=tf.squeeze(max_length_indices, axis=1))
    decoder_input = tf.reshape(digit_output * masked[:, :, None, None], [tf.shape(digit_output)[0], 10*16*1])
    print('DECODER INPUT: {}'.format(tf.shape(decoder_input)))
    print('DECODER INPUT: {}'.format(decoder_input.get_shape()))
      
    net = tf.layers.dense(inputs=decoder_input, units=512, activation=tf.nn.relu)
    net = tf.layers.dense(inputs=decoder_input, units=1024, activation=tf.nn.relu)
    net = tf.layers.dense(inputs=decoder_input, units=784, activation=tf.nn.sigmoid)
    
    decoder_output = tf.reshape(net, [-1, 1, 28, 28])
    
    return digit_output, decoder_output, masked

In [439]:
def loss(data, x, target, reconstructions):
    loss_val = margin_loss(x, target)
    recon_loss_val = reconstruction_loss(data, reconstructions)
    return loss_val + recon_loss_val

In [440]:
def margin_loss(x, labels, size_average=True):
    batch_size = tf.shape(x)[0]
    v_c = tf.sqrt(tf.reduce_sum((x ** 2), axis=2, keepdims=True))
    
    left = tf.reshape(tf.nn.relu(0.9 - v_c), [batch_size, -1])
    right = tf.reshape(tf.nn.relu(v_c - 0.1), [batch_size, -1])
    
    loss = labels * left + 0.5 * (1.0 - labels) * right
    loss = tf.reduce_mean(tf.reduce_sum(loss, axis=1))
    return loss

In [441]:
def reconstruction_loss(data, reconstructions):
    labels = tf.reshape(reconstructions, [tf.shape(reconstructions)[0], -1])
    target = tf.reshape(data, [tf.shape(reconstructions)[0], -1])
    loss = tf.losses.mean_squared_error(labels, target)
    return loss * 0.0005

In [448]:
def train(loss, learning_rate=0.001):
    return tf.train.AdamOptimizer(learning_rate).minimize(loss)

In [None]:
tf.reset_default_graph()
input_pl = tf.placeholder(shape=[None,28,28,1], dtype=tf.float32)
target_pl = tf.placeholder(shape=[None,1], dtype=tf.float32)

net = forward(input_pl)
loss_op = loss(input_pl, net[0], target_pl, net[1])
train_op = train(loss_op)

print()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(1000):
        numpy_input = np.random.rand(64,28,28,1)
        numpy_target = np.random.randint(low=0, high=2, size=(64,1))
        loss_val, res = sess.run([loss_op, train_op], feed_dict={input_pl: numpy_input, target_pl: numpy_target})
        print('Loss: {}'.format(loss_val))
        

ReLU Conv1 OUT: Tensor("Shape:0", shape=(4,), dtype=int32)
ReLU Conv1 OUT: (?, 20, 20, 256)
PRIMARY CAPSULES (before reshape): Tensor("Shape_1:0", shape=(5,), dtype=int32)
PRIMARY CAPSULES (before reshape): (?, 8, 32, 6, 6)
PRIMARY CAPSULES (before squash): Tensor("Shape_3:0", shape=(3,), dtype=int32)
PRIMARY CAPSULES (before squash): (?, 1152, ?)
PRIMARY CAPSULES OUT: Tensor("Shape_4:0", shape=(3,), dtype=int32)
PRIMARY CAPSULES OUT: (?, 1152, ?)
X SHAPE: Tensor("Shape_6:0", shape=(5,), dtype=int32)
X SHAPE: Tensor("Shape_7:0", shape=(5,), dtype=int32)
X SHAPE: (?, 1152, 10, ?, 1)
W BATCH SHAPE: Tensor("Shape_8:0", shape=(5,), dtype=int32)
W BATCH SHAPE: (?, 1152, 10, 16, 8)
u_hat SHAPE: Tensor("Shape_9:0", shape=(5,), dtype=int32)
u_hat SHAPE: (?, 1152, 10, 16, 1)
DIGIT CAPSULE OUT: Tensor("Shape_16:0", shape=(4,), dtype=int32)
DIGIT CAPSULE OUT: (?, 10, 16, 1)
DECODER INPUT: Tensor("Shape_20:0", shape=(2,), dtype=int32)
DECODER INPUT: (?, 160)

