In [1]:
import os
import tensorflow.compat.v1 as tf
import numpy as np 
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
tf.compat.v1.disable_eager_execution()

In [2]:
batch_size = 128
iter_routing = 3
m_plus = 0.9
m_minus = 0.1
lambda_val = 0.5
regularization_scale = 0.392
epoch = 50
val_sum_freq = 500

In [3]:
x_input = tf.placeholder(tf.float32, shape=[batch_size, 28, 28, 1], name='x_input')
y_input = tf.placeholder(tf.int32, shape= [batch_size, ], name='y_input')
y_onehot = tf.one_hot(y_input, depth=10, axis=1, dtype=tf.float32)

In [4]:
def squash(vector):
    # vector的形状为[batch_size, num_caps, vec_len, 1]
    # 这里是squash函数的实现
    vec_squared_norm = tf.reduce_sum(tf.square(vector), axis=-2, keepdims=True)
    scalar_factor = vec_squared_norm / (
        1 + vec_squared_norm) / tf.sqrt(vec_squared_norm + 1e-6)
    vec_squashed = scalar_factor * vector
    return (vec_squashed)

In [5]:
def get_shape(inputs, name=None):
    name = "shape" if name is None else name
    with tf.name_scope(name):
        static_shape = inputs.get_shape().as_list()
        dynamic_shape = tf.shape(inputs)
        shape = []
        for i, dim in enumerate(static_shape):
            dim = dim if dim is not None else dynamic_shape[i]
            shape.append(dim)
        return(shape)

In [6]:
def routing(input, b_IJ, num_outputs=10, num_dims=16):
    input_shape = get_shape(input)
    # [1, num_caps_i, num_caps_j * len_v_j, len_u_j, 1]
    W = tf.get_variable('weight',
                        shape=[1, input_shape[1], num_dims * num_outputs] +
                        input_shape[-2:],
                        dtype=tf.float32,
                        initializer=tf.random_normal_initializer(stddev=0.01))
    biases = tf.get_variable('bias', shape=(1, 1, num_outputs, num_dims, 1))
    # 计算u_hat
    input = tf.tile(input, [1, 1, num_dims * num_outputs, 1, 1])
    u_hat = tf.reduce_sum(W * input, axis=3, keepdims=True)
    u_hat = tf.reshape(u_hat,
                       shape=[-1, input_shape[1], num_outputs, num_dims, 1])
    u_hat_stopped = tf.stop_gradient(u_hat, name='stop_gradient')
    for i in range(iter_routing):
        with tf.variable_scope('iter_' + str(i)):
            c_IJ = tf.nn.softmax(b_IJ, axis=2)
            if i == iter_routing - 1:
                s_J = tf.multiply(c_IJ, u_hat)
                s_J = tf.reduce_sum(s_J, axis=1, keepdims=True) + biases
                v_J = squash(s_J)
            elif i < iter_routing - 1:
                s_J = tf.multiply(c_IJ, u_hat_stopped)
                s_J = tf.reduce_sum(s_J, axis=1, keepdims=True) + biases
                v_J = squash(s_J)

                v_J_tiled = tf.tile(v_J, [1, input_shape[1], 1, 1, 1])
                u_produce_v = tf.reduce_sum(u_hat_stopped * v_J_tiled,
                                            axis=3,
                                            keepdims=True)
                b_IJ = u_produce_v
    return v_J

In [7]:
with tf.variable_scope('conv1_layer'):
    # conv1 的形状为[batch_size, 20, 20, 256]
    conv1 = tf.layers.conv2d(x_input,
                             filters=256,
                             kernel_size=9,
                             strides=1,
                             padding='valid')
with tf.variable_scope('primarycaps_layer'):
    # 这步输出形状[batch_size, 6, 6, 256]
    capsules = tf.layers.conv2d(conv1,
                                filters=32 * 8,
                                kernel_size=9,
                                strides=2,
                                padding='valid',
                                activation=tf.nn.relu)
    # 这步输出形状[batch_size, 1152, 8, 1]，相当于1152个胶囊
    capsules = tf.reshape(capsules, (batch_size, -1, 8, 1))
    # 这步输出形状[batch_size, 1152, 8, 1]
    caps1 = squash(capsules)
with tf.variable_scope('digitcaps_layer'):
    # 这步输出形状[batch_size, 1152, 1, 8, 1]
    digitcaps_input = tf.reshape(caps1,
                                 shape=(batch_size, -1, 1, caps1.shape[-2], 1))
    with tf.variable_scope('routing'):
        # 这步输出形状[batch_size, num_caps_l, num_caps_l_plus_1, 1, 1]
        b_IJ = tf.constant(
            np.zeros([batch_size, caps1.shape[1], 10, 1, 1], dtype=np.float32))
        # 这步输出形状[128, 1, 10, 16, 1]
        capsules = routing(digitcaps_input, b_IJ, num_outputs=10, num_dims=16)
        # 这步输出形状[batch_size, 10, 16, 1]
        caps2 = tf.squeeze(capsules, axis=1)
with tf.variable_scope('masking'):
    # [batch_size, 10, 16, 1] => [batch_size, 10, 1, 1]
    v_length = tf.sqrt(
        tf.reduce_sum(tf.square(caps2), axis=2, keepdims=True) + 1e-9)
    # (128, 10, 1, 1)
    softmax_v = tf.nn.softmax(v_length, axis=1)
    argmax_idx = tf.to_int32(tf.argmax(softmax_v, axis=1))
    argmax_idx = tf.reshape(argmax_idx, shape=(batch_size, ))
    masked_v = tf.multiply(
        tf.squeeze(caps2),
        tf.reshape(y_onehot,
                   (-1, 10, 1)))
    # 这步输出形状[128, 10, 1, 1]
    v_length = tf.sqrt(
        tf.reduce_sum(tf.square(caps2), axis=2, keepdims=True) + 1e-9)
with tf.variable_scope('decoder'):
    vector_j = tf.reshape(masked_v, shape=(batch_size, -1))
    fc1 = tf.layers.dense(vector_j, units=512, activation=tf.nn.relu)
    fc2 = tf.layers.dense(fc1, units=1024)
    decoded = tf.layers.dense(fc2,
                              units=28 * 28 * 1,
                              activation=tf.nn.sigmoid)

Instructions for updating:
Use `tf.keras.layers.Conv2D` instead.
Instructions for updating:
Please use `layer.__call__` method instead.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Instructions for updating:
Use `tf.cast` instead.
Instructions for updating:
Use keras.layers.Dense instead.


In [8]:
# loss
## margin loss
max_l = tf.square(tf.maximum(0., m_plus - v_length))
max_r = tf.square(tf.maximum(0., v_length - m_minus))
max_l = tf.reshape(max_l, shape=(batch_size, -1))
max_r = tf.reshape(max_r, shape=(batch_size, -1))
T_c = tf.reshape(y_onehot, shape=(batch_size, -1))
L_c = T_c * max_l + lambda_val * (1 - T_c) * max_r
margin_loss = tf.reduce_mean(tf.reduce_sum(L_c, axis=1))
## reconstruction loss
origin = tf.reshape(x_input, shape=(batch_size, -1))
squared = tf.square(decoded - origin)
reconstruction_err = tf.reduce_mean(squared)
total_loss = margin_loss + regularization_scale * reconstruction_err
train_op = tf.train.AdamOptimizer().minimize(total_loss)

In [9]:
# accuracy
correct_prediction = tf.equal(tf.to_int32(y_input), argmax_idx)
print(correct_prediction.get_shape())
accuracy = tf.reduce_sum(tf.cast(correct_prediction, tf.float32))

(128,)


In [10]:
def load_mnist(batch_size, is_training=True):
    path = os.path.join('data','mnist')
    if is_training:
        fd = open(os.path.join(path, 'train-images-idx3-ubyte'))
        loaded = np.fromfile(file=fd, dtype=np.uint8)
        trainX = loaded[16:].reshape((60000, 28, 28, 1)).astype(np.float32)
        fd = open(os.path.join(path, 'train-labels-idx1-ubyte'))
        loaded = np.fromfile(file=fd, dtype=np.uint8)
        trainY = loaded[8:].reshape((60000)).astype(np.int32)

        trX = trainX[:55000] / 255.
        trY = trainY[:55000]

        valX = trainX[55000:, ] / 255.
        valY = trainY[55000:]

        num_tr_batch = 55000 // batch_size
        num_val_batch = 5000 // batch_size

        return trX, trY, num_tr_batch, valX, valY, num_val_batch
    else:
        fd = open(os.path.join(path, 't10k-images-idx3-ubyte'))
        loaded = np.fromfile(file=fd, dtype=np.uint8)
        teX = loaded[16:].reshape((10000, 28, 28, 1)).astype(np.float)

        fd = open(os.path.join(path, 't10k-labels-idx1-ubyte'))
        loaded = np.fromfile(file=fd, dtype=np.uint8)
        teY = loaded[8:].reshape((10000)).astype(np.int32)

        num_te_batch = 10000 // batch_size
        return teX / 255., teY, num_te_batch

In [11]:
trX, trY, num_tr_batch, valX, valY, num_val_batch = load_mnist(
    batch_size, is_training=True)
print(trY)

[5 0 4 ... 0 4 0]


In [12]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(epoch):
        for step in range(num_tr_batch):
            start = step * batch_size
            end = start + batch_size
            end = end if end < len(trX) else len(trX)
            global_step = i * num_tr_batch + step
            train_x = trX[start:end]
            train_y = trY[start:end]
            feed_dict = {x_input: train_x, y_input: train_y}
            l, o = sess.run([total_loss, train_op], feed_dict)
            if step % 20 == 0:
                print('train-- ', 'i: ', i, 'step: ', step, 'loss: ', l)
            if global_step!=0 and global_step % val_sum_freq == 0:
                val_acc = 0
                for j in range(num_val_batch):
                    val_start = j * batch_size
                    val_end = val_start + batch_size
                    val_end = val_end if val_end < len(valX) else len(valX)
                    acc = sess.run(
                        [accuracy], {
                            x_input: valX[val_start:val_end],
                            y_input: valY[val_start:val_end]
                        })
                    val_acc += acc[0]
                val_acc = val_acc / (batch_size * num_val_batch)
                print('val acc: ', val_acc)

train--  i:  0 step:  0 loss:  0.6515322
train--  i:  0 step:  20 loss:  0.16547106
train--  i:  0 step:  40 loss:  0.13412395
train--  i:  0 step:  60 loss:  0.10704443
train--  i:  0 step:  80 loss:  0.11876654
train--  i:  0 step:  100 loss:  0.082290396
train--  i:  0 step:  120 loss:  0.06788447
train--  i:  0 step:  140 loss:  0.064370625
train--  i:  0 step:  160 loss:  0.050058324
train--  i:  0 step:  180 loss:  0.06635042
train--  i:  0 step:  200 loss:  0.059301212
train--  i:  0 step:  220 loss:  0.05640088
train--  i:  0 step:  240 loss:  0.052899007
train--  i:  0 step:  260 loss:  0.05980049
train--  i:  0 step:  280 loss:  0.05478416
train--  i:  0 step:  300 loss:  0.050402783
train--  i:  0 step:  320 loss:  0.0560572
train--  i:  0 step:  340 loss:  0.04512105
train--  i:  0 step:  360 loss:  0.050112695
train--  i:  0 step:  380 loss:  0.041668862
train--  i:  0 step:  400 loss:  0.07171888
train--  i:  0 step:  420 loss:  0.045154043
train--  i:  1 step:  0 loss:  

KeyboardInterrupt: 