# Adversarial Autoencoders using Keras

In [1]:
import tensorflow
import numpy as np
import matplotlib.pyplot as plt
import keras

from keras.layers import Input, Dense, Activation, LeakyReLU
from keras.models import Sequential, Model
from keras.datasets import mnist
from keras.optimizers import Adam


Using TensorFlow backend.


### Parameters and hyperparameters
rows, columns and channels dictate the properties of the input image
latent_size is the size of the "squished" or encoded representation
optimizer is just the optimizer
epochs & batch_size are just training hyperparameters

In [2]:
# just some variables

rows = 28
cols = 28
channels = 1
img_shape = (rows, cols, channels)
img_size = rows * cols * channels

latent_size = 10

optimizer = Adam(0.0002, 0.5)

epochs = 5000
batch_size = 128


Instructions for updating:
Colocations handled automatically by placer.


## Encoder
The encoder takes in a full sized image and compresses it to a smaller dimensionality, in this case it goes from img_size to latent_size

In [3]:
# create the encoder
def create_encoder(latent_size, img_size):
  model = Sequential()
  model.add(Dense(512, input_dim=img_size))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Dense(512))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Dense(latent_size))
  
  in_img = Input(shape=(img_size,))
  encoded_repr = model(in_img)
  
  return Model(in_img, encoded_repr)

## Decoder
The decoder takes in a compressed version of the image and tries to recreate the original image. We are using sigmoid(from 0-1) because we normalize the images from 0-1

In [4]:
# create the decoder
def create_decoder(latent_size, img_size):
  model = Sequential()
  model.add(Dense(512, input_dim=latent_size))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Dense(512))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Dense(img_size, activation='sigmoid'))
  
  encoded_repr = Input(shape=(latent_size,))
  out_img = model(encoded_repr)
  
  return Model(encoded_repr, out_img)

## Discriminator
This is the main part of *adversarial* autoencoders. The discriminator first learns how to tell whether an array belongs to a given distribution. The encoder then uses the discriminators output to try to make the encoded representation fit the given latent space.

In [5]:
# create the discriminator
def create_discriminator(latent_size):
  model = Sequential()
  model.add(Dense(512, input_dim=latent_size))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Dense(256))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Dense(1, activation='sigmoid'))
  
  encoded_repr = Input(shape=(latent_size,))
  probability = model(encoded_repr)
  
  return Model(encoded_repr, probability)

In [6]:
# putting it together

discriminator = create_discriminator(latent_size)
discriminator.compile(loss='binary_crossentropy',
                      optimizer=optimizer,
                      metrics=['accuracy'])

encoder = create_encoder(latent_size, img_size)
decoder = create_decoder(latent_size, img_size)

img = Input(shape=(img_size,))

encoded_repr = encoder(img)

decoded = decoder(encoded_repr)

discriminator.trainable = False

validity = discriminator(encoded_repr)

adversarial = Model(img, [decoded, validity])
adversarial.compile(loss=['mse', 'binary_crossentropy'], 
                          loss_weights=[0.999, 0.001], 
                          optimizer=optimizer)

### Getting and preprocessing the data
We only need the training set because AAEs don't care about classification(some do, but not this one)
Then we normalize the data to [0,1] and finally reshape the 28x28 images to 784x1 arrays

In [7]:
# get and preprocess data

(x_train, _), (_, _) = mnist.load_data()

x_train = x_train.astype(np.float64) / 255
x_train = x_train.reshape(x_train.shape[0], img_size)

In [8]:
# create real and fake "answers"
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

for epoch in range(epochs):
  # Select a random batch of images
    idx = np.random.randint(0, x_train.shape[0], batch_size)
    imgs = x_train[idx]

    latent_fake = encoder.predict(imgs)
    
    # Here we generate the "TRUE" samples
    latent_real = np.random.normal(size=(batch_size, latent_size))
                      
    # Train the discriminator
    d_loss_real = discriminator.train_on_batch(latent_real, valid)
    d_loss_fake = discriminator.train_on_batch(latent_fake, fake)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # ---------------------
    #  Train Generator
    # ---------------------

    # Train the generator
    g_loss = adversarial.train_on_batch(imgs, [imgs, valid])

    # Plot the progress (every 10th epoch)
    if epoch % 10 == 0:
        print ("%d [D loss: %f, acc: %.2f%%] [G loss: %f, mse: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[0], g_loss[1]))

Instructions for updating:
Use tf.cast instead.


  'Discrepancy between trainable weights and collected trainable'


0 [D loss: 0.690654, acc: 34.77%] [G loss: 0.231352, mse: 0.230906]
10 [D loss: 0.212701, acc: 97.66%] [G loss: 0.115636, mse: 0.110882]
20 [D loss: 0.108238, acc: 99.61%] [G loss: 0.074324, mse: 0.069094]
30 [D loss: 0.072832, acc: 100.00%] [G loss: 0.075493, mse: 0.070393]
40 [D loss: 0.084835, acc: 98.44%] [G loss: 0.076633, mse: 0.070186]
50 [D loss: 0.068241, acc: 97.66%] [G loss: 0.070625, mse: 0.063932]
60 [D loss: 0.087013, acc: 96.48%] [G loss: 0.066440, mse: 0.060119]
70 [D loss: 0.300746, acc: 86.72%] [G loss: 0.070006, mse: 0.061908]
80 [D loss: 0.823973, acc: 62.50%] [G loss: 0.061947, mse: 0.056382]
90 [D loss: 2.009838, acc: 32.42%] [G loss: 0.065888, mse: 0.059843]
100 [D loss: 0.849680, acc: 57.81%] [G loss: 0.065149, mse: 0.059014]
110 [D loss: 0.651141, acc: 65.23%] [G loss: 0.056076, mse: 0.051943]
120 [D loss: 0.549159, acc: 69.14%] [G loss: 0.052877, mse: 0.048291]
130 [D loss: 0.466202, acc: 69.14%] [G loss: 0.049886, mse: 0.046258]
140 [D loss: 0.412697, acc: 76

1170 [D loss: 0.684275, acc: 55.86%] [G loss: 0.024441, mse: 0.023619]
1180 [D loss: 0.644888, acc: 62.50%] [G loss: 0.024425, mse: 0.023591]
1190 [D loss: 0.653634, acc: 62.50%] [G loss: 0.023553, mse: 0.022728]
1200 [D loss: 0.655416, acc: 61.33%] [G loss: 0.023969, mse: 0.023133]
1210 [D loss: 0.643297, acc: 64.45%] [G loss: 0.024343, mse: 0.023477]
1220 [D loss: 0.661585, acc: 62.50%] [G loss: 0.022657, mse: 0.021802]
1230 [D loss: 0.641748, acc: 63.67%] [G loss: 0.023387, mse: 0.022502]
1240 [D loss: 0.654229, acc: 60.94%] [G loss: 0.024241, mse: 0.023384]
1250 [D loss: 0.676565, acc: 60.94%] [G loss: 0.025082, mse: 0.024261]
1260 [D loss: 0.633439, acc: 61.33%] [G loss: 0.023918, mse: 0.023026]
1270 [D loss: 0.606891, acc: 68.75%] [G loss: 0.022780, mse: 0.021877]
1280 [D loss: 0.666462, acc: 60.16%] [G loss: 0.021597, mse: 0.020750]
1290 [D loss: 0.656757, acc: 59.77%] [G loss: 0.023214, mse: 0.022384]
1300 [D loss: 0.589286, acc: 73.44%] [G loss: 0.020907, mse: 0.020000]
1310 [

2330 [D loss: 0.608908, acc: 68.75%] [G loss: 0.022844, mse: 0.022013]
2340 [D loss: 0.603668, acc: 69.53%] [G loss: 0.020977, mse: 0.020073]
2350 [D loss: 0.570980, acc: 73.44%] [G loss: 0.019521, mse: 0.018607]
2360 [D loss: 0.596574, acc: 70.31%] [G loss: 0.019979, mse: 0.019107]
2370 [D loss: 0.601010, acc: 68.75%] [G loss: 0.020169, mse: 0.019281]
2380 [D loss: 0.589112, acc: 73.05%] [G loss: 0.019498, mse: 0.018629]
2390 [D loss: 0.580403, acc: 72.27%] [G loss: 0.018924, mse: 0.018008]
2400 [D loss: 0.598034, acc: 68.36%] [G loss: 0.020223, mse: 0.019387]
2410 [D loss: 0.596334, acc: 68.75%] [G loss: 0.020400, mse: 0.019551]
2420 [D loss: 0.643203, acc: 64.84%] [G loss: 0.021247, mse: 0.020440]
2430 [D loss: 0.610695, acc: 69.14%] [G loss: 0.019834, mse: 0.018992]
2440 [D loss: 0.605267, acc: 69.92%] [G loss: 0.019622, mse: 0.018757]
2450 [D loss: 0.581479, acc: 71.48%] [G loss: 0.020712, mse: 0.019804]
2460 [D loss: 0.573392, acc: 73.83%] [G loss: 0.020947, mse: 0.020056]
2470 [

3490 [D loss: 0.516818, acc: 75.39%] [G loss: 0.020315, mse: 0.019295]
3500 [D loss: 0.525777, acc: 76.95%] [G loss: 0.018792, mse: 0.017686]
3510 [D loss: 0.508118, acc: 76.95%] [G loss: 0.019174, mse: 0.018062]
3520 [D loss: 0.522576, acc: 78.12%] [G loss: 0.018441, mse: 0.017323]
3530 [D loss: 0.512389, acc: 75.00%] [G loss: 0.018431, mse: 0.017286]
3540 [D loss: 0.519514, acc: 77.34%] [G loss: 0.018227, mse: 0.017144]
3550 [D loss: 0.510173, acc: 77.73%] [G loss: 0.017435, mse: 0.016270]
3560 [D loss: 0.544854, acc: 73.83%] [G loss: 0.019286, mse: 0.018238]
3570 [D loss: 0.543173, acc: 75.39%] [G loss: 0.017781, mse: 0.016664]
3580 [D loss: 0.519036, acc: 74.61%] [G loss: 0.020251, mse: 0.019168]
3590 [D loss: 0.494329, acc: 78.91%] [G loss: 0.018216, mse: 0.017040]
3600 [D loss: 0.532226, acc: 75.00%] [G loss: 0.017852, mse: 0.016784]
3610 [D loss: 0.482759, acc: 82.03%] [G loss: 0.018789, mse: 0.017642]
3620 [D loss: 0.526595, acc: 75.78%] [G loss: 0.018996, mse: 0.017856]
3630 [

4650 [D loss: 0.529944, acc: 75.78%] [G loss: 0.018960, mse: 0.017817]
4660 [D loss: 0.488874, acc: 78.91%] [G loss: 0.017565, mse: 0.016324]
4670 [D loss: 0.492241, acc: 76.17%] [G loss: 0.017680, mse: 0.016400]
4680 [D loss: 0.502978, acc: 76.95%] [G loss: 0.018049, mse: 0.016813]
4690 [D loss: 0.528250, acc: 73.05%] [G loss: 0.018912, mse: 0.017731]
4700 [D loss: 0.547764, acc: 73.83%] [G loss: 0.018679, mse: 0.017494]
4710 [D loss: 0.511847, acc: 76.17%] [G loss: 0.017181, mse: 0.015986]
4720 [D loss: 0.499838, acc: 78.12%] [G loss: 0.018252, mse: 0.017008]
4730 [D loss: 0.561771, acc: 68.36%] [G loss: 0.019047, mse: 0.017871]
4740 [D loss: 0.500712, acc: 74.61%] [G loss: 0.017457, mse: 0.016247]
4750 [D loss: 0.546242, acc: 73.05%] [G loss: 0.018057, mse: 0.016856]
4760 [D loss: 0.473482, acc: 80.08%] [G loss: 0.019090, mse: 0.017894]
4770 [D loss: 0.529951, acc: 75.39%] [G loss: 0.016348, mse: 0.015164]
4780 [D loss: 0.501831, acc: 74.61%] [G loss: 0.017877, mse: 0.016603]
4790 [