In [1]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import keras.backend as K
from keras import layers, models, optimizers
from keras import callbacks
from PIL import Image

from utils import combine_images
from capsule_layers import CapsuleLayer, PrimaryCap, Length, Mask

K.set_image_data_format('channels_last')

Using TensorFlow backend.


In [9]:
def CapsNet(input_shape, n_class, routings):
    """
    A Capsule Network on MSCOCO 2017 dataset.
    :param input_shape: data shape, 3d, [width, height, channels]
    :param n_class: number of classes
    :param routings: number of routing iterations
    :return: Two Keras Models, the first one used for training, and the second one for evaluation.
            `eval_model` can also be used for training.
    """
    x = layers.Input(shape=input_shape)

    # Layer 1: Just a conventional Conv2D layer
    conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x)

    # Layer 4: Conv2D layer with `squash` activation, then reshape to [None, num_capsule, dim_capsule]
    primary_caps = PrimaryCap(conv1, dim_capsule=8, n_channels=32, kernel_size=9, strides=2, padding='valid')

    # Layer 5: Capsule layer. Routing algorithm works here.
    caption_caps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, routings=routings, name='caption_caps')(primary_caps)

    # Layer 6: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape.
    # If using tensorflow, this will not be necessary. :)
    out_caps = Length(name='capsnet')(caption_caps)

    # Decoder network.
    y = layers.Input(shape=(n_class,))
    masked_by_y = Mask()([caption_caps, y])  # The true label is used to mask the output of capsule layer. For training
    masked = Mask()(caption_caps)  # Mask using the capsule with maximal length. For prediction

    # Shared Decoder model in training and prediction
    decoder = models.Sequential(name='decoder')
    decoder.add(layers.Dense(512, activation='relu', input_dim=16*n_class))
    decoder.add(layers.Dense(1024, activation='relu'))
    decoder.add(layers.Dense(np.prod(input_shape), activation='sigmoid'))
    decoder.add(layers.Reshape(target_shape=input_shape, name='out_recon'))

    # Models for training and evaluation (prediction)
    train_model = models.Model([x, y], [out_caps, decoder(masked_by_y)])
    eval_model = models.Model(x, [out_caps, decoder(masked)])

    # manipulate model
    noise = layers.Input(shape=(n_class, 16))
    noised_caption_caps = layers.Add()([caption_caps, noise])
    masked_noised_y = Mask()([noised_caption_caps, y])
    manipulate_model = models.Model([x, y, noise], decoder(masked_noised_y))
    return train_model, eval_model, manipulate_model

In [10]:
def margin_loss(y_true, y_pred):
    """
    Margin loss for Eq.(4). When y_true[i, :] contains not just one `1`, this loss should work too. Not test it.
    :param y_true: [None, n_classes]
    :param y_pred: [None, num_capsule]
    :return: a scalar loss value.
    """
    L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \
        0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))

    return K.mean(K.sum(L, 1))

In [11]:
def train(model, data, args):
    """
    Training a CapsuleNet
    :param model: the CapsuleNet model
    :param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))`
    :param args: arguments
    :return: The trained model
    """
    # unpacking the data
    (x_train, y_train), (x_test, y_test) = data

    # callbacks
    log = callbacks.CSVLogger(args['save_dir'] + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args['save_dir'] + '/tensorboard-logs',
                               batch_size=args['batch_size'], histogram_freq=int(args['debug']))
    checkpoint = callbacks.ModelCheckpoint(args['save_dir'] + '/weights-{epoch:02d}.h5', monitor='val_capsnet_acc',
                                           save_best_only=True, save_weights_only=True, verbose=1)
    lr_decay = callbacks.LearningRateScheduler(schedule=lambda epoch: args['lr'] * (args['lr_decay'] ** epoch))

    # compile the model
    model.compile(optimizer=optimizers.Adam(lr=args['lr']),
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., args['lam_recon']],
                  metrics={'capsnet': 'accuracy'})

    # Begin: Training ----------------------------------------------------------------------------------------#
    model.fit([x_train, y_train], [y_train, x_train], batch_size=args['batch_size'], epochs=args['epochs'],
              validation_data=[[x_test, y_test], [y_test, x_test]], callbacks=[log, tb, checkpoint, lr_decay])
    # End: Training ------------------------------------------------------------------------------------------#

    model.save_weights(args['save_dir'] + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args['save_dir'])

    from utils import plot_log
    plot_log(args['save_dir'] + '/log.csv', show=True)

    return model

In [12]:
def test(model, data, args):
    x_test, y_test = data
    y_pred, x_recon = model.predict(x_test, batch_size=100)
    print('-'*30 + 'Begin: test' + '-'*30)
    print('Test acc:', np.sum(np.argmax(y_pred, 1) == np.argmax(y_test, 1))/y_test.shape[0])

    img = combine_images(np.concatenate([x_test[:50],x_recon[:50]]))
    image = img * 255
    Image.fromarray(image.astype(np.uint8)).save(args.save_dir + "/real_and_recon.png")
    print()
    print('Reconstructed images are saved to %s/real_and_recon.png' % args.save_dir)
    print('-' * 30 + 'End: test' + '-' * 30)
    plt.imshow(plt.imread(args.save_dir + "/real_and_recon.png"))
    plt.show()

In [None]:
def load_coco():
    # the data, shuffled and split between train and test sets
    
    # Load dataset from disk
    x_train = pickle.load(open('dataset/x_train.pickle', 'rb'))
    x_train = np.expand_dims(x_train, axis=-1).astype('float32') / 255.

    y_train = pickle.load(open('dataset/y_train.pickle', 'rb'))

    x_test = pickle.load(open('dataset/x_val.pickle', 'rb'))
    x_test = np.expand_dims(x_test, axis=-1).astype('float32') / 255.
    
    y_test = pickle.load(open('dataset/y_val.pickle', 'rb'))

    return (x_train, y_train), (x_test, y_test)

In [None]:
# load data
(x_train, y_train), (x_test, y_test) = load_coco()

In [13]:
# define and display model
# model, eval_model, manipulate_model = CapsNet(input_shape=x_train.shape[1:], n_class=12, routings=3)
model, eval_model, manipulate_model = CapsNet(input_shape=(250, 250, 1), n_class=12, routings=3)
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_8 (InputLayer)            (None, 250, 250, 1)  0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 60, 60, 96)   16320       input_8[0][0]                    
__________________________________________________________________________________________________
conv2 (Conv2D)                  (None, 28, 28, 96)   230496      conv1[0][0]                      
__________________________________________________________________________________________________
conv3 (Conv2D)                  (None, 20, 20, 256)  1990912     conv2[0][0]                      
__________________________________________________________________________________________________
primarycap

In [None]:
args = {}
args['epochs'] = 50
args['batch_size'] = 100
args['lr'] = 0.001
args['lr_decay'] = 0.9
args['lam_recon'] = 0.392
args['routings'] = 3
args['shift_fraction'] = 0.1
args['debug'] = False
args['save_dir'] = './result'
args['testing'] = False
args['super'] = 5
args['weights'] = None

In [None]:
train(model=model, data=((x_train, y_train), (x_test, y_test)), args=args)