In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np

  from ._conv import register_converters as _register_converters


In [2]:
def augmentation(x, max_offset=2):
    bz, h, w, c = x.shape
    bg = np.zeros([bz, w + 2 * max_offset, h + 2 * max_offset, c])
    offsets = np.random.randint(0, 2 * max_offset + 1, 2)
    bg[:, offsets[0]:offsets[0] + h, offsets[1]:offsets[1] + w, :] = x
    return bg[:, max_offset:max_offset + h, max_offset:max_offset + w, :]


def mnist_train_iter(iters=1000, batch_size=32, is_shift_ag=True):
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    max_offset = int(is_shift_ag) * 2
    for i in range(iters):
        batch = mnist.train.next_batch(batch_size)
        images = batch[0].reshape([batch_size, 28, 28, 1])
        images = np.concatenate([images] * 3, axis=-1)
        yield augmentation(images, max_offset), np.stack(
            [batch[1]] * 3, axis=-1)


def mnist_test_iter(iters=1000, batch_size=32, is_shift_ag=False):
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    max_offset = int(is_shift_ag) * 2
    for i in range(iters):
        batch = mnist.test.next_batch(batch_size)
        images = batch[0].reshape([batch_size, 28, 28, 1])
        images = np.concatenate([images] * 3, axis=-1)
        yield augmentation(images, max_offset), np.stack(
            [batch[1]] * 3, axis=-1)


def multimnist_train_iter(iters=1000, batch_size=32, is_shift_ag=True):
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    max_offset = int(is_shift_ag) * 2
    for i in range(iters):
        batch1 = mnist.train.next_batch(batch_size)
        batch2 = mnist.train.next_batch(batch_size)
        images1 = augmentation(batch1[0].reshape([batch_size, 28, 28, 1]),
                               max_offset)
        images2 = augmentation(batch2[0].reshape([batch_size, 28, 28, 1]),
                               max_offset)
        images = np.logical_or(images1, images2).astype(np.float32)
        images = np.concatenate([images, images1, images2], axis=-1)
        y1, y2 = batch1[1], batch2[1]
        y0 = np.logical_or(y1, y2).astype(np.float32)
        yield images, np.stack([y0, y1, y2], axis=-1)


def multimnist_test_iter(iters=1000, batch_size=32, is_shift_ag=True):
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    max_offset = int(is_shift_ag) * 2
    for i in range(iters):
        batch1 = mnist.test.next_batch(batch_size)
        batch2 = mnist.test.next_batch(batch_size)
        images1 = augmentation(batch1[0].reshape([batch_size, 28, 28, 1]),
                               max_offset)
        images2 = augmentation(batch2[0].reshape([batch_size, 28, 28, 1]),
                               max_offset)
        images = np.logical_or(images1, images2).astype(np.float32)
        images = np.concatenate([images, images1, images2], axis=-1)
        y1, y2 = batch1[1], batch2[1]
        y0 = np.logical_or(y1, y2).astype(np.float32)
        yield images, np.stack([y0, y1, y2], axis=-1)

In [3]:
class CapsNet(object):
    def __init__(self,
                 routing_iterations=3,
                 batch_size=128,
                 is_multi_mnist=False,
                 beta1=0.9,
                 steps=5000,
                 norm=True):
        self.iterations = routing_iterations
        self.batch_size = batch_size
        self.is_multi_mnist = float(is_multi_mnist)

        self.x = tf.placeholder(tf.float32, [None, 28, 28, 3])
        self.h_sample = tf.placeholder(tf.float32, [None, 10, 16])
        self.y_sample = tf.placeholder(tf.float32, [None, 10])
        self.y = tf.placeholder(tf.float32, [None, 10, 3])

        self.norm = norm
        self.on_train = tf.placeholder(tf.bool)

        x_composed, x_a, x_b = tf.split(self.x, num_or_size_splits=3, axis=3)
        y_composed, y_a, y_b = tf.split(self.y, num_or_size_splits=3, axis=2)

        valid_mask = self.is_multi_mnist * (tf.reduce_sum(y_composed, axis=[1,2]) - 1.0) \
                      + (1.0 - self.is_multi_mnist) * tf.ones_like(y_composed[:,0,0])

        v_digit,reg_term = self.get_CapsNet(x_composed, self.norm, self.on_train)

        length_v = tf.reduce_sum(v_digit**2.0, axis=-1)**0.5

        x_rec_a = self.get_mlp_decoder(v_digit * y_a)
        x_rec_b = self.get_mlp_decoder(v_digit * y_b, reuse=True)
        loss_rec_a = tf.reduce_sum((x_rec_a - x_a)**2.0, axis=[1, 2, 3])
        loss_rec_b = tf.reduce_sum((x_rec_b - x_b)**2.0, axis=[1, 2, 3])
        self.loss_rec = (loss_rec_a + loss_rec_b) / 2.0
        self.x_recs = [x_rec_a, x_rec_b]
        self.x_sample = self.get_mlp_decoder(
            self.h_sample * self.y_sample[:, :, None], reuse=True)
        self.loss_cls = tf.reduce_sum(
            y_composed[:, :, 0] * tf.maximum(0.0, 0.9 - length_v)**2.0 + 0.5 *
            (1.0 - y_composed[:, :, 0]) * tf.maximum(0.0, length_v - 0.1)**2.0,
            axis=-1)
        self.loss_cls = tf.reduce_sum(
            self.loss_cls * valid_mask) / tf.reduce_sum(valid_mask)
        self.loss_rec = tf.reduce_sum(
            self.loss_rec * valid_mask) / tf.reduce_sum(valid_mask)
        self.loss = self.loss_cls + 0.0005 * self.loss_rec + reg_term

        global_step = tf.Variable(0)
        lr = tf.train.exponential_decay(
            0.00075, global_step, steps / 10, 0.96, staircase=False)

        self.train = tf.train.AdamOptimizer(
            learning_rate=lr, beta1=beta1).minimize(
                self.loss, global_step=global_step)

        if is_multi_mnist:
            self.accuracy = tf.reduce_mean(tf.cast(tf.nn.in_top_k(length_v,tf.argmax(tf.squeeze(y_a), 1),k=2),tf.float32))+\
                            tf.reduce_mean(tf.cast(tf.nn.in_top_k(length_v,tf.argmax(tf.squeeze(y_b), 1),k=2),tf.float32))
            self.accuracy /= 2.0
        else:
            correct_prediction = tf.equal(
                tf.argmax(y_composed[:, :, 0], 1), tf.argmax(length_v, 1))
            self.accuracy = tf.reduce_mean(
                tf.cast(correct_prediction, tf.float32))

    def get_CapsNet(self, x, norm, on_train, reuse=False):
        with tf.variable_scope('CapsNet', reuse=reuse):
            wconv1 = tf.get_variable(
                'wconv1', [9, 9, 1, 256],
                initializer=tf.truncated_normal_initializer(stddev=0.02))
            bconv1 = tf.get_variable(
                'bconv1', [256],
                initializer=tf.truncated_normal_initializer(stddev=0.02))
            wconv2 = tf.get_variable(
                'wconv2', [9, 9, 256, 8 * 32],
                initializer=tf.truncated_normal_initializer(stddev=0.02))
            bconv2 = tf.get_variable(
                'bconv2', [8 * 32],
                initializer=tf.truncated_normal_initializer(stddev=0.02))
            wcap = tf.get_variable(
                'wcap', [1, 6, 6, 32, 8, 10, 16],
                initializer=tf.truncated_normal_initializer(stddev=0.02))
            b = tf.get_variable(
                'coupling_coefficient_logits', [1, 6, 6, 32, 1, 10, 1],
                initializer=tf.constant_initializer(0.0))

        c = tf.stop_gradient(tf.nn.softmax(b, axis=5))
        
        #L2-regularization
        tf.add_to_collection(tf.GraphKeys.WEIGHTS, wconv1)
        tf.add_to_collection(tf.GraphKeys.WEIGHTS, wconv2)
        tf.add_to_collection(tf.GraphKeys.WEIGHTS, wcap)
        regularizer = tf.contrib.layers.l2_regularizer(scale=5.0/50000)
        reg_term = tf.contrib.layers.apply_regularization(regularizer)

        
        if norm:
            # BN for the first input
            fc_mean, fc_var = tf.nn.moments(
                x,
                axes=[0, 1, 2],
            )
            scale = tf.Variable(tf.ones([1]))
            shift = tf.Variable(tf.zeros([1]))
            epsilon = 0.001
            ema = tf.train.ExponentialMovingAverage(decay=0.5)

            def mean_var_with_update():
                ema_apply_op = ema.apply([fc_mean, fc_var])
                with tf.control_dependencies([ema_apply_op]):
                    return tf.identity(fc_mean), tf.identity(fc_var)

            mean, var = tf.cond(
                on_train, mean_var_with_update,
                lambda: (ema.average(fc_mean), ema.average(fc_var)))
            x = tf.nn.batch_normalization(x, mean, var, shift, scale, epsilon)

        conv1 = tf.nn.conv2d(x, wconv1, [1, 1, 1, 1], padding='VALID') + bconv1

        if norm:
            # BN for the first conv layer
            fc_mean, fc_var = tf.nn.moments(
                conv1,
                axes=[0, 1, 2],
            )
            scale = tf.Variable(tf.ones([1]))
            shift = tf.Variable(tf.zeros([1]))
            epsilon = 0.001
            ema = tf.train.ExponentialMovingAverage(decay=0.5)

            def mean_var_with_update():
                ema_apply_op = ema.apply([fc_mean, fc_var])
                with tf.control_dependencies([ema_apply_op]):
                    return tf.identity(fc_mean), tf.identity(fc_var)

            mean, var = tf.cond(
                on_train, mean_var_with_update,
                lambda: (ema.average(fc_mean), ema.average(fc_var)))
            conv1 = tf.nn.batch_normalization(conv1, mean, var, shift, scale,
                                              epsilon)

        conv1 = tf.nn.relu(conv1)

        s_primary = tf.nn.conv2d(
            conv1, wconv2, [1, 2, 2, 1], padding='VALID') + bconv2

        if norm:
            # BN for the second conv layer
            fc_mean, fc_var = tf.nn.moments(
                s_primary,
                axes=[0, 1, 2],
            )
            scale = tf.Variable(tf.ones([1]))
            shift = tf.Variable(tf.zeros([1]))
            epsilon = 0.001
            ema = tf.train.ExponentialMovingAverage(decay=0.5)

            def mean_var_with_update():
                ema_apply_op = ema.apply([fc_mean, fc_var])
                with tf.control_dependencies([ema_apply_op]):
                    return tf.identity(fc_mean), tf.identity(fc_var)

            mean, var = tf.cond(
                on_train, mean_var_with_update,
                lambda: (ema.average(fc_mean), ema.average(fc_var)))
            s_primary = tf.nn.batch_normalization(s_primary, mean, var, shift,
                                                  scale, epsilon)

        s_primary = tf.reshape(s_primary, [-1, 6, 6, 32, 8, 1, 1])

        v_primary = self.squash(s_primary, axis=4)

        #CAPSNET
        u = v_primary
        u_ = tf.reduce_sum(u * wcap, axis=[4], keepdims=True)
        s = tf.reduce_sum(u_ * c, axis=[1, 2, 3], keepdims=True)
        v = self.squash(s, axis=-1)

        if norm:
            # BN for the capsule layer
            fc_mean, fc_var = tf.nn.moments(
                v,
                axes=[0, 1, 2],
            )
            scale = tf.Variable(tf.ones([1]))
            shift = tf.Variable(tf.zeros([1]))
            epsilon = 0.001
            ema = tf.train.ExponentialMovingAverage(decay=0.5)

            def mean_var_with_update():
                ema_apply_op = ema.apply([fc_mean, fc_var])
                with tf.control_dependencies([ema_apply_op]):
                    return tf.identity(fc_mean), tf.identity(fc_var)

            mean, var = tf.cond(
                on_train, mean_var_with_update,
                lambda: (ema.average(fc_mean), ema.average(fc_var)))
            v = tf.nn.batch_normalization(v, mean, var, shift, scale, epsilon)

        for i in range(self.iterations - 1):
            b += tf.reduce_sum(u_ * v, axis=-1, keepdims=True)
            c = tf.nn.softmax(b, axis=5)
            s = tf.reduce_sum(u_ * c, axis=[1, 2, 3], keepdims=True)
            v = self.squash(s, axis=-1)

        v_digit = tf.squeeze(v)

        return v_digit,reg_term

    def get_mlp_decoder(self, h, num_h=[10 * 16, 512, 1024, 784], reuse=False):
        h = tf.reshape(h, [-1, 10 * 16])
        with tf.variable_scope('decoder', reuse=reuse):
            weights = []
            for i in range(len(num_h) - 1):
                w = tf.get_variable(
                    'wfc%d' % i, [num_h[i], num_h[i + 1]],
                    initializer=tf.truncated_normal_initializer(stddev=0.02))
                b = tf.get_variable(
                    'bfc%d' % i, [num_h[i + 1]],
                    initializer=tf.truncated_normal_initializer(stddev=0.02))
                weights.append((w, b))
                if i < len(num_h) - 2:
                    h = tf.nn.relu(tf.add(tf.matmul(h, w), b))
                else:
                    h = tf.nn.sigmoid(tf.add(tf.matmul(h, w), b))
        x_rec = tf.reshape(h, [-1, 28, 28, 1])
        return x_rec

    def squash(self, s, axis=-1):
        length_s = tf.reduce_sum(s**2.0, axis=axis, keepdims=True)**0.5
        v = s * length_s / (1.0 + length_s**2.0)
        return v

In [4]:
batch_size = 64
is_multi_mnist = True
is_shift_ag = True
irun = 0
steps = 5000

In [5]:
if is_multi_mnist:
    train_iter = multimnist_train_iter(
        iters=steps, batch_size=batch_size, is_shift_ag=True)
    test_iter = multimnist_test_iter(
        iters=steps, batch_size=batch_size, is_shift_ag=True)
else:
    train_iter = mnist_train_iter(
        iters=steps, batch_size=batch_size, is_shift_ag=True)
    test_iter = mnist_test_iter(
        iters=steps, batch_size=batch_size, is_shift_ag=True)

net = CapsNet(is_multi_mnist=is_multi_mnist, steps=steps)
tf.summary.scalar('error_rate_on_test_set', (1.0 - net.accuracy) * 100.0)
tf.summary.scalar('loss_reconstruction_on_test_set', net.loss_rec)
merged = tf.summary.merge_all()
init = tf.global_variables_initializer()

sess = tf.Session()

sess.run(init)

for X, Y in train_iter:
    X_TEST, Y_TEST = next(test_iter)

    LS, LS_REC, ACC, _ = sess.run(
        [net.loss, net.loss_rec, net.accuracy, net.train],
        feed_dict={
            net.x: X,
            net.y: Y,
            net.on_train: True
        })
    ACC_TEST, result = sess.run(
        [net.accuracy, merged],
        feed_dict={
            net.x: X_TEST,
            net.y: Y_TEST,
            net.on_train: False
        })

    if irun % 100 == 0:
        print(irun, LS, LS_REC, ACC, ACC_TEST)

    irun += 1

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0 1.6316246 182.09749 0.1796875 0.3046875
100 0.68395054 53.946007 0.59375 0.71875
200 0.5422913 50.462147 0.7265625 0.7890625
300 0.48848277 51.95821 0.7578125 0.7265625
400 0.46840346 44.144913 0.7578125 0.7265625
500 0.4508866 47.69269 0.7890625 0.6875
600 0.4384737 43.762383 0.7734375 0.765625
700 0.45832253 50.35829 0.765625 0.7109375
800 0.46756113 47.370953 0.7578125 0.765625
900 0.3931062 43.12308 0.78125 0.7421875
1000 0.332156 41.22647 0.84375 0.828125
1100 0.41959065 42.84593 0.8125 0.7890625
1200 0.43307418 45.503494 0.75 0.7109375
1300 0.328545 38.79215 0.8359375 0.796875
1400 0.3981854 44.53