In [1]:
import tensorflow as tf
import numpy as np
from tensorflow import keras
from PIL import Image
from tensorflow.keras import Sequential, layers
from tensorflow.keras.layers import Dense

def load_dataset():
    (Xtrain, ytrain), (Xtest, ytest) = keras.datasets.fashion_mnist.load_data()
    train_db = tf.data.Dataset.from_tensor_slices(Xtrain)
    train_db = train_db.shuffle(1000).map(preprocess).batch(512)
    test_db = tf.data.Dataset.from_tensor_slices(Xtest)
    test_db = test_db.map(preprocess).batch(512)
    return train_db, test_db
    
def preprocess(x):
    x = tf.cast(x, dtype=tf.float32) / 255.
    return x
   
def save_images(imgs, name):
    # 创建 280x280 大小图片阵列
    new_im = Image.new('L', (280, 280))

    index = 0
    # 10 行图片阵列
    for i in range(0, 280, 28):
        # 10 列图片阵列
        for j in range(0, 280, 28):
            im = imgs[index]
            im = Image.fromarray(im, mode='L')
            # 写入对应位置
            new_im.paste(im, (i, j))
            index += 1
    # 保存图片阵列
    new_im.save(name)

In [5]:
class AutoEncoder(keras.Model):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        self.Encoder = Sequential()
        self.Encoder.add(Dense(256, activation="relu", input_shape=(784,)))
        self.Encoder.add(Dense(128, activation="relu"))
        self.Encoder.add(Dense(20))
        
        self.Decoder = Sequential()
        self.Decoder.add(Dense(128, activation="relu", input_shape=(20,)))
        self.Decoder.add(Dense(128, activation="relu"))
        self.Decoder.add(Dense(784))
    
    def forward_propagation(self, input_data):
        vector = self.Encoder(input_data)
        y = self.Decoder(vector)
        return y
    
    def call(self, input_data):
        vector = self.Encoder(input_data)
        y = self.Decoder(vector)
        return y
    
    
def build_model():
    # 创建网络对象
    model = AutoEncoder()
    # 指定输入大小
    model.build(input_shape = (512, 784))
    # 打印网络信息
    model.summary()
    return model

def train(train_db, optimizer, model, epoch):
    for step, x in enumerate(train_db):
        """
        enumerate函数输出（数据下标， 数据）
        """
        x = tf.reshape(x, [-1, 784])
        with tf.GradientTape() as tape:
            X_ = model(x)
            loss = tf.losses.binary_crossentropy(x, X_, from_logits=True)
            """
            from_logits参数表示是否把输出转化为概率
            """
            loss = tf.reduce_mean(loss)
        
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))    
        
        if step % 100 == 0:
            # 间隔性打印训练误差
            print(epoch, step, float(loss))
            
    return model

In [6]:
def evaluation(test_db, model, epoch):
    # evaluation
    # 重建图片，从测试集采样一批图片
    x = next(iter(test_db))
    # 打平并送入自编码器
    logits = model(tf.reshape(x, [-1, 784]))
    # 将输出转换为像素值，使用sigmoid函数
    x_hat = tf.sigmoid(logits)
    # 恢复为 28x28,[b, 784] => [b, 28, 28]
    x_hat = tf.reshape(x_hat, [-1, 28, 28])

    # 输入的前50张+重建的前50张图片合并， [b, 28, 28] => [2b, 28, 28]
    x_concat = tf.concat([x[:50], x_hat[:50]], axis=0)
    # 恢复为0~255范围
    x_concat = x_concat.numpy() * 255.
    # 转换为整型
    x_concat = x_concat.astype(np.uint8)
    # 保存图片
    save_images(x_concat, './ae_images/rec_epoch_%d.png' % epoch)
    
def main():
    lr = 1e-3
    train_db, test_db = load_dataset()
    optimizer = tf.optimizers.Adam(lr=lr)
    model = build_model()
    for epoch in range(20):
        model = train(train_db, optimizer, model, epoch)
        evaluation(test_db, model, epoch)

In [None]:
if __name__ == '__main__':
    main()

Model: "auto_encoder_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential_2 (Sequential)    (None, 20)                236436    
_________________________________________________________________
sequential_3 (Sequential)    (None, 784)               120336    
Total params: 356,772
Trainable params: 356,772
Non-trainable params: 0
_________________________________________________________________
0 0 0.6930687427520752
