# Import Libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import glob
from tqdm import tqdm
import matplotlib.image as mpimg


import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing import image
from keras.layers import Conv2D, Conv2DTranspose, Input, Flatten, Dense, Lambda, Reshape
from keras.layers import BatchNormalization
from keras.models import Model
from keras.datasets import mnist
from keras.losses import binary_crossentropy
from keras import backend as K
from sklearn.model_selection import train_test_split

### To avoid compatibility issues between Keras and Tensorflow
tf.compat.v1.disable_eager_execution()


# Load the Data

In [None]:
train_images_glob = glob.glob('../input/coco2017/train2017/train2017/*.jpg')
test_images_glob = glob.glob('../input/coco2017/val2017/val2017/*.jpg')

"""Test whether the data is loaded correctly by checking the length"""
print(len(train_images_glob))
print(len(test_images_glob))

"""Display one of the train and test images"""
plt.imshow(mpimg.imread(train_images_glob[2]))
plt.show()
plt.imshow(mpimg.imread(test_images_glob[2]))
plt.show()


# Copy data into numpy arrays

In [None]:
"""Define train and test data and their sizes"""
x_train = []
x_val = []
train_size = 1000
val_size = 200
image_size = 128

"""Load the images into lists and then convert into np.array"""
for i in tqdm(train_images_glob[0:train_size]):
  img = image.load_img(i, target_size=(image_size,image_size,3))
  img = image.img_to_array(img)
  x_train.append(img)

for i in tqdm(test_images_glob[0:val_size]):
  img = image.load_img(i, target_size=(image_size,image_size,3))
  img = image.img_to_array(img)
  x_val.append(img)

"""Check the length for confirmation"""
x_train = np.array(x_train)
x_val = np.array(x_val)
print(len(x_train), len(x_val))

# Define Variables

In [None]:
img_width, img_height = x_train.shape[1], x_train.shape[2]
batch_size = 128
no_epochs = 50
validation_split = 0.2
verbosity = 1
latent_dim = 100
num_channels = 3

# Data Preprocessing

In [None]:
x_train = x_train.reshape(x_train.shape[0], img_height, img_width, num_channels)
x_val = x_val.reshape(x_val.shape[0], img_height, img_width, num_channels)
input_shape = (img_height, img_width, num_channels)

x_train = x_train.astype("float32")
x_val = x_val.astype("float32")

x_train = x_train / 255
x_val = x_val / 255


# Configure the VAE

# Sampling Function


In [None]:
# Define sampling with reparameterization trick
def sample_z(args):
  mu, sigma = args
  batch     = K.shape(mu)[0]
  dim       = K.int_shape(mu)[1]
  eps       = K.random_normal(shape=(batch, dim))
  return mu + K.exp(sigma / 2) * eps

## Build the encoder layer

In [None]:
i       = Input(shape=input_shape, name='encoder_input')
cx      = Conv2D(filters=16, kernel_size=3, strides=2, padding='same', activation='relu')(i)
cx      = BatchNormalization()(cx)
cx      = Conv2D(filters=32, kernel_size=3, strides=2, padding='same', activation='relu')(cx)
cx      = BatchNormalization()(cx)
cx      = Conv2D(filters=64, kernel_size=3, strides=2, padding='same', activation='relu')(cx)
cx      = BatchNormalization()(cx)
cx      = Conv2D(filters=128, kernel_size=3, strides=2, padding='same', activation='relu')(cx)
cx      = BatchNormalization()(cx)
cx      = Conv2D(filters=256, kernel_size=3, strides=2, padding='same', activation='relu')(cx)
cx      = BatchNormalization()(cx)
x       = Flatten()(cx)
x       = Dense(100, activation='relu')(x)
x       = BatchNormalization()(x)
mu      = Dense(latent_dim, name='latent_mu')(x)
sigma   = Dense(latent_dim, name='latent_sigma')(x)

# Get Conv2D shape for Conv2DTranspose operation in decoder
conv_shape = K.int_shape(cx)



In [None]:
# Use reparameterization trick
z       = Lambda(sample_z, output_shape=(latent_dim, ), name='z')([mu, sigma])

# Instantiate encoder
encoder = Model(i, [mu, sigma, z], name='encoder')
encoder.summary()


# Build the Decoder

In [None]:
# Definition
d_i   = Input(shape=(latent_dim, ), name='decoder_input')
x     = Dense(conv_shape[1] * conv_shape[2] * conv_shape[3], activation='relu')(d_i)
x     = BatchNormalization()(x)
x     = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(x)
cx    = Conv2DTranspose(filters=256, kernel_size=3, strides=2, padding='same', activation='relu')(x)
cx    = BatchNormalization()(cx)
cx    = Conv2DTranspose(filters=128, kernel_size=3, strides=2, padding='same', activation='relu')(cx)
cx    = BatchNormalization()(cx)
cx    = Conv2DTranspose(filters=64, kernel_size=3, strides=2, padding='same', activation='relu')(cx)
cx    = BatchNormalization()(cx)
cx    = Conv2DTranspose(filters=32, kernel_size=3, strides=2, padding='same', activation='relu')(cx)
cx    = BatchNormalization()(cx)
cx    = Conv2DTranspose(filters=16, kernel_size=3, strides=2, padding='same',  activation='relu')(cx)
cx    = BatchNormalization()(cx)
o     = Conv2DTranspose(filters=num_channels, kernel_size=3, activation='sigmoid', padding='same', name='decoder_output')(cx)

# Instantiate decoder
decoder = Model(d_i, o, name='decoder')
decoder.summary()


In [None]:
# Define loss
def kl_reconstruction_loss(true, pred):
  # Reconstruction loss
  reconstruction_loss = binary_crossentropy(K.flatten(true), K.flatten(pred)) * img_width * img_height
  # KL divergence loss
  kl_loss = 1 + sigma - K.square(mu) - K.exp(sigma)
  kl_loss = K.sum(kl_loss, axis=-1)
  kl_loss *= -0.5
  # Total loss = 50% rec + 50% KL divergence loss
  return K.mean(reconstruction_loss + kl_loss)

In [None]:
# Instantiate VAE
vae_outputs = decoder(encoder(i)[2])
vae         = Model(i, vae_outputs, name='vae')
vae.summary()

## Compile the VAE

In [None]:
vae.compile(optimizer='RMSprop', loss=kl_reconstruction_loss)


## Train the Model
### having multiple layers of train helps us stop frequently and change run different parts of the code as needed without keyboard interrupt

In [None]:
vae.fit(x_train, x_train, epochs = no_epochs, batch_size = batch_size, validation_split = validation_split)


In [None]:
vae.fit(x_train, x_train, epochs = 10, batch_size = batch_size, validation_split = 0)

In [None]:
vae.fit(x_train, x_train, epochs = 40, batch_size = batch_size, validation_split = validation_split)

In [None]:
vae.fit(x_train, x_train, epochs = 10, batch_size = batch_size, validation_split = 0)

# Make Predictions on Test Data

In [None]:
predictions = vae.predict(x_val, batch_size = 128)

# Display Results

In [None]:
n = 5
plt.figure(figsize= (30,20))

for i in range(n):
  ax = plt.subplot(2, n, i+1)
#   plt.imshow(val_x_px[i+20])
  plt.imshow(x_val[i+40])  
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False) 
  ax = plt.subplot(2, n, i+1+n)
  plt.imshow(predictions[i+40])
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

# Try CelebA Dataset

In [None]:
train_images_glob = glob.glob('../input/celeba-dataset/img_align_celeba/img_align_celeba/*.jpg')

print(len(train_images_glob))
plt.imshow(mpimg.imread(train_images_glob[2]))
plt.show()

# Load Data

In [None]:
all_images = []
image_size = 128
train_size = 6000

for i in tqdm(train_images_glob[0:train_size]):
  img = image.load_img(i, target_size=(image_size,image_size,3))
  img = image.img_to_array(img)
  all_images.append(img)
"""Change the data into an np.array"""
all_images = np.array(all_images)
print(len(all_images))

## Split into train and test data

In [None]:
x_train, x_test = train_test_split(
                                     all_images, 
                                     shuffle=True,
                                     test_size = 0.2,
                                     random_state=42
                                  )
"""Reshape the data into proper form and normalize"""
x_train = x_train.reshape(x_train.shape[0], image_size, image_size, num_channels)
x_test = x_test.reshape(x_test.shape[0], image_size, image_size, num_channels)
input_shape = (image_size, image_size, num_channels)

x_train = x_train.astype("float32")
x_test = x_test.astype("float32")

x_train = x_train / 255
x_test = x_test / 255

# Train Model on New Data

In [None]:
vae.fit(x_train, x_train, epochs = no_epochs, batch_size = batch_size, validation_split = validation_split)

In [None]:
vae.fit(x_train, x_train, epochs = 10, batch_size = batch_size, validation_split = 0)

In [None]:
vae.fit(x_train, x_train, epochs = 30, batch_size = batch_size, validation_split = validation_split)

In [None]:
vae.fit(x_train, x_train, epochs = 10, batch_size = batch_size, validation_split = 0)

# Make Predictions on Test Data

In [None]:
predictions = vae.predict(x_test, batch_size = 128)

In [None]:
n = 5
plt.figure(figsize= (30,20))

for i in range(n):
  ax = plt.subplot(2, n, i+1)
#   plt.imshow(val_x_px[i+20])
  plt.imshow(x_test[i+40])  
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False) 
  ax = plt.subplot(2, n, i+1+n)
  plt.imshow(predictions[i+40])
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

# Save the Weights for Future Use

In [None]:
from random import randint as r
vae.save_weights("vae-weights_6k_celeba"+str(r(0,3653))+".h5")