In [None]:
import tensorflow as tf
import numpy as np
import keras

Using TensorFlow backend.


In [None]:
from keras.datasets import cifar100
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

In [None]:
num_label_classes = 100
input_image_shape = [32,32,3]
batch_size = 32
dim_capsule = 16
num_routings = 3

In [None]:
x = tf.keras.layers.Input(shape=input_image_shape, batch_size=batch_size)

In [None]:
conv1 = tf.keras.layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x)

In [None]:
primary_capsule = tf.keras.layers.Conv2D(filters=256, kernel_size=9, strides=2, padding = 'valid', name='primarycap_conv2d')(conv1)

In [None]:
primary_capsule = tf.keras.layers.Reshape(target_shape=[-1, 8], name='primarycap_reshape')(primary_capsule)

In [None]:
def squash(vectors, axis=-1):
 
    s_squared_norm = tf.reduce_sum(tf.square(vectors), axis, keepdims=True)
    scale = s_squared_norm / (1 + s_squared_norm) / tf.sqrt(s_squared_norm + 0.000000000000000000001)
    return scale * vectors

In [None]:
primary_capsule_output = tf.keras.layers.Lambda(squash, name='primarycap_squash')(primary_capsule)

In [None]:
primary_capsule_output.shape[1]

2048

In [None]:
Z = tf.keras.layers.Layer()
W = Z.add_weight(shape=[num_label_classes, primary_capsule_output.shape[1], dim_capsule, primary_capsule_output.shape[2]], initializer=tf.keras.initializers.GlorotUniform)

In [None]:
primary_capsule_output_expanded = tf.expand_dims(tf.expand_dims(primary_capsule_output, 1), -1)

In [None]:
primary_capsule_output_tiled = tf.tile(primary_capsule_output_expanded, [1, num_label_classes, 1, 1, 1])

In [None]:
primary_capsule_output_hat = tf.squeeze(tf.map_fn(lambda x: tf.matmul(W, x), elems=primary_capsule_output_tiled))

In [None]:
b = tf.zeros(shape=[primary_capsule_output.shape[0], num_label_classes, 1, primary_capsule_output.shape[1]])

In [None]:
for i in range(num_routings):
            c = tf.nn.softmax(b, axis=1)
            outputs = squash(tf.matmul(c, primary_capsule_output_hat))

            if i < num_routings - 1:
                b = b + tf.matmul(outputs, primary_capsule_output_hat, transpose_b=True)

digit_caps = tf.squeeze(outputs, name='digitcaps')

In [None]:
class Length(tf.keras.layers.Layer):
    """
    Compute the length of vectors. This is used to compute a Tensor that has the same shape with y_true in margin_loss.
    Using this layer as model's output can directly predict labels by using `y_pred = np.argmax(model.predict(x), 1)`
    inputs: shape=[None, num_vectors, dim_vector]
    output: shape=[None, num_vectors]
    """
    def call(self, inputs, **kwargs):
        return tf.sqrt(tf.reduce_sum(tf.square(inputs), -1) + 0.0000000000001)

    def compute_output_shape(self, input_shape):
        return input_shape[:-1]

    def get_config(self):
        config = super(Length, self).get_config()
        return config

In [None]:
out_caps = Length(name='capsnet')(digit_caps)

In [None]:
#out_caps = tf.sqrt(tf.reduce_sum(tf.square(digit_caps), -1) + 0.0000000000001, name = 'capsnet')

In [None]:
out_caps.shape

TensorShape([32, 100])

In [None]:
y = tf.keras.layers.Input(shape=(num_label_classes,))

In [None]:
import tensorflow.keras.backend as K

In [None]:
class Mask(tf.keras.layers.Layer):
    """
    Mask a Tensor with shape=[None, num_capsule, dim_vector] either by the capsule with max length or by an additional 
    input mask. Except the max-length capsule (or specified capsule), all vectors are masked to zeros. Then flatten the
    masked Tensor.
    For example:
        ```
        x = keras.layers.Input(shape=[8, 3, 2])  # batch_size=8, each sample contains 3 capsules with dim_vector=2
        y = keras.layers.Input(shape=[8, 3])  # True labels. 8 samples, 3 classes, one-hot coding.
        out = Mask()(x)  # out.shape=[8, 6]
        # or
        out2 = Mask()([x, y])  # out2.shape=[8,6]. Masked with true labels y. Of course y can also be manipulated.
        ```
    """
    def call(self, inputs, **kwargs):
        if type(inputs) is list:  # true label is provided with shape = [None, n_classes], i.e. one-hot code.
            assert len(inputs) == 2
            inputs, mask = inputs
        else:  # if no true label, mask by the max length of capsules. Mainly used for prediction
            # compute lengths of capsules
            x = tf.sqrt(tf.reduce_sum(tf.square(inputs), -1))
            # generate the mask which is a one-hot code.
            # mask.shape=[None, n_classes]=[None, num_capsule]
            mask = tf.one_hot(indices=tf.argmax(x, 1), depth=x.shape[1])

        # inputs.shape=[None, num_capsule, dim_capsule]
        # mask.shape=[None, num_capsule]
        # masked.shape=[None, num_capsule * dim_capsule]
        masked = K.batch_flatten(inputs * tf.expand_dims(mask, -1))
        return masked

    def compute_output_shape(self, input_shape):
        if type(input_shape[0]) is tuple:  # true label provided
            return tuple([None, input_shape[0][1] * input_shape[0][2]])
        else:  # no true label provided
            return tuple([None, input_shape[1] * input_shape[2]])

    def get_config(self):
        config = super(Mask, self).get_config()
        return config


In [None]:
# Decoder network.
masked_by_y = Mask()([digit_caps, y])  # The true label is used to mask the output of capsule layer. For training
masked = Mask()(digit_caps)  # Mask using the capsule with maximal length. For prediction


In [None]:
#masked_by_y = K.batch_flatten(digit_caps * tf.expand_dims(y, -1))

In [None]:
#z = tf.sqrt(tf.reduce_sum(tf.square(digit_caps), -1))

In [None]:
#mask = tf.one_hot(indices=tf.argmax(z, 1), depth=z.shape[1])

In [None]:
#masked = K.batch_flatten(digit_caps * tf.expand_dims(mask, -1))

In [None]:
decoder = tf.keras.models.Sequential(name='decoder')
decoder.add(tf.keras.layers.Dense(512, activation='relu', input_dim=16 * num_label_classes))
decoder.add(tf.keras.layers.Dense(1024, activation='relu'))
decoder.add(tf.keras.layers.Dense(np.prod(input_image_shape), activation='sigmoid'))
decoder.add(tf.keras.layers.Reshape(target_shape=input_image_shape, name='out_recon'))

In [None]:
train_model = tf.keras.models.Model([x, y], [out_caps, decoder(masked_by_y)])
eval_model = tf.keras.models.Model(x, [out_caps, decoder(masked)])

In [None]:
noise = tf.keras.layers.Input(shape=(num_label_classes, 16))

In [None]:
noised_digit_caps = tf.keras.layers.Add()([digit_caps, noise])

In [None]:
 masked_noised_y = Mask()([noised_digit_caps, y])

In [None]:
#masked_noised_y = K.batch_flatten(noised_digit_caps * tf.expand_dims(y, -1))

In [None]:
manipulate_model = tf.keras.models.Model([x, y, noise], decoder(masked_noised_y))

In [None]:
def margin_loss(y_true, y_pred):
  
    L = y_true * tf.square(tf.maximum(0., 0.9 - y_pred)) + \
        0.5 * (1 - y_true) * tf.square(tf.maximum(0., y_pred - 0.1))

    return tf.reduce_mean(tf.reduce_sum(L, 1))

In [None]:
log = tf.keras.callbacks.CSVLogger('/log.csv')

checkpoint = tf.keras.callbacks.ModelCheckpoint('/weights-{epoch:02d}.h5', monitor='val_capsnet_acc', save_best_only=True, save_weights_only=True, verbose=1)

lr_decay = tf.keras.callbacks.LearningRateScheduler(schedule=lambda epoch: 0.001 * (0.9 ** epoch))

In [None]:
# compile the model

train_model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001),
                loss=[margin_loss, 'mse'],
                loss_weights=[1., 0.392],
                metrics={'capsnet': 'accuracy'})

In [None]:
train_model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(32, 32, 32, 3)]    0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (32, 24, 24, 256)    62464       input_1[0][0]                    
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (32, 8, 8, 256)      5308672     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Reshape)    (32, 2048, 8)        0           primarycap_conv2d[0][0]          
______________________________________________________________________________________________

In [None]:
train_model.fit([x_train, y_train], [y_train, x_train], batch_size=32, epochs=50,
                validation_data=[[x_test, y_test], [y_test, x_test]], callbacks=[log, checkpoint, lr_decay])

Epoch 1/50


In [None]:





model.save_weights('/trained_model.h5')
print('Trained model saved to \'%s/trained_model.h5\'')

from utils import plot_log
plot_log('/log.csv', show=True)

train model

In [None]:
y_pred, x_recon = model.predict(x_test, batch_size=100)
print('-' * 30 + 'Begin: test' + '-' * 30)
print('Test acc:', np.sum(np.argmax(y_pred, 1) == np.argmax(y_test, 1)) / y_test.shape[0])

img = combine_images(np.concatenate([x_test[:50], x_recon[:50]]))
image = img * 255
Image.fromarray(image.astype(np.uint8)).save(args.save_dir + "/real_and_recon.png")
print()
print('Reconstructed images are saved to %s/real_and_recon.png' % args.save_dir)
print('-' * 30 + 'End: test' + '-' * 30)
plt.imshow(plt.imread(args.save_dir + "/real_and_recon.png"))
plt.show()

manipulated model

In [None]:
print('-' * 30 + 'Begin: manipulate' + '-' * 30)
x_test, y_test = data
index = np.argmax(y_test, 1) == args.digit
number = np.random.randint(low=0, high=sum(index) - 1)
x, y = x_test[index][number], y_test[index][number]
x, y = np.expand_dims(x, 0), np.expand_dims(y, 0)
noise = np.zeros([1, 10, 16])
x_recons = []
for dim in range(16):
    for r in [-0.25, -0.2, -0.15, -0.1, -0.05, 0, 0.05, 0.1, 0.15, 0.2, 0.25]:
        tmp = np.copy(noise)
        tmp[:, :, dim] = r
        x_recon = model.predict([x, y, tmp])
        x_recons.append(x_recon)

x_recons = np.concatenate(x_recons)

img = combine_images(x_recons, height=16)
image = img * 255
Image.fromarray(image.astype(np.uint8)).save(args.save_dir + '/manipulate-%d.png' % args.digit)
print('manipulated result saved to %s/manipulate-%d.png' % (args.save_dir, args.digit))
print('-' * 30 + 'End: manipulate' + '-' * 30)

test model

In [None]:
y_pred, x_recon = eval_model.predict(x_test, batch_size=100)
print('-' * 30 + 'Begin: test' + '-' * 30)
print('Test acc:', np.sum(np.argmax(y_pred, 1) == np.argmax(y_test, 1)) / y_test.shape[0])

img = combine_images(np.concatenate([x_test[:50], x_recon[:50]]))
image = img * 255
Image.fromarray(image.astype(np.uint8)).save(args.save_dir + "/real_and_recon.png")
print()
print('Reconstructed images are saved to %s/real_and_recon.png' % args.save_dir)
print('-' * 30 + 'End: test' + '-' * 30)
plt.imshow(plt.imread(args.save_dir + "/real_and_recon.png"))
plt.show()