In [14]:
# Inspired from https://wiseodd.github.io/techblog/2016/12/05/contractive-autoencoder/

from keras.layers import Input, Dense, Dropout
from keras.models import Model
from keras import regularizers
from keras.callbacks import TensorBoard
from keras.optimizers import Adam
import keras.backend as K
from keras import regularizers

In [None]:
encoding_dim = 256

input_img = Input(shape = (784,))
encoded = Dense(encoding_dim, activation='relu', name='encoder')(input_img)
decoded = Dense(784, activation='sigmoid', name='decoder')(encoded)

autoencoder = Model(input_img, decoded)

In [None]:
from keras.losses import binary_crossentropy

lam = 0.001

def contractive_loss(y_true, y_pred):
    base_error = binary_crossentropy(y_true, y_pred)
    
    h = autoencoder.get_layer('encoder').output
    W = K.variable(value = autoencoder.get_layer('encoder').get_weights()[0])
    W = K.transpose(W)
    factor = K.sign(h)
    factor = (factor + 1)/2
    contractive_error = lam*K.sum(factor * K.sum(W**2, axis=1), axis=1)
    
    return base_error + contractive_error

In [None]:
# Encoder model
encoder = Model(input_img, encoded)
encoded_input = Input(shape=(encoding_dim,))

# Decoder model
decoder_layer = autoencoder.get_layer('decoder')
decoder = Model(encoded_input, decoder_layer(encoded_input))

In [None]:
autoencoder.compile(optimizer = Adam(lr=0.001, decay = 0.0001), loss=contractive_loss)

In [None]:
from keras.datasets import mnist
import numpy as np

(x_train, _), (x_test,_) = mnist.load_data()

x_train = x_train.astype('float32')/255
x_test = x_test.astype('float32')/255

x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

In [None]:
autoencoder.fit(x_train, x_train, epochs=50, batch_size=256, shuffle=True, validation_data=(x_test,x_test))

In [None]:
weights,biases = autoencoder.get_layer('encoder').get_weights()
weights = weights.T
weights = weights.T.reshape((len(weights), 28, 28))

pixeled_weights = np.heaviside(weights, 0.5)*255

In [None]:
encoder_imgs = encoder.predict(x_test)
decoder_imgs = decoder.predict(encoder_imgs)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
def show_imgs(n, r, c, imgs):
    plt.figure(figsize=(2*c, 2*c))
    for i in range(n):
        ax = plt.subplot(r,c,i+1)
        plt.imshow(imgs[i])
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

In [None]:
show_imgs(10,1,10,x_test[0:10].reshape((10,28,28)))
show_imgs(10,1,10,decoder_imgs[0:10].reshape((10,28,28)))

In [None]:
show_imgs(encoding_dim, 16, 16, pixeled_weights)