In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
import pickle

from keras.layers import Input, Dense, Lambda, Flatten, Reshape, Layer
from keras.layers import Conv2D, Conv2DTranspose
from keras.models import Model
from keras import backend as K
from keras import metrics
from keras.datasets import cifar10


Using TensorFlow backend.


In [2]:
import numpy as np
import os
from PIL import Image
from sklearn.model_selection import train_test_split

show_name = "all"
src_dir = "../img/%s" % show_name
input_files = [os.path.join(src_dir, f) for f in os.listdir(src_dir) if f.endswith(".jpg")]
#imgs = [Image.open(im).resize(260, 172, 3) for im in input_files]

lwf_immatrix = np.array([np.array(Image.open(im)) for im in input_files])
print("Input size", lwf_immatrix.shape)


x_train, x_test, _, _ = train_test_split(lwf_immatrix, [0]*len(input_files), test_size=0.20, random_state=42)

original_img_size = lwf_immatrix[0].shape

x_train = x_train.astype('float32') / 255.
x_train = x_train.reshape((x_train.shape[0],) + original_img_size)
print(original_img_size)
print(x_train.shape)
x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape((x_test.shape[0],) + original_img_size)


Input size (195, 263, 175, 3)
(263, 175, 3)
(156, 263, 175, 3)


In [3]:
# import parameters
img_rows, img_cols, img_chns = 263, 175, 3
latent_dim = 6
intermediate_dim = 128
epsilon_std = 1.0
epochs = 20
filters = 5
num_conv = 3
batch_size = 20

# tensorflow uses channels_last
# theano uses channels_first
"""
if K.image_data_format() == 'channels_first':
    original_img_size = (img_chns, img_rows, img_cols)
else:
    original_img_size = (img_rows, img_cols, img_chns)
"""
#original_img_size = (img_rows, img_cols, img_chns)

# encoder architecture
x = Input(shape=original_img_size)
conv_1 = Conv2D(img_chns,
                kernel_size=(2, 2),
                padding='same', activation='relu')(x)
conv_2 = Conv2D(filters,
                kernel_size=(2, 2),
                padding='same', activation='relu',
                strides=(2, 2))(conv_1)
conv_3 = Conv2D(filters,
                kernel_size=num_conv,
                padding='same', activation='relu',
                strides=1)(conv_2)
conv_4 = Conv2D(filters,
                kernel_size=num_conv,
                padding='same', activation='relu',
                strides=1)(conv_3)
flat = Flatten()(conv_4)
hidden = Dense(intermediate_dim, activation='relu')(flat)

# mean and variance for latent variables
z_mean = Dense(latent_dim)(hidden)
z_log_var = Dense(latent_dim)(hidden)

# sampling layer
def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim),
                              mean=0., stddev=epsilon_std)
    return z_mean + K.exp(z_log_var) * epsilon

z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])


# decoder architecture
decoder_hid = Dense(intermediate_dim, activation='relu')
decoder_upsample = Dense(int(filters * round(img_rows / 2) * round(img_cols / 2)), activation='relu')

output_shape = (batch_size, int(round(img_rows / 2)), int(round(img_cols / 2)), filters)
"""
if K.image_data_format() == 'channels_first':
    output_shape = (batch_size, filters, int(img_rows / 2), int(img_cols / 2))
else:
    output_shape = (batch_size, int(img_rows / 2), int(img_cols / 2), filters)
"""    

decoder_reshape = Reshape(output_shape[1:])
decoder_deconv_1 = Conv2DTranspose(filters,
                                   kernel_size=num_conv,
                                   padding='same',
                                   strides=1,
                                   activation='relu')
decoder_deconv_2 = Conv2DTranspose(filters,
                                   kernel_size=num_conv,
                                   padding='same',
                                   strides=1,
                                   activation='relu')
decoder_deconv_3_upsamp = Conv2DTranspose(filters,
                                          kernel_size=(3, 3),
                                          strides=(2, 2),
                                          padding='valid',
                                          activation='relu')
decoder_mean_squash = Conv2D(img_chns,
                             kernel_size=2,
                             padding='valid',
                             activation='sigmoid')

hid_decoded = decoder_hid(z)
up_decoded = decoder_upsample(hid_decoded)
reshape_decoded = decoder_reshape(up_decoded)
deconv_1_decoded = decoder_deconv_1(reshape_decoded)
deconv_2_decoded = decoder_deconv_2(deconv_1_decoded)
x_decoded_relu = decoder_deconv_3_upsamp(deconv_2_decoded)
x_decoded_mean_squash = decoder_mean_squash(x_decoded_relu)


# Custom loss layer
class CustomVariationalLayer(Layer):
    def __init__(self, **kwargs):
        self.is_placeholder = True
        super(CustomVariationalLayer, self).__init__(**kwargs)

    def vae_loss(self, x, x_decoded_mean_squash):
        x = K.flatten(x)
        x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)
        xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
        kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
        return K.mean(xent_loss + kl_loss)

    def call(self, inputs):
        x = inputs[0]
        x_decoded_mean_squash = inputs[1]
        loss = self.vae_loss(x, x_decoded_mean_squash)
        self.add_loss(loss, inputs=inputs)
        return x

y = CustomVariationalLayer()([x, x_decoded_mean_squash])

# entire model
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)
vae.summary()


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 263, 175, 3)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 263, 175, 3)  39          input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 132, 88, 5)   65          conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 132, 88, 5)   230         conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_4 (

In [4]:
# training
history = vae.fit(x_train,
        shuffle=True,
        epochs=epochs,
        batch_size=batch_size,
        validation_data=(x_test, None))

# encoder from learned model
encoder = Model(x, z_mean)

# generator / decoder from learned model
decoder_input = Input(shape=(latent_dim,))
_hid_decoded = decoder_hid(decoder_input)
_up_decoded = decoder_upsample(_hid_decoded)
_reshape_decoded = decoder_reshape(_up_decoded)
_deconv_1_decoded = decoder_deconv_1(_reshape_decoded)
_deconv_2_decoded = decoder_deconv_2(_deconv_1_decoded)
_x_decoded_relu = decoder_deconv_3_upsamp(_deconv_2_decoded)
_x_decoded_mean_squash = decoder_mean_squash(_x_decoded_relu)
generator = Model(decoder_input, _x_decoded_mean_squash)

# save all 3 models for future use - especially generator
vae.save('../models/cifar10_ld_%d_conv_%d_id_%d_e_%d_vae.h5' % (latent_dim, num_conv, intermediate_dim, epochs))
encoder.save('../models/cifar10_ld_%d_conv_%d_id_%d_e_%d_encoder.h5' % (latent_dim, num_conv, intermediate_dim, epochs))
generator.save('../models/cifar10_ld_%d_conv_%d_id_%d_e_%d_generator.h5' % (latent_dim, num_conv, intermediate_dim, epochs))

# save training history
fname = '../models/cifar10_ld_%d_conv_%d_id_%d_e_%d_history.pkl' % (latent_dim, num_conv, intermediate_dim, epochs)
with open(fname, 'wb') as file_pi:
    pickle.dump(history.history, file_pi)

Train on 156 samples, validate on 39 samples
Epoch 1/20


InvalidArgumentError: Incompatible shapes: [2787840] vs. [2761500]
	 [[Node: training/RMSprop/gradients/custom_variational_layer_1/logistic_loss/mul_grad/BroadcastGradientArgs = BroadcastGradientArgs[T=DT_INT32, _class=["loc:@custom_variational_layer_1/logistic_loss/mul"], _device="/job:localhost/replica:0/task:0/cpu:0"](training/RMSprop/gradients/custom_variational_layer_1/logistic_loss/mul_grad/Shape, training/RMSprop/gradients/custom_variational_layer_1/logistic_loss/mul_grad/Shape_1)]]

Caused by op 'training/RMSprop/gradients/custom_variational_layer_1/logistic_loss/mul_grad/BroadcastGradientArgs', defined at:
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/ipykernel/kernelapp.py", line 486, in start
    self.io_loop.start()
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/tornado/platform/asyncio.py", line 132, in start
    self.asyncio_loop.run_forever()
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/asyncio/base_events.py", line 421, in run_forever
    self._run_once()
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/asyncio/base_events.py", line 1425, in _run_once
    handle._run()
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/asyncio/events.py", line 127, in _run
    self._callback(*self._args)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/tornado/platform/asyncio.py", line 122, in _handle_events
    handler_func(fileobj, events)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/tornado/stack_context.py", line 300, in null_wrapper
    return fn(*args, **kwargs)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 450, in _handle_events
    self._handle_recv()
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 480, in _handle_recv
    self._run_callback(callback, msg)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 432, in _run_callback
    callback(*args, **kwargs)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/tornado/stack_context.py", line 300, in null_wrapper
    return fn(*args, **kwargs)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
    handler(stream, idents, msg)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/ipykernel/ipkernel.py", line 208, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/ipykernel/zmqshell.py", line 537, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2662, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2785, in _run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2901, in run_ast_nodes
    if self.run_code(code, result):
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2961, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-4-d3df2a6a65f3>", line 6, in <module>
    validation_data=(x_test, None))
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/keras/engine/training.py", line 1008, in fit
    self._make_train_function()
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/keras/engine/training.py", line 498, in _make_train_function
    loss=self.total_loss)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/keras/optimizers.py", line 255, in get_updates
    grads = self.get_gradients(loss, params)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/keras/optimizers.py", line 89, in get_gradients
    grads = K.gradients(loss, params)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py", line 2708, in gradients
    return tf.gradients(loss, variables, colocate_gradients_with_ops=True)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/tensorflow/python/ops/gradients_impl.py", line 542, in gradients
    grad_scope, op, func_call, lambda: grad_fn(op, *out_grads))
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/tensorflow/python/ops/gradients_impl.py", line 348, in _MaybeCompile
    return grad_fn()  # Exit early
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/tensorflow/python/ops/gradients_impl.py", line 542, in <lambda>
    grad_scope, op, func_call, lambda: grad_fn(op, *out_grads))
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/tensorflow/python/ops/math_grad.py", line 713, in _MulGrad
    rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/tensorflow/python/ops/gen_array_ops.py", line 393, in _broadcast_gradient_args
    name=name)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
    op_def=op_def)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2630, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1204, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

...which was originally created as op 'custom_variational_layer_1/logistic_loss/mul', defined at:
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
[elided 22 identical lines from previous traceback]
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2961, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-3-6e0230684b7e>", line 117, in <module>
    y = CustomVariationalLayer()([x, x_decoded_mean_squash])
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/keras/engine/base_layer.py", line 457, in __call__
    output = self.call(inputs, **kwargs)
  File "<ipython-input-3-6e0230684b7e>", line 113, in call
    loss = self.vae_loss(x, x_decoded_mean_squash)
  File "<ipython-input-3-6e0230684b7e>", line 106, in vae_loss
    xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/keras/losses.py", line 77, in binary_crossentropy
    return K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py", line 3302, in binary_crossentropy
    logits=output)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/tensorflow/python/ops/nn_impl.py", line 174, in sigmoid_cross_entropy_with_logits
    relu_logits - logits * labels,
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/tensorflow/python/ops/math_ops.py", line 865, in binary_op_wrapper
    return func(x, y, name=name)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/tensorflow/python/ops/math_ops.py", line 1088, in _mul_dispatch
    return gen_math_ops._mul(x, y, name=name)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/tensorflow/python/ops/gen_math_ops.py", line 1449, in _mul
    result = _op_def_lib.apply_op("Mul", x=x, y=y, name=name)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
    op_def=op_def)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2630, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/sw/lsa/centos7/python-anaconda-arc-connect/created-20170421/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1204, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Incompatible shapes: [2787840] vs. [2761500]
	 [[Node: training/RMSprop/gradients/custom_variational_layer_1/logistic_loss/mul_grad/BroadcastGradientArgs = BroadcastGradientArgs[T=DT_INT32, _class=["loc:@custom_variational_layer_1/logistic_loss/mul"], _device="/job:localhost/replica:0/task:0/cpu:0"](training/RMSprop/gradients/custom_variational_layer_1/logistic_loss/mul_grad/Shape, training/RMSprop/gradients/custom_variational_layer_1/logistic_loss/mul_grad/Shape_1)]]
