In [1]:
# Inspired from https://wiseodd.github.io/techblog/2016/12/05/contractive-autoencoder/

from keras.layers import Input, Dense, Dropout
from keras.models import Model
from keras import regularizers
from keras.callbacks import TensorBoard
from keras.optimizers import Adam
import keras.backend as K
from keras import regularizers

Using TensorFlow backend.


In [2]:
encoding_dim = 256

input_img = Input(shape = (784,))
encoded = Dense(encoding_dim, activation='sigmoid', name='encoder')(input_img)
decoded = Dense(784, activation='sigmoid', name='decoder')(encoded)

autoencoder = Model(input_img, decoded)

In [3]:
from keras.losses import binary_crossentropy

lam = 0.001

def contractive_loss(y_true, y_pred):
    base_error = binary_crossentropy(y_true, y_pred)
    
    h = autoencoder.get_layer('encoder').output
    W = K.variable(value = autoencoder.get_layer('encoder').get_weights()[0])
    W = K.transpose(W)
    factor = h*(1-h)
    contractive_error = lam*K.sum(factor**2 * K.sum(W**2, axis=1), axis=1)
    
    return base_error + contractive_error

In [4]:
# Encoder model
encoder = Model(input_img, encoded)
encoded_input = Input(shape=(encoding_dim,))

# Decoder model
decoder_layer = autoencoder.get_layer('decoder')
decoder = Model(encoded_input, decoder_layer(encoded_input))

In [5]:
autoencoder.compile(optimizer = Adam(lr=0.001, decay = 0.0001), loss=contractive_loss)

In [6]:
from keras.datasets import mnist
import numpy as np

(x_train, _), (x_test,_) = mnist.load_data()

x_train = x_train.astype('float32')/255
x_test = x_test.astype('float32')/255

x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

In [7]:
autoencoder.fit(x_train, x_train, epochs=50, batch_size=256, shuffle=True, validation_data=(x_test,x_test))

Train on 60000 samples, validate on 10000 samples
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


<keras.callbacks.History at 0x7fba89003e50>

In [None]:
weights,biases = autoencoder.get_layer('encoder').get_weights()
weights = weights.T
weights = weights.T.reshape((len(weights), 28, 28))

pixeled_weights = np.heaviside(weights, 0.5)*255

In [None]:
encoder_imgs = encoder.predict(x_test)
decoder_imgs = decoder.predict(encoder_imgs)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
def show_imgs(n, r, c, imgs):
    plt.figure(figsize=(2*c, 2*c))
    for i in range(n):
        ax = plt.subplot(r,c,i+1)
        plt.imshow(imgs[i])
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

In [None]:
show_imgs(10,1,10,x_test[0:10].reshape((10,28,28)))
show_imgs(10,1,10,decoder_imgs[0:10].reshape((10,28,28)))

In [None]:
show_imgs(encoding_dim, 16, 16, pixeled_weights)