# 用TensorFlow实现胶囊网络
包含三层CapsNet和由三层全连接层构成的重构网络，但是只训练三层CapsNet

In [None]:
import tensorflow as tf
import numpy as np
import os
from tqdm import tqdm # 进度条提示

epsilon = 1e-9
batch_size = 8 # 每次训练读取样本数
epoch = 1 # 所有的样本训练一遍

# 定义边缘损失函数margin loss的参数，λ、m+和m-
lambda_val = 0.5  
m_plus = 0.9  
m_minus = 0.1  

# 动态路由迭代次数
iter_routing = 3

# Tensorboard保存位置
logdir ='logdir'
# 数据集路径
dataset_path = 'MNIST_data'
is_training= True  # True表示训练网络，False表示用训练好的网络测试数据

In [None]:
# 定义加载mnist数据集的函数
def load_mnist(path, is_training):

    #trX,trY将加载储存所有60000张灰度图
    fd = open(os.path.join(path, 'train-images.idx3-ubyte'))
    loaded = np.fromfile(file=fd, dtype=np.uint8)
    trX = loaded[16:].reshape((60000, 28, 28, 1)).astype(np.float)

    fd = open(os.path.join(path, 'train-labels.idx1-ubyte'))
    loaded = np.fromfile(file=fd, dtype=np.uint8)
    trY = loaded[8:].reshape((60000)).astype(np.float)

    #teX,teY将储存所有一万张测试用的图片
    fd = open(os.path.join(path, 't10k-images.idx3-ubyte'))
    loaded = np.fromfile(file=fd, dtype=np.uint8)
    teX = loaded[16:].reshape((10000, 28, 28, 1)).astype(np.float)

    fd = open(os.path.join(path, 't10k-labels.idx1-ubyte'))
    loaded = np.fromfile(file=fd, dtype=np.uint8)
    teY = loaded[8:].reshape((10000)).astype(np.float)

    # 将所有训练图片表示为一个4维张量 [60000, 28, 28, 1]，其中每个像素值缩放到0和1之间
    trX = tf.convert_to_tensor(trX / 255., tf.float32)

    # one hot编码为 [num_samples, 10]
    trY = tf.one_hot(trY, depth=10, axis=1, dtype=tf.float32)
    teY = tf.one_hot(teY, depth=10, axis=1, dtype=tf.float32)

    # 训练和测试时返回不同的数据
    if is_training:
        return trX, trY
    else:
        return teX / 255., teY

def get_batch_data():
    trX, trY = load_mnist(dataset_path, True)

    # 每次从一个tensor列表中按顺序或者随机抽取出一个tensor放入文件名队列
    data_queues = tf.train.slice_input_producer([trX, trY])

    # 对队列中的样本进行乱序处理,简单来说读取一个文件并且加载一个张量中的batch_size行
    X, Y = tf.train.shuffle_batch(data_queues,
                                  batch_size=batch_size,
                                  capacity=batch_size * 64,
                                  min_after_dequeue=batch_size * 32,
                                  allow_smaller_final_batch=False)
    return (X, Y)

In [None]:
# 通过定义CapsLayer类构建PrimaryCaps层和DigitCaps层
class CapsLayer(object):
    ''' Capsule layer类别参数有：
    Args:
        input_x: 输入为4维张量
        num_outputs: 对于PrimaryCaps层为卷积核数量，对于DigitCaps层为胶囊数量
        vec_len: 每个Capsule输出向量长度
        layer_type: 选择'FC' 或 "CONV", 以确定是用全连接层(DigitCaps)还是卷积层(PrimaryCaps)
        with_routing: 当前Capsule是否从较低层级中Routing而得出输出向量

    Returns:
        一个四维张量
    '''
    def __init__(self, num_outputs, vec_len, layer_type='FC', with_routing=True):
        # 构造函数
        self.num_outputs = num_outputs
        self.vec_len = vec_len
        self.with_routing = with_routing
        self.layer_type = layer_type

    def __call__(self, input_x, kernel_size=None, stride=None):
        '''
        使实例能够像函数一样被调用；如：假设x是X类的一个实例，那么调用x(1,2)等同于调用x.__call__(1,2)
        当“Layer_type”选择的是“CONV”，则使用 'kernel_size' 和 'stride'
        '''
        
        # 构建PrimaryCaps层
        if self.layer_type == 'CONV':
            self.kernel_size = kernel_size # 卷积核大小
            self.stride = stride # 卷积步长

            # PrimaryCaps层没有Routing过程
            if not self.with_routing:
                # PrimaryCaps层输入张量的维度为： [batch_size, 20, 20, 256]
                assert input_x.get_shape() == [batch_size, 20, 20, 256]

                # 用32个9×9的卷积核执行8遍卷积操作,num_outputs=32, vec_len=8
                capsules = tf.contrib.layers.conv2d(input_x, self.num_outputs * self.vec_len,
                                        self.kernel_size, self.stride, padding="VALID")
                capsules = tf.reshape(capsules, (batch_size, -1, self.vec_len, 1))

                # 输出张量的维度应为： [batch_size, 6*6*32, 8, 1]
                return (squashing(capsules))
        
        # 构建DigitCaps层
        if self.layer_type == 'FC':

            # DigitCaps层带有Routing过程
            if self.with_routing:
                # 将输入张量重建为 [batch_size, 6*6*32, 1, 8, 1]
                self.input_x = tf.reshape(input_x, shape=(batch_size, -1, 1, input_x.shape[-2].value, 1))

                with tf.variable_scope('routing'): # 让变量有相同的命名"routing",达到重复利用变量的效果
                    # 初始化b_ij的值为零，且维度满足: [1,6*6*32,10,1,1]
                    b_IJ = tf.constant(np.zeros([1, input_x.shape[1].value, self.num_outputs, 1, 1], dtype=np.float32))
                    # 动态路由过程
                    v_J = routing(self.input_x, b_IJ)
                    # tf.squeeze删除axis=1（即第2个）大小为1的维度，如果第二个维度不为1则报错，例如shape[2,1,4,1,2]->shape[2,4,1,2]
                    capsules = tf.squeeze(v_J, axis=1)
                    
                    # 最终输出结果维度为[batch_size, 10, 16, 1]
                    return(capsules)

In [None]:
# 定义squashing激活函数
def squashing(vector):
    ''' 
    Args:
        vector: PrimaryCaps层是4维张量，DigitCaps层是5维张量
    Returns:
        返回维度与vector对应的张量
    '''
    vec_squared_norm = tf.reduce_sum(tf.square(vector), -2, keep_dims=True) # 计算vector的L2范式
    scalar_factor = vec_squared_norm / (1 + vec_squared_norm) / tf.sqrt(vec_squared_norm + epsilon)
    vec_squashed = scalar_factor * vector  # 元素积
    return(vec_squashed)

In [None]:
# 动态路由算法
def routing(input_x, b_IJ):
    ''' 
    Args:
        input_x: 输入张量的维度为 [batch_size, 6*6*32, 1, 8, 1]
    Returns:
        返回的张量维度为 [batch_size, 1, 10, 16, 1]
     '''

    # 线性组合部分：计算 u_hat
    # W_ij共有6*6*32×10个，每一个的维度为8×16
    W = tf.get_variable('Weight', shape=(1, 6*6*32, 10, 8, 16), dtype=tf.float32,
                        initializer=tf.random_normal_initializer(stddev=0.01))  
    # 调整张量维度W => [batch_size, 1152, 10, 8, 16]
    W = tf.tile(W, [batch_size, 1, 1, 1, 1])
    # 调整张量维度input_x => [batch_size, 6*6*32, 10, 8, 1]
    input_x = tf.tile(input_x, [1, 1, 10, 1, 1])
    assert input_x.get_shape() == [batch_size, 6*6*32, 10, 8, 1]
    # 矩阵相乘；[8, 16] × [8, 1] => [16, 1]，所以矩阵乘法在最后得出的维度为 [batch_size, 6*6*32, 10, 16, 1]
    u_hat = tf.matmul(W, input_x, transpose_a=True)
    assert u_hat.get_shape() == [batch_size, 6*6*32, 10, 16, 1]

    # 动态路由部分：迭代更新c_ij，得到v_j
    for r_iter in range(iter_routing):
        with tf.variable_scope('iter_' + str(r_iter)):
            
            # 计算c_ij=softmax(b_ij)
            c_IJ = tf.nn.softmax(b_IJ, dim=3)
            c_IJ = tf.tile(c_IJ, [batch_size, 1, 1, 1, 1])
            assert c_IJ.get_shape() == [batch_size, 6*6*32, 10, 1, 1]

            # 计算s_j
            s_J = tf.multiply(c_IJ, u_hat)
            s_J = tf.reduce_sum(s_J, axis=1, keep_dims=True)
            assert s_J.get_shape() == [batch_size, 1, 10, 16, 1]

            # 计算v_j
            v_J = squashing(s_J)
            assert v_J.get_shape() == [batch_size, 1, 10, 16, 1]

            # 更新b_ij
            v_J_tiled = tf.tile(v_J, [1, 6*6*32, 1, 1, 1])
            u_produce_v = tf.matmul(u_hat, v_J_tiled, transpose_a=True)
            assert u_produce_v.get_shape() == [batch_size, 6*6*32, 10, 1, 1]
            b_IJ += tf.reduce_sum(u_produce_v, axis=0, keep_dims=True)

    return(v_J)

In [None]:
# 构建胶囊网络
class CapsNet():
    def __init__(self, is_training=True):
        self.graph = tf.Graph()
        with self.graph.as_default():
            if is_training:
                
                self.X, self.Y = get_batch_data() # 获取一个批量的训练数据

                self.build_arch() # 构建胶囊网络
                self.loss() # 构建损失函数

                # 利用Adam优化器训练模型
                self.optimizer = tf.train.AdamOptimizer()
                self.global_step = tf.Variable(0, name='global_step', trainable=False) # 为全局步骤计数
                self.train_op = self.optimizer.minimize(self.total_loss, global_step=self.global_step)  # var_list=t_vars)
            
            else:
                self.X = tf.placeholder(tf.float32, shape=(batch_size, 28, 28, 1))
                self.build_arch()

        tf.logging.info('Setting up the main structure')

    # 构建整个胶囊网络架构
    def build_arch(self):
        # 构建第一个卷积层；输出张量为：[batch_size, 20, 20, 256]
        with tf.variable_scope('Conv1_layer'):
            # 采用256个9×9的卷积核，步幅为1进行卷积
            conv1 = tf.contrib.layers.conv2d(self.X, num_outputs=256,kernel_size=9, stride=1,padding='VALID')
            assert conv1.get_shape() == [batch_size, 20, 20, 256]

        # 构建PrimaryCaps层；输出张量为：[batch_size, 6*6*32, 8, 1]
        with tf.variable_scope('PrimaryCaps_layer'):
            # 采用32个9×9的卷积核，步幅为2进行8遍卷积
            primaryCaps = CapsLayer(num_outputs=32, vec_len=8, with_routing=False, layer_type='CONV')
            caps1 = primaryCaps(conv1, kernel_size=9, stride=2)
            assert caps1.get_shape() == [batch_size, 6*6*32, 8, 1]

        # 构建DigitCaps层；输出张量为：[batch_size, 10, 16, 1]
        with tf.variable_scope('DigitCaps_layer'):
            # DigitCaps是最后一层，它返回对应10个类别的向量（向量长度为16）
            digitCaps = CapsLayer(num_outputs=10, vec_len=16, with_routing=True, layer_type='FC')
            self.caps2 = digitCaps(caps1)

        # 构建重构网络
        with tf.variable_scope('Masking'):
            # masked_v为正确的分类输出，维度为[batch_size, 1, 16, 1]
            self.masked_v = tf.matmul(tf.squeeze(self.caps2), tf.reshape(self.Y, (-1, 10, 1)), transpose_a=True)
            self.v_length = tf.sqrt(tf.reduce_sum(tf.square(self.caps2), axis=2, keep_dims=True) + epsilon)

        # 通过3个全连接层重构MNIST图像，这三个全连接层的神经元数分别为512、1024、784
        #  [batch_size, 1, 16, 1] => [batch_size, 16] => [batch_size, 512] => [batch_size, 1024]=> [batch_size, 784]
        with tf.variable_scope('Decoder'):
            vector_j = tf.reshape(self.masked_v, shape=(batch_size, -1))
            fc1 = tf.contrib.layers.fully_connected(vector_j, num_outputs=512)
            assert fc1.get_shape() == [batch_size, 512]
            fc2 = tf.contrib.layers.fully_connected(fc1, num_outputs=1024)
            assert fc2.get_shape() == [batch_size, 1024]
            self.decoded = tf.contrib.layers.fully_connected(fc2, num_outputs=784, activation_fn=tf.sigmoid)

    # 
    ，Margin loss + Reconstruction loss
    def loss(self):
        
        # 构建Margin loss
        # max_l = max(0, m_plus-||v_c||)^2 ,维度为[batch_size, 10, 1, 1]
        max_l = tf.square(tf.maximum(0., m_plus - self.v_length))
        # max_r = max(0, ||v_c||-m_minus)^2 ,维度为[batch_size, 10, 1, 1]
        max_r = tf.square(tf.maximum(0., self.v_length - m_minus))
        assert max_l.get_shape() == [batch_size, 10, 1, 1]
        # 将当前的维度[batch_size, 10, 1, 1] 转换为10个数字类别的one-hot编码 [batch_size, 10]
        max_l = tf.reshape(max_l, shape=(batch_size, -1))
        max_r = tf.reshape(max_r, shape=(batch_size, -1))

        # T_c和Y都为One-hot编码，维度为[batch_size, 10]
        T_c = self.Y
        # 计算L_c，维度为[batch_size, 10], 
        L_c = T_c * max_l + lambda_val * (1 - T_c) * max_r
        # 计算最终的Margin loss
        self.margin_loss = tf.reduce_mean(tf.reduce_sum(L_c, axis=1))

        # 构建Reconstruction loss
        # 通过计算FC Sigmoid层的输出像素点与原始图像像素点间的欧几里德距离而构建
        orgin = tf.reshape(self.X, shape=(batch_size, -1))
        squared = tf.square(self.decoded - orgin)
        self.reconstruction_err = tf.reduce_mean(squared)

        # 构建总损失函数
        # Hinton论文将Reconstruction loss乘上0.0005，以使它不会主导训练过程中的Margin loss
        self.total_loss = self.margin_loss + 0.0005 * self.reconstruction_err

        # 以下输出至TensorBoard
        tf.summary.scalar('margin_loss', self.margin_loss)
        tf.summary.scalar('reconstruction_loss', self.reconstruction_err)
        tf.summary.scalar('total_loss', self.total_loss)
        recon_img = tf.reshape(self.decoded, shape=(batch_size, 28, 28, 1))
        tf.summary.image('reconstruction_img', recon_img)
        self.merged_sum = tf.summary.merge_all() # 将之前定义的所有summary op整合到一起

In [None]:
if __name__ == "__main__":
    # 训练和推断
    capsNet = CapsNet(is_training=is_training)
    tf.logging.info('Graph loaded')
    sv = tf.train.Supervisor(graph=capsNet.graph,
                             logdir=logdir,
                             save_model_secs=0) # logdir用来保存checkpoint和summary,详见https://blog.csdn.net/mijiaoxiaosan/article/details/75021279

    with sv.managed_session() as sess: # 会自动去logdir中去找checkpoint，如果没有的话，自动执行初始化
        num_batch = int(60000 / batch_size)
        for epoch in range(epoch):
            if sv.should_stop():
                break
            for step in tqdm(range(num_batch), total=num_batch, ncols=70, leave=False, unit='b'):
                sess.run(capsNet.train_op)

            global_step = sess.run(capsNet.global_step) # 输出结果中global_step/sec：是一种性能指标，显示我们在特定批处理（x 轴）中每秒处理了多少批处理（y 轴，梯度更新）
            sv.saver.save(sess, logdir + '/model_epoch_%04d_step_%02d' % (epoch, global_step))

    tf.logging.info('Training done')