In [1]:
import keras
from keras import layers
from keras import backend as K
from keras.models import Model
import numpy as np

img_shape = (28,28,1)
batch_size = 16
latent_dim = 2  #潜在空降的维度：一个二维平面

input_img = keras.Input(shape=img_shape)

x = layers.Conv2D(32,3,padding='same',activation='relu')(input_img)
x = layers.Conv2D(64,3,padding='same',activation='relu',strides=(2,2))(x)
x = layers.Conv2D(64,3,padding='same',activation='relu')(x)
x = layers.Conv2D(64,3,padding='same',activation='relu')(x)
shape_before_flattening = K.int_shape(x)

x = layers.Flatten()(x)
x = layers.Dense(32,activation='relu')(x)

z_mean = layers.Dense(latent_dim)(x)
z_log_var = layers.Dense(latent_dim)(x) 

In [2]:
def sampling(args):
    z_mean,z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0],latent_dim),mean=0.,stddev=1.)
    return z_mean + K.exp(z_log_var) * epsilon

z = layers.Lambda(sampling)([z_mean,z_log_var])

In [3]:
decoder_input = layers.Input(K.int_shape(z)[1:])

x = layers.Dense(np.prod(shape_before_flattening[1:]),activation='relu')(decoder_input)  #对输入进行采样
x = layers.Reshape(shape_before_flattening[1:])(x) #将z转化为特征图， 使其形状和编码器模型最后一个个Flatten层之前的特征图形形状相同
x = layers.Conv2DTranspose(32,3,padding='same',activation='relu',strides=(2,2))(x)  
x = layers.Conv2D(1,3,padding='same',activation='sigmoid')(x)  #将z解码为与原始输入图像具有相同的尺寸特征图

decoder = Model(decoder_input,x)  #解码模型实例化，它将decoder_input转化为解码后的图像
z_decoded = decoder(z)

In [4]:
class CustomVariationalLayer(keras.layers.Layer):
    
    def vae_loss(self,x,z_decoded):
        x = K.flatten(x)
        z_decoded = K.flatten(z_decoded)
        xent_loss = keras.metrics.binary_crossentropy(x,z_decoded)
        kl_loss = -5e-4 * K.mean(
            1 + z_log_var - K.square(z_mean) - K.exp(z_log_var),axis=-1)
        return K.mean(xent_loss + kl_loss)
    
    def call(self,inputs):   #实现自定义层
        x = inputs[0]
        z_decoded = inputs[1]
        loss = self.vae_loss(x,z_decoded)
        self.add_loss(loss,inputs=inputs)
        return x  #不使用x,但层必须一定要有返回值

y = CustomVariationalLayer()([input_img,z_decoded])

In [6]:
#训练
import tensorflow as tf
tf.config.experimental_run_functions_eagerly(True)
#动态分配显存，解决内存不足问题
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)


from keras.datasets import mnist
vae = Model(input_img,y)
vae.compile(optimizer='rmsprop',loss=None,experimental_run_tf_function=False)  #自定义层中已经定义了损失，顾不用使用loss
vae.summary()

(x_train,_),(x_test,y_test) = mnist.load_data()

x_train = x_train.astype('float32')/255.
x_train = x_train.reshape(x_train.shape+(1,))
x_test = x_test.astype('float32')/255.
x_test = x_test.reshape(x_test.shape+ (1,))

vae.fit(x=x_train,y=None,
       shuffle=True,
       epochs = 10,
       batch_size=batch_size,
       validation_data=(x_test, None))


Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 28, 28, 32)   320         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 14, 14, 64)   18496       conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 14, 14, 64)   36928       conv2d_1[0][0]                   
____________________________________________________________________________________________

UnknownError: Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above. [Op:Conv2D]

In [None]:
import matplotlib.pyplot as plt
from scipy.stats import norm

n = 15  #15*15的数字网格，共255个数字
digit_size = 28
figure = np.size((digit_size * n,digit_size * n))
grid_x = norm.ppf(np.linspace(0.05,0.95,n))
grid_y = norm.ppf(np.linspace(0.05,0.95,n))  #ppf函数对线性分隔的坐标进行变换，以生成潜在变量z的值（因为潜在空间的先验分布为高斯分布）

for i,yi in enumerate(gird_x):
    for j,xi in enumerate(grid_y):
        z_sample = np.array([xi,yi])
        z_sampel = np.tile(z_sample,batch_size).reshape(batch_size,2)  #将z多次重复，构建完整的批量
        x_decoded = decoder.predict(z_sample,batch_size=batch_size)  #将批量解码为数字图像
        digit = x_decoded[0].reshape(digit_size,digit_size)   #将批量第一个数字的形状从28*28*1 变为 28*28
        firgure[i * digit_size:(i+1) * digit_size,
                j * digit_size:(j+1) * digit_size] = digit
        
plt.figure(figsize=(10,10))
plt.imshow(figure,cmap='Greys_r')
plt.show()