In [1]:
import keras
from keras import backend as K
from keras.layers import Input, Lambda, Dense, Conv2D, Conv2DTranspose, MaxPool2D, Flatten, Reshape
from keras.models import Model
from keras import metrics
from keras.datasets import mnist
import tensorflow as tf
from tensorflow.python import debug as tf_debug
import numpy as np

Using TensorFlow backend.


In [2]:
batch_size = 128
epochs = 50
image_size = (28,28,1)
latent_dimension = 3 ## to view representation clusters in 3 dimensions

In [3]:
## defining the input for mnist images
input_image = Input(shape=image_size)

In [4]:
## defining the inference network
## this is the network that will produce a latent space representation of the original image
## 5 layer convolutional network
encoder = Conv2D(16, (3,3), activation='relu', padding='same')(input_image)
encoder = MaxPool2D((2,2), padding="same")(encoder)
encoder = Conv2D(8, (3,3), activation='relu', padding='same')(encoder)
encoder = MaxPool2D((2,2), padding="same")(encoder)
encoder = Conv2D(4, (3,3), activation='relu', padding='same')(encoder)
encoder_shape = K.int_shape(encoder)

encoder = Flatten()(encoder) ## turns output to size of (None, 112)
# testing without further reduction of dimensions
#encoder = Dense(32)(encoder)
z_mean = Dense(latent_dimension)(encoder)
z_var = Dense(latent_dimension)(encoder)

Instructions for updating:
keep_dims is deprecated, use keepdims instead


In [5]:
print(encoder_shape)

(None, 7, 7, 4)


In [15]:
## defining the sampling method for the generator network
def normal_sample(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dimension),
                              mean=0., stddev=1.)
    return z_mean + K.exp(z_log_var) * epsilon
z = Lambda(normal_sample)([z_mean, z_var])
sample_shape = K.int_shape(z)
print(sample_shape)

(None, 3)


In [16]:
## defining the generator network
## this is the network that takes the latent space representation and upsamples it to create a reconstruction
## 5 layer convolutional upsampling network
decoder_input = Input(shape=sample_shape[1:])
decoder = Dense(np.prod(encoder_shape[1:]), activation="relu")(decoder_input)
decoder = Reshape(encoder_shape[1:])(decoder)
print(decoder.shape)
## figure out the shaping problem
decoder  = Conv2DTranspose(32, 3,
                           padding='same', activation='relu',
                           strides=(2, 2))(decoder)
print(decoder.shape)
decoder = Conv2D(1, 3, padding='same', activation='sigmoid')(decoder)
print(decoder.shape)
decoder = Model(decoder_input, decoder)

(?, 7, 7, 4)
(?, ?, ?, 32)
(?, ?, ?, 1)


In [17]:
## variational layer for reconstruction loss
class CustomVariationalLayer(keras.layers.Layer):

    def vae_loss(self, x, z_decoded):
        x = K.flatten(x)
        z_decoded = K.flatten(z_decoded)
        xent_loss = keras.metrics.binary_crossentropy(x, z_decoded)
        kl_loss = -5e-4 * K.mean(
            1 + z_var - K.square(z_mean) - K.exp(z_var), axis=-1)
        return K.mean(xent_loss + kl_loss)

    def call(self, inputs):
        x = inputs[0]
        z_decoded = inputs[1]
        loss = self.vae_loss(x, z_decoded)
        self.add_loss(loss, inputs=inputs)
        # We don't use this output.
        return x

In [18]:
## reconstruct the sample taken from the latent space
decoded_sample = decoder(z) ## generated sample
y = CustomVariationalLayer()([input_image, decoded_sample]) ## reconstruction loss applied

In [19]:
from keras.datasets import mnist

vae = Model(input_image, y)
vae.compile(optimizer='rmsprop', loss=None)
vae.summary()

# Train the VAE on MNIST digits
(x_train, _), (x_test, y_test) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_train = x_train.reshape(x_train.shape + (1,))
x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape(x_test.shape + (1,))

vae.fit(x=x_train, y=None,
        shuffle=True,
        epochs=10,
        batch_size=batch_size,
        validation_data=(x_test, None))

  after removing the cwd from sys.path.


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 28, 28, 16)   160         input_1[0][0]                    
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 14, 14, 16)   0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 14, 14, 8)    1160        max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
max_poolin

InvalidArgumentError: Incompatible shapes: [25088] vs. [100352]
	 [[Node: training_1/RMSprop/gradients/custom_variational_layer_2/logistic_loss/mul_grad/BroadcastGradientArgs = BroadcastGradientArgs[T=DT_INT32, _class=["loc:@custom_variational_layer_2/logistic_loss/mul"], _device="/job:localhost/replica:0/task:0/device:CPU:0"](training_1/RMSprop/gradients/custom_variational_layer_2/logistic_loss/mul_grad/Shape, training_1/RMSprop/gradients/custom_variational_layer_2/logistic_loss/mul_grad/Shape_1)]]

Caused by op 'training_1/RMSprop/gradients/custom_variational_layer_2/logistic_loss/mul_grad/BroadcastGradientArgs', defined at:
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/ipykernel/kernelapp.py", line 477, in start
    ioloop.IOLoop.instance().start()
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/zmq/eventloop/ioloop.py", line 177, in start
    super(ZMQIOLoop, self).start()
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/tornado/ioloop.py", line 888, in start
    handler_func(fd_obj, events)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 440, in _handle_events
    self._handle_recv()
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 235, in dispatch_shell
    handler(stream, idents, msg)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/ipykernel/ipkernel.py", line 196, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/ipykernel/zmqshell.py", line 533, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2698, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2808, in run_ast_nodes
    if self.run_code(code, result):
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2862, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-19-cb320b8c5d0c>", line 19, in <module>
    validation_data=(x_test, None))
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 1634, in fit
    self._make_train_function()
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 990, in _make_train_function
    loss=self.total_loss)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 87, in wrapper
    return func(*args, **kwargs)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/keras/optimizers.py", line 225, in get_updates
    grads = self.get_gradients(loss, params)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/keras/optimizers.py", line 73, in get_gradients
    grads = K.gradients(loss, params)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2394, in gradients
    return tf.gradients(loss, variables, colocate_gradients_with_ops=True)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 488, in gradients
    gate_gradients, aggregation_method, stop_gradients)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 625, in _GradientsHelper
    lambda: grad_fn(op, *out_grads))
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 379, in _MaybeCompile
    return grad_fn()  # Exit early
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 625, in <lambda>
    lambda: grad_fn(op, *out_grads))
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py", line 881, in _MulGrad
    rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 671, in _broadcast_gradient_args
    "BroadcastGradientArgs", s0=s0, s1=s1, name=name)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3290, in create_op
    op_def=op_def)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1654, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

...which was originally created as op 'custom_variational_layer_2/logistic_loss/mul', defined at:
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
[elided 16 identical lines from previous traceback]
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2698, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2802, in run_ast_nodes
    if self.run_code(code, result):
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2862, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-18-46cd04c78a97>", line 3, in <module>
    y = CustomVariationalLayer()([input_image, decoded_sample]) ## reconstruction loss applied
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/keras/engine/topology.py", line 603, in __call__
    output = self.call(inputs, **kwargs)
  File "<ipython-input-17-3d4fb9c0aeca>", line 15, in call
    loss = self.vae_loss(x, z_decoded)
  File "<ipython-input-17-3d4fb9c0aeca>", line 7, in vae_loss
    xent_loss = keras.metrics.binary_crossentropy(x, z_decoded)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/keras/losses.py", line 66, in binary_crossentropy
    return K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2953, in binary_crossentropy
    logits=output)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/nn_impl.py", line 181, in sigmoid_cross_entropy_with_logits
    relu_logits - logits * labels,
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py", line 971, in binary_op_wrapper
    return func(x, y, name=name)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py", line 1198, in _mul_dispatch
    return gen_math_ops.mul(x, y, name=name)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_math_ops.py", line 4689, in mul
    "Mul", x=x, y=y, name=name)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3290, in create_op
    op_def=op_def)
  File "/Users/ShishirJakati/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1654, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Incompatible shapes: [25088] vs. [100352]
	 [[Node: training_1/RMSprop/gradients/custom_variational_layer_2/logistic_loss/mul_grad/BroadcastGradientArgs = BroadcastGradientArgs[T=DT_INT32, _class=["loc:@custom_variational_layer_2/logistic_loss/mul"], _device="/job:localhost/replica:0/task:0/device:CPU:0"](training_1/RMSprop/gradients/custom_variational_layer_2/logistic_loss/mul_grad/Shape, training_1/RMSprop/gradients/custom_variational_layer_2/logistic_loss/mul_grad/Shape_1)]]
