## 2. 加载数据集

In [None]:
from keras.datasets import mnist
(X_train, _), (X_test, _) = mnist.load_data()
X_train = X_train / 255.0
X_test = X_test / 255.0
X_train = X_train.reshape((-1, 28, 28, 1))
X_test = X_test.reshape((-1, 28, 28, 1))

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
img = X_train[12].reshape(28, 28)
plt.imshow(img)

In [None]:
# 在图片中加入噪音
import numpy as np
noise_factor = 0.5
X_train_noisy = X_train + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=X_train.shape)
X_test_noisy = X_test + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=X_test.shape)
# 将加入噪音后的图片像素值进行剪裁
X_train_noisy = np.clip(X_train_noisy, 0.0, 1.0)
X_test_noisy = np.clip(X_test_noisy, 0.0, 1.0)

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
img = X_test_noisy[4].reshape(28, 28)
plt.imshow(img)

## 3. 定义自编码模型

In [None]:
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Conv2DTranspose
from keras.models import Model
from keras.optimizers import Adam
# 定义编码器
inputs = Input((28, 28, 1))
x = Conv2D(filters=10,
           kernel_size=(5, 5), 
           activation='relu')(inputs)
x = MaxPooling2D((2, 2))(x)
x = Conv2D(filters=20, 
           kernel_size=(2, 2), 
           activation='relu')(x)
encoding = MaxPooling2D((2, 2))(x)
# 定义解码器
x = UpSampling2D((2, 2))(encoding)
x = Conv2DTranspose(filters=20, 
                    kernel_size=(2, 2), 
                    activation='relu')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2DTranspose(filters=10, 
                    kernel_size=(5, 5), 
                    activation='sigmoid')(x)
outputs = Conv2DTranspose(filters=1, 
                          kernel_size=(3, 3), 
                          activation='sigmoid')(x)
# 构建自编码模型
model = Model(inputs=inputs, outputs=outputs)
model.summary()

## 4.模型的编译与训练

In [None]:
model.compile(loss="binary_crossentropy",
              optimizer=Adam(),
              metrics=None)
model.fit(X_train_noisy, 
          X_train, 
          batch_size=32, 
          epochs=5, 
          verbose=2,
          validation_split=0.2)

In [None]:
result = model.predict(img.reshape(1, 28, 28, 1))
plt.imshow(result.reshape(28, 28))