In [1]:
from keras.layers import Dense, Input
from keras.layers import Conv2D, Flatten, Lambda
from keras.layers import Reshape, Conv2DTranspose, BatchNormalization
from keras.models import Model
from keras.datasets import cifar10
from keras.losses import mse, binary_crossentropy
from keras.utils import plot_model
from keras import backend as K
from keras.callbacks import CSVLogger
from keras.optimizers import Adam, RMSprop

import numpy as np
import matplotlib.pyplot as plt
import argparse
from os import walk
import matplotlib.image as mpimg
import cv2

Using TensorFlow backend.


In [0]:
# dimension of the latent space (min 2)
latent_dim = 512
image_dim = 64
input_shape = (64, 64, 3)

# dataset

In [0]:
!unzip augmented.zip

Archive:  augmented.zip
 extracting: augmented/1.png         
 extracting: augmented/10.png        
 extracting: augmented/100.png       
 extracting: augmented/1000.png      
 extracting: augmented/10000.png     
 extracting: augmented/10001.png     
 extracting: augmented/10002.png     
 extracting: augmented/10003.png     
 extracting: augmented/10004.png     
 extracting: augmented/10005.png     
 extracting: augmented/10006.png     
 extracting: augmented/10007.png     
 extracting: augmented/10008.png     
 extracting: augmented/10009.png     
 extracting: augmented/1001.png      
 extracting: augmented/10010.png     
 extracting: augmented/10011.png     
 extracting: augmented/10012.png     
 extracting: augmented/10013.png     
 extracting: augmented/10014.png     
 extracting: augmented/10015.png     
 extracting: augmented/10016.png     
 extracting: augmented/10017.png     
 extracting: augmented/10018.png     
 extracting: augmented/10019.png     
 extracting: augmented/100

In [6]:
#----------------------------------------------------------------------------
# Read the dataset
#----------------------------------------------------------------------------
imgs = []

f = []
for (dirpath, dirnames, filenames) in walk('augmented'):
    f.extend(filenames)

for src in f:
    #read PNG images and convert them to RGB tensors
    img = mpimg.imread('augmented/' + src)[:,:,0:3]
    img = cv2.resize(img, (image_dim,image_dim))
    imgs.append(img)
    
imgs = np.array(imgs)

# 64
image_size = imgs.shape[1]
# 64x64x3
input_shape = (image_size, image_size, 3)

# show some of them
f, axarr = plt.subplots(1, 5)
f.set_size_inches(16, 6)
for i in range(5):
    img = imgs[i]
    axarr[i].imshow(img)
plt.show()

IndexError: ignored

# AAE MODEL

## ENCODER

In [0]:
# reparameterization trick
# z = mean + sqrt(var)*eps
# where only eps is random, we can backprop on the other parameters
def sampling(args):
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

In [8]:
inputs = Input(shape=input_shape, name='encoder_input')
x = inputs
x = Conv2D(filters=32, kernel_size=3, activation='relu', strides=2, padding='same')(x)
x = Conv2D(filters=64, kernel_size=2, activation='relu', strides=2, padding='same')(x)
x = Conv2D(filters=128, kernel_size=3, activation='relu', strides=2, padding='same')(x)     
x = Conv2D(filters=256, kernel_size=3, activation='relu', strides=2, padding='same')(x)   

# shape info needed to build decoder model
shape = K.int_shape(x)

# generate latent vector Q(z|X)
x = Flatten()(x)
x = Dense(latent_dim*2, activation='relu')(x)
z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)

# use reparameterization trick to push the sampling out as input
z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])

# instantiate encoder model
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
encoder.summary()





Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_input (InputLayer)      (None, 64, 64, 3)    0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 32, 32, 32)   896         encoder_input[0][0]              
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 16, 16, 64)   8256        conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 8, 8, 128)    73856       conv2d_2[0][0]                   
________________________________________________________________________________________

## DECODER

In [9]:
# input = z latent space
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = Dense(shape[1] * shape[2] * shape[3], activation='relu')(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)

x = Conv2DTranspose(filters=256, kernel_size=5, activation='relu', strides=2, padding='same')(x)
x = BatchNormalization(momentum=0.9)(x)
x = Conv2DTranspose(filters=128, kernel_size=5, activation='relu', strides=2, padding='same')(x)
x = BatchNormalization(momentum=0.9)(x)
x = Conv2DTranspose(filters=64, kernel_size=3, activation='relu', strides=2, padding='same')(x)
x = BatchNormalization(momentum=0.9)(x)
x = Conv2DTranspose(filters=32, kernel_size=2, activation='relu', strides=2, padding='same')(x)
x = BatchNormalization(momentum=0.9)(x)
outputs = Conv2DTranspose(filters=3, kernel_size=2, activation='sigmoid', padding='same', name='decoder_output')(x)

# instantiate decoder model
decoder = Model(latent_inputs, outputs, name='decoder')
decoder.summary()









Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
z_sampling (InputLayer)      (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 4096)              2101248   
_________________________________________________________________
reshape_1 (Reshape)          (None, 4, 4, 256)         0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 8, 8, 256)         1638656   
_________________________________________________________________
batch_normalization_1 (Batch (None, 8, 8, 256)         1024      
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 16, 16, 128)       819328    
_________________________________________________________________
batch_normalization_2 (Batch (None, 16, 16, 128)   

## DISCRIMINATOR

In [10]:
disc_inputs = Input(shape=(latent_dim,))
x = Dense(shape[1] * shape[2] * shape[3], activation='relu')(disc_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)
x = Conv2D(filters=64, kernel_size=2, activation='relu', strides=2, padding='same')(x)
x = Conv2D(filters=32, kernel_size=3, activation='relu', strides=2, padding='same')(x)
x = Flatten()(x)     
disc_out = Dense(1, activation='sigmoid')(x)

# instantiate discriminator model
discriminator = Model(disc_inputs, disc_out, name='discriminator')
discriminator.summary()

# instantiate encoder_discriminator model
encoder_discriminator = Model(inputs, discriminator(encoder(inputs)[2]))

# compile the trainable discriminator
discriminator.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])

# compile the encoder discriminator model, where only the encoder is trainable
for layer in discriminator.layers:
    layer.trainable = False
encoder_discriminator.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])

Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 512)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 4096)              2101248   
_________________________________________________________________
reshape_2 (Reshape)          (None, 4, 4, 256)         0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 2, 2, 64)          65600     
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 1, 1, 32)          18464     
_________________________________________________________________
flatten_2 (Flatten)          (None, 32)                0         
_________________________________________________________________
dense_4 (Dense)              (None, 1)               

## AAE

In [0]:
# out = D(E(x))
# where E(x)=z
outputs = decoder(encoder(inputs)[2])
aae = Model(inputs, outputs, name='aae')

In [0]:
# now AAE loss is only = xent_loss
reconstruction_loss = binary_crossentropy(K.flatten(inputs),K.flatten(outputs))
reconstruction_loss *= image_size * image_size
aae_loss = K.mean(reconstruction_loss)

In [0]:
# compile AAE with loss
aae.add_loss(aae_loss)
aae.compile(optimizer='adam')
aae.summary()

Model: "aae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
encoder_input (InputLayer)   (None, 128, 128, 3)       0         
_________________________________________________________________
encoder (Model)              [(None, 1024), (None, 102 38131008  
_________________________________________________________________
decoder (Model)              (None, 128, 128, 3)       19335907  
Total params: 57,466,915
Trainable params: 57,465,955
Non-trainable params: 960
_________________________________________________________________


# TRAIN

In [0]:
def train( x_train, batch_size=256, epochs=1200):
        trainG, trainD = True, True
        half_batch = int(batch_size / 2)
        for epoch in range(epochs):
            #---------------Train Discriminator -------------
            # on half a batch real and half a batch fake data
            idx = np.random.randint(0, x_train.shape[0], half_batch)
            imgs = x_train[idx]
            latent_fake = encoder.predict(imgs)[2]
            latent_real = np.random.normal(size=(half_batch, latent_dim))
            valid = np.ones((half_batch, 1))
            fake = np.zeros((half_batch, 1))
            if (trainD):
              d_loss_real = discriminator.train_on_batch(latent_real, valid)
              d_loss_fake = discriminator.train_on_batch(latent_fake, fake)
              d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            #---------------Train Generator -------------
            # on one batch fake data
            idx = np.random.randint(0, x_train.shape[0], batch_size)
            imgs = x_train[idx]
            valid_y = np.ones((batch_size, 1))
            if (trainG):
              g_logg_similarity = encoder_discriminator.train_on_batch(imgs, valid_y)

            # with this the adversarial part can learn the latent distribution better, but the generated images are better without it
            # if one model of the adversarial part is too strong, then we dont let it train
            if (d_loss[1]*0.8 > g_logg_similarity[1]):
              trainG = True
              trainD = False
            elif (g_logg_similarity[1]*0.8 > d_loss[1]):
              trainG = False
              trainD = True
            else:
              trainG = True
              trainD = True

            #---------------Train Autoencoder -------------
            g_loss_reconstruction = aae.fit(imgs, validation_data=(imgs, None), verbose=0)

            #---------------Plot the progress---------------
            print ("%d [D loss: %f, acc: %.2f%%] [G acc: %f, mse: %f]" % (epoch, d_loss[0], 100*d_loss[1],
                   g_logg_similarity[1], g_loss_reconstruction.history['val_loss'][0]))

In [0]:
train(imgs)
aae.save_weights('aae.h5')




  'Discrepancy between trainable weights and collected trainable'





  'Discrepancy between trainable weights and collected trainable'


0 [D loss: 1.904352, acc: 26.56%] [G acc: 0.699219, mse: 15505.765381]


  'Discrepancy between trainable weights and collected trainable'


1 [D loss: 8.239217, acc: 35.94%] [G acc: 0.699219, mse: 11849.808716]
2 [D loss: 7.728102, acc: 43.36%] [G acc: 0.699219, mse: 15785.075439]
3 [D loss: 1.985193, acc: 46.09%] [G acc: 0.699219, mse: 15220.691284]
4 [D loss: 0.263995, acc: 94.92%] [G acc: 0.699219, mse: 11041.561157]
5 [D loss: 0.263995, acc: 94.92%] [G acc: 0.000000, mse: 12318.646118]
6 [D loss: 0.263995, acc: 94.92%] [G acc: 0.000000, mse: 10203.781738]
7 [D loss: 0.263995, acc: 94.92%] [G acc: 0.000000, mse: 8103.719788]
8 [D loss: 0.263995, acc: 94.92%] [G acc: 0.000000, mse: 8442.456299]
9 [D loss: 0.263995, acc: 94.92%] [G acc: 0.000000, mse: 8976.224854]
10 [D loss: 0.263995, acc: 94.92%] [G acc: 0.000000, mse: 7138.197510]
11 [D loss: 0.263995, acc: 94.92%] [G acc: 0.230469, mse: 7592.805115]
12 [D loss: 0.263995, acc: 94.92%] [G acc: 1.000000, mse: 6679.739441]
13 [D loss: 7.104215, acc: 45.70%] [G acc: 0.000000, mse: 8414.326416]
14 [D loss: 7.104215, acc: 45.70%] [G acc: 0.000000, mse: 6469.190552]
15 [D los

# PLOT

In [0]:
def plot_results(encoder, decoder, x_test, y_test, batch_size=128, model_name="aae_cifar"):
    
    # display a nxn 2D manifold of images
    n = 10
    img_size = image_dim
    figure = np.zeros((img_size * n, img_size * n,3))
    # linearly spaced coordinates corresponding to the 2D plot
    grid_x = np.linspace(-1, 1, n)
    grid_y = np.linspace(-1, 1, n)[::-1]

  # display a 2D plot of the image classes in the first 2 dimensions of the latent space
    z_mean, _, _ = encoder.predict(x_test, batch_size=batch_size)
    plt.figure(figsize=(n, n))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=y_test)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.show()

  # display the generated images along the first 2 dimensions of the latent space
    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.random.normal(0.0, 4, (1,latent_dim))
            # the first two dim is given
            #z_sample[0,0]=xi
            #z_sample[0,1]=yi
            # generate an image from the latent representation
            x_decoded = decoder.predict(z_sample)
            img = x_decoded[0].reshape(img_size, img_size, 3)

            #sharpen image
            #kernel_sharpening = np.array([[-1,-1,-1], [-1, 9,-1], [-1,-1,-1]])
            #img = cv2.filter2D(img, -1, kernel_sharpening)

            figure[i * img_size: (i + 1) * img_size,
                   j * img_size: (j + 1) * img_size,] = img

    plt.figure(figsize=(20, 20))
    start_range = img_size // 2
    end_range = n * img_size + start_range + 1
    pixel_range = np.arange(start_range, end_range, img_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap='Greys_r')
    plt.show()



In [0]:
# now we have only one class
plot_results(encoder, decoder, imgs, np.ones((imgs.shape[0])), batch_size=128, model_name="aae_cnn")

Output hidden; open in https://colab.research.google.com to view.