In [1]:
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from PIL import Image
from matplotlib import pyplot as plt


print(tf.__version__)

# 作图函数
def save_images(imgs, name):
    new_im = Image.new('L', (280, 280))
    index = 0
    for i in range(0, 280, 28):
        for j in range(0, 280, 28):
            im = imgs[index]
            im = Image.fromarray(im, mode='L')
            new_im.paste(im, (i, j))
            index += 1
    new_im.save(name)

2.0.0


In [2]:
h_dim = 20
batchsz = 512
learning_rate = 1e-3

(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255, x_test.astype(np.float32) / 255

In [3]:
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(20).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batchsz)

print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)


In [4]:
z_dim = 10


from tensorflow.keras import Sequential

class VAE(keras.Model):
    def __init__(self):
        super(VAE, self).__init__()
        # encoder
        self.e1 = keras.layers.Dense(128)
        print(self.e1)
        self.e2 = keras.layers.Dense(z_dim)   # get mean prediction
        self.e3 = keras.layers.Dense(z_dim)   #  get mean prediction
        
        # decoder
        self.fc4 = keras.layers.Dense(128)
        self.fc5 = keras.layers.Dense(784)
    
    def encoder(self, inputs):
        h = self.e1(inputs)
        h = tf.nn.relu(h)
        # get_mean
        mean = self.e2(h)
        # get_variance
        log_var = self.e3(h)    # 一般方差做一个log，方便计算
        
        return mean, log_var
    
    def decoder(self, z):
        out = tf.nn.relu(self.fc4(z))
        out = self.fc5(out)
        
        return out
    
    def reparameterize(self, mean, log_var):
        eps = tf.random.normal(tf.shape(log_var))
        
        std = tf.exp(log_var)**0.5    # 开根号
        
        z = mean + std * eps
        
        return z
        
    
    def call(self, inputs, training=None):
        # [b, 784] -> [b, z_dim], [b, z_dim]
        mean, log_var = self.encoder(inputs)
        # trick: reparameterization trick，采样
        z = self.reparameterize(mean, log_var)
        x_hat = self.decoder(z)
        
        # 返回 x 的同时，返回 mean和val作为约束
        return x_hat, mean, log_var

In [None]:
model = VAE()

model.build(input_shape=(None, 784))
model.summary()
optimizer = tf.optimizers.Adam(0.001)

for epoch in range(1000):
    for step, x in enumerate(train_db):
        x = tf.reshape(x, [-1, 784])
        with tf.GradientTape() as tape:
            x_rec_logits, mean, log_var = model(x)
            rec_loss = tf.losses.binary_crossentropy(x, x_rec_logits, from_logits=True)
            rec_loss = tf.reduce_mean(rec_loss)
            # compute kl divergence (mean, val) ~ N(0, 1)
            kl_div = -0.5 * (log_var + 1 - mean ** 2 - tf.exp(log_var))
            kl_div = tf.reduce_mean(kl_div) / x.shape[0]
            # loss 
            loss = rec_loss + 1.0 * kl_div
        
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        
        if step % 10 == 0:
            print(epoch, step, 'kl_div:', float(kl_div), 'rec_loss:', rec_loss)

<tensorflow.python.keras.layers.core.Dense object at 0x000001E70461BDD8>
Model: "vae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                multiple                  100480    
_________________________________________________________________
dense_1 (Dense)              multiple                  1290      
_________________________________________________________________
dense_2 (Dense)              multiple                  1290      
_________________________________________________________________
dense_3 (Dense)              multiple                  1408      
_________________________________________________________________
dense_4 (Dense)              multiple                  101136    
Total params: 205,604
Trainable params: 205,604
Non-trainable params: 0
_________________________________________________________________
0 0 kl_div: 0.0004456816241145134 rec_loss: tf.Ten