In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

# 加载数据集
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
train_images = train_images[..., tf.newaxis]
test_images = test_images[..., tf.newaxis]

# LReLU 激活函数
def lrelu(x, alpha=0.1):
    return tf.maximum(alpha * x, x)

# 构建模型
class DenoisingAutoencoder(tf.keras.Model):
    def __init__(self):
        super(DenoisingAutoencoder, self).__init__()
        self.encoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
            tf.keras.layers.Conv2D(32, (3, 3), activation=lrelu, padding='same'),
            tf.keras.layers.MaxPooling2D((2, 2), padding='same'),
            tf.keras.layers.Conv2D(32, (3, 3), activation=lrelu, padding='same'),
            tf.keras.layers.MaxPooling2D((2, 2), padding='same')
        ])
        
        self.decoder = tf.keras.Sequential([
            tf.keras.layers.Conv2DTranspose(32, kernel_size=3, strides=2, activation=lrelu, padding='same'),
            tf.keras.layers.Conv2DTranspose(32, kernel_size=3, strides=2, activation=lrelu, padding='same'),
            tf.keras.layers.Conv2D(1, kernel_size=(3, 3), activation='sigmoid', padding='same')
        ])
        
    def call(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

# 模型实例化
autoencoder = DenoisingAutoencoder()

# 编译模型
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')

# 添加噪声
def add_noise(images, noise_factor=0.5):
    noisy_images = images + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=images.shape)
    noisy_images = np.clip(noisy_images, 0., 1.)
    return noisy_images

# 准备数据
train_images_noisy = add_noise(train_images)
test_images_noisy = add_noise(test_images)

# 训练模型
autoencoder.fit(train_images_noisy, train_images, epochs=10, batch_size=64, validation_data=(test_images_noisy, test_images))

# 显示结果
def display_images(original, noisy, reconstructed):
    n = 10  # 展示 10 张图像
    plt.figure(figsize=(20, 4))
    for i in range(n):
        # 原图
        ax = plt.subplot(3, n, i + 1)
        plt.imshow(original[i].reshape(28, 28))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        
        # 加噪声的图
        ax = plt.subplot(3, n, i + 1 + n)
        plt.imshow(noisy[i].reshape(28, 28))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        
        # 重建的图
        ax = plt.subplot(3, n, i + 1 + 2*n)
        plt.imshow(reconstructed[i].reshape(28, 28))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

test_images_reconstructed = autoencoder.predict(test_images_noisy)
display_images(test_images, test_images_noisy, test_images_reconstructed)






Epoch 1/10

Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
 81/938 [=>............................] - ETA: 27s - loss: 0.0971