In [49]:
import numpy as np
import tensorflow as tf
import tensorflow.keras as K
import matplotlib.pyplot as plt

from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, UpSampling2D

In [50]:
batch_size = 128
max_epochs = 20
filters = [32,32,16]

In [51]:
(xtrain,_),(xtest,_) = K.datasets.mnist.load_data()
xtrain = xtrain / 255.0
xtest = xtest / 255.0

xtrain = xtrain.reshape(-1,28,28,1)
xtest = xtest.reshape(-1,28,28,1)

noise = 0.5
xtrain_noisy = xtrain + noise * np.random.normal(loc=0.0,scale=1.0,size=xtrain.shape)
xtest_noisy = xtest + noise * np.random.normal(loc=0.0,scale=1.0,size=xtest.shape)

xtrain_noisy = np.clip(xtrain_noisy,0,1)
xtest_noisy = np.clip(xtest_noisy,0,1)

xtrain_noisy = xtrain_noisy.astype('float32')
xtest_noisy = xtest_noisy.astype('float32')

In [52]:
class Encoder(K.layers.Layer):
  def __init__(self, filters):
    super(Encoder,self).__init__()

    self.conv1 = Conv2D(filters=filters[0],kernel_size=3,strides=1,activation='relu',padding='same') ## 28x28
    self.conv2 = Conv2D(filters[1],3,1,activation='relu',padding='same') # 28 x 28
    self.conv3 = Conv2D(filters[2],3,1,activation='relu',padding='same')
    self.pool = MaxPooling2D((2,2))

  def call(self, x):
    x = self.conv1(x)
    x = self.pool(x) # 14
    x = self.conv2(x)
    x = self.pool(x) # 7
    x = self.conv3(x)
    x = self.pool(x) # 3
    
    return x # 28 x 28

In [53]:
class Decoder(K.layers.Layer):

  def __init__(self, filters):
    super(Decoder,self).__init__()
    self.conv1 = Conv2D(filters[2],3,1,activation='relu',padding='same')
    self.conv2 = Conv2D(filters[1],3,1,activation='relu',padding='same')
    self.conv3 = Conv2D(filters[0],3,1,activation='relu',padding='same')
    self.conv4 = K.layers.Conv2DTranspose(1,5,1,activation='sigmoid',padding='valid')
    self.up = UpSampling2D((2,2))

  def call(self,x):

    
    x = self.conv1(x)
    x = self.up(x) # 6
    x = self.conv2(x)
    x = self.up(x) # 12
    x = self.conv3(x) 
    x = self.up(x) # 28
    x = self.conv4(x)
    return x

In [54]:
class Autoencoder(K.Model):
  def __init__(self, filters):
    super(Autoencoder,self).__init__()

    self.encoder = Encoder(filters)
    self.decoder = Decoder(filters)
  
  def call(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded

In [None]:
model = Autoencoder(filters)
model.compile(loss='binary_crossentropy', optimizer='adam')
loss = model.fit(xtrain_noisy,xtrain,validation_split=0.2,epochs=max_epochs,batch_size=batch_size)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20

In [None]:
number = 10 # how many digits we will display
plt.figure(figsize=(20, 4))
for index in range(number):
 # display original
 ax = plt.subplot(2, number, index + 1)
 plt.imshow(xtest_noisy[index].reshape(28, 28), cmap='gray')
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
 # display reconstruction
 ax = plt.subplot(2, number, index + 1 + number)
 plt.imshow(tf.reshape(model(xtest_noisy)[index], (28, 28)),
cmap='gray')
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
plt.show()