In [1]:
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from help_printing import *
import tensorflow as tf

In [2]:
data_path = 'data/celeba_split/'
data_list = os.listdir(data_path)
print(data_list)

['celeba_140000.npy', 'celeba_60000.npy', 'celeba_40000.npy', 'celeba_160000.npy', 'celeba_200000.npy', 'celeba_80000.npy', 'celeba_120000.npy', 'celeba_100000.npy', 'celeba_180000.npy', 'celeba_20000.npy']


In [3]:
dataset = [np.load(data_path+data) for data in data_list[:4]]
#x_data = np.load(data_path+data_list[0])
x_data = np.concatenate(dataset)
x_data.shape

(80000, 64, 64, 3)

In [39]:
from sklearn.model_selection import train_test_split

x_train, x_test = train_test_split(x_data, train_size = 0.9)
train_ds = tf.data.Dataset.from_tensor_slices((x_train, x_train)).shuffle(10000).batch(256)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, x_test)).batch(256)

In [40]:
from Models import build_vae

encoder, decoder, vae = build_vae()
#encoder.summary()
#decoder.summary()
vae.summary()

Model: "vae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_5 (InputLayer)         [(None, 64, 64, 3)]       0         
_________________________________________________________________
encoder (Model)              [(None, 10), (None, 10),  230900    
_________________________________________________________________
decoder (Model)              (None, 64, 64, 3)         228803    
Total params: 459,703
Trainable params: 459,703
Non-trainable params: 0
_________________________________________________________________


In [41]:
optimizer = tf.keras.optimizers.Adam(0.001)
train_loss =  tf.keras.metrics.Mean(name='train_loss')


def get_rec_loss(inputs, predictions):
    rec_loss = tf.keras.losses.mean_squared_error(inputs, predictions)
    rec_loss *= 64*64
    rec_loss = K.mean(rec_loss)
    return rec_loss

def get_kl_loss(z_log_var, z_mean):
    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5
    return kl_loss

In [42]:
from tensorflow.keras import backend as K

def plot_recimg(save_dir, epoch):
    org_img = x_train[:100]
    rec_img = vae(x_train[:100], training=False)
    fig, ax = plt.subplots(6, 10, figsize=(20, 10))
    for i in range(10):
        for j in range(6):
            if j%2 ==0: img=org_img 
            else: img=rec_img
            ax[j][i].set_axis_off()
            ax[j][i].imshow(img[10*(j//2)+i])
    
    plt.savefig('%s/recimg_%i.png'%(save_dir,epoch))
    #plt.show()
    plt.close('all')

@tf.function
def train_step(inputs):
    with tf.GradientTape() as tape:
        
        z_log_var, z_mean, z = encoder(inputs)
        predictions = decoder(z)
        
        rec_loss = get_rec_loss(inputs, predictions)
        kl_loss = get_kl_loss(z_log_var, z_mean)
        loss = K.mean(rec_loss + kl_loss)
    
    varialbes = vae.trainable_variables
    gradients = tape.gradient(loss, varialbes)
    optimizer.apply_gradients(zip(gradients, varialbes))
    
    train_loss(loss)

In [47]:
save_dir = 'results_VAE/train_5'
if_not_make(save_dir)
epochs = 300
checkpoint = tf.train.Checkpoint(step=tf.Variable(1), encoder=encoder, decoder=decoder, vae=vae)
manager = tf.train.CheckpointManager(checkpoint, save_dir, max_to_keep=5)

In [None]:
for epoch in range(epochs):
    for inputs, outputs in train_ds:
        train_step(inputs)
        
    print("* Epoch: %i, loss: %f "%(epoch, train_loss.result()))
    manager.save()
    plot_recimg(save_dir, epoch)

* Epoch: 0, loss: 89.752388 
* Epoch: 1, loss: 89.722214 
* Epoch: 2, loss: 89.692825 
* Epoch: 3, loss: 89.663979 
* Epoch: 4, loss: 89.635452 
* Epoch: 5, loss: 89.607536 
* Epoch: 6, loss: 89.579430 
* Epoch: 7, loss: 89.551949 
* Epoch: 8, loss: 89.524406 
* Epoch: 9, loss: 89.497520 
* Epoch: 10, loss: 89.470558 
* Epoch: 11, loss: 89.444824 
* Epoch: 12, loss: 89.418747 
* Epoch: 13, loss: 89.393196 
* Epoch: 14, loss: 89.367386 
* Epoch: 15, loss: 89.342560 
* Epoch: 16, loss: 89.317436 
* Epoch: 17, loss: 89.293327 
* Epoch: 18, loss: 89.269669 
* Epoch: 19, loss: 89.245941 
* Epoch: 20, loss: 89.222412 
* Epoch: 21, loss: 89.199074 
* Epoch: 22, loss: 89.175934 
* Epoch: 23, loss: 89.153168 
* Epoch: 24, loss: 89.130676 
* Epoch: 25, loss: 89.108467 
* Epoch: 26, loss: 89.086380 
* Epoch: 27, loss: 89.064499 
* Epoch: 28, loss: 89.042534 
* Epoch: 29, loss: 89.021202 
* Epoch: 30, loss: 89.000252 
* Epoch: 31, loss: 88.979195 
* Epoch: 32, loss: 88.958473 
* Epoch: 33, loss: 8