In [1]:
import tensorflow as tf
import numpy as np

In [2]:
num_features = 784

learning_rate = 0.01
training_steps = 20000
batch_size = 256
disply_step = 1000

num_hidden_1 = 128
num_hidden_2 = 64

In [3]:
from tensorflow.keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train, x_test = x_train.astype(np.float32), x_test.astype(np.float32)

x_train, x_test = x_train.reshape([-1, num_features]), x_test.reshape([-1, num_features])

x_train, x_test = x_train/255., x_test/255.

In [4]:
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_data = train_data.repeat().shuffle(10000).batch(batch_size).prefetch(1)

test_data = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_data = test_data.repeat().batch(batch_size).prefetch(1)

In [60]:
random_normal = tf.initializers.RandomNormal()

weights = {
    'encoder_h1': tf.Variable(random_normal([num_features, num_hidden_1])),
    'encoder_h2': tf.Variable(random_normal([num_hidden_1, num_hidden_2])),
    'decoder_h1': tf.Variable(random_normal([num_hidden_2, num_hidden_1])),
    'decoder_h2': tf.Variable(random_normal([num_hidden_1, nums_features])),
}

biases = {
    'encoder_b1': tf.Variable(random_normal([num_hidden_1])),
    'encoder_b2': tf.Variable(random_normal([num_hidden_2])),
    'decoder_b1': tf.Variable(random_normal([num_hidden_1])),
    'decoder_b2': tf.Variable(random_normal([num_features])),
}

In [61]:
weights['encoder_h1'].shape

TensorShape([784, 128])

In [62]:
def encoder(x):
    layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['encoder_h1']),
                                biases['encoder_b1']))
    
    layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['encoder_h2']),
                                biases['encoder_b2']))
    
    return layer_2

def decoder(x):
    layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['decoder_h1']),
                                  biases['decoder_b1']))
    
    layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['decoder_h2']),
                                  biases['decoder_b2']))
    
    return layer_2

In [63]:
def mean_square(reconstructed, original):
    return tf.reduce_mean(tf.pow(original -reconstructed, 2))

optimizer = tf.optimizers.Adam(learning_rate=learning_rate)

In [97]:
def run_optimization(x):
    with tf.GradientTape() as g:
        reconstructed_image = decoder(encoder(x))
        loss = mean_square(reconstructed_image, x)
        
    trainable_variables = list(set(weights.values()) | set(biases.values()))

    gradients = g.gradient(loss, trainable_variables)
    
    optimizer.apply_gradients(zip(gradients, trainable_variables))
    
    return loss

In [98]:
for step, (batch_x, _) in enumerate(train_data.take(training_steps+1)):
    
    loss = run_optimization(batch_x)
    
    if step % display_step == 0:
        print("step: %i, loss: %f" % (step, loss))

step: 0, loss: 0.236298
step: 1000, loss: 0.017984
step: 2000, loss: 0.011823
step: 3000, loss: 0.008842
step: 4000, loss: 0.008418
step: 5000, loss: 0.007441
step: 6000, loss: 0.006401
step: 7000, loss: 0.006229
step: 8000, loss: 0.005942
step: 9000, loss: 0.005629
step: 10000, loss: 0.005322
step: 11000, loss: 0.004981
step: 12000, loss: 0.005280
step: 13000, loss: 0.004751
step: 14000, loss: 0.004349
step: 15000, loss: 0.004742
step: 16000, loss: 0.004616
step: 17000, loss: 0.004141
step: 18000, loss: 0.004663
step: 19000, loss: 0.004185
step: 20000, loss: 0.003846


In [99]:
import matplotlib.pyplot as plt

In [103]:
n = 4
canvas_orig = np.empty((28*n, 28*n))
canvas_recon = np.empty((28*n, 28*n))

for i, (batch_x, _) in enumerate(test_data.take(n)):
    reconstructed_images = decoder(encoder(batch_x))
    for j in range(n):
        img = batch_x[j].numpy().reshape([28,28])
        canvas_orig[i*28:(i+1)*28, j*28:(j+1)*28] = img
        
        for j in range(n):
            reconstr_img = reconstructed_images[j].numpy().reshape([28,28])
            canvas_recon[i*28:(i+1)*28, j*28:(j+1)*28] = reconstr_img 
            
print('original images')
plt.figure(figsize=(n,n))
plt.imshow(canvas_orig, origin="uppder", cmap="gray")
plt.show()


print('reconstructed images')
plt.figure(figsize=(n,n))
plt.imshow(canvas_recon, origin="uppder", cmap="gray")
plt.show()

ValueError: cannot reshape array of size 700 into shape (28,28)