In [1]:
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from help_printing import *
import tensorflow as tf

In [2]:
data_path = 'data/celeba_split/'
data_list = os.listdir(data_path)
print(data_list)

['celeba_140000.npy', 'celeba_60000.npy', 'celeba_40000.npy', 'celeba_160000.npy', 'celeba_200000.npy', 'celeba_80000.npy', 'celeba_120000.npy', 'celeba_100000.npy', 'celeba_180000.npy', 'celeba_20000.npy']


In [3]:
dataset = [np.load(data_path+data) for data in data_list[:4]]
#x_data = np.load(data_path+data_list[0])
x_data = np.concatenate(dataset)
x_data.shape

(80000, 64, 64, 3)

In [39]:
from sklearn.model_selection import train_test_split

x_train, x_test = train_test_split(x_data, train_size = 0.9)
train_ds = tf.data.Dataset.from_tensor_slices((x_train, x_train)).shuffle(10000).batch(256)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, x_test)).batch(256)

In [40]:
from Models import build_vae

encoder, decoder, vae = build_vae()
#encoder.summary()
#decoder.summary()
vae.summary()

Model: "vae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_5 (InputLayer)         [(None, 64, 64, 3)]       0         
_________________________________________________________________
encoder (Model)              [(None, 10), (None, 10),  230900    
_________________________________________________________________
decoder (Model)              (None, 64, 64, 3)         228803    
Total params: 459,703
Trainable params: 459,703
Non-trainable params: 0
_________________________________________________________________


In [41]:
optimizer = tf.keras.optimizers.Adam(0.001)
train_loss =  tf.keras.metrics.Mean(name='train_loss')


def get_rec_loss(inputs, predictions):
    rec_loss = tf.keras.losses.mean_squared_error(inputs, predictions)
    rec_loss *= 64*64
    rec_loss = K.mean(rec_loss)
    return rec_loss

def get_kl_loss(z_log_var, z_mean):
    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5
    return kl_loss

In [42]:
from tensorflow.keras import backend as K

def plot_recimg(save_dir, epoch):
    org_img = x_train[:100]
    rec_img = vae(x_train[:100], training=False)
    fig, ax = plt.subplots(6, 10, figsize=(20, 10))
    for i in range(10):
        for j in range(6):
            if j%2 ==0: img=org_img 
            else: img=rec_img
            ax[j][i].set_axis_off()
            ax[j][i].imshow(img[10*(j//2)+i])
    
    plt.savefig('%s/recimg_%i.png'%(save_dir,epoch))
    if epoch%show_term==0:
        plt.show()
    plt.close('all')

@tf.function
def train_step(inputs):
    with tf.GradientTape() as tape:
        
        z_log_var, z_mean, z = encoder(inputs)
        predictions = decoder(z)
        
        rec_loss = get_rec_loss(inputs, predictions)
        kl_loss = get_kl_loss(z_log_var, z_mean)
        loss = K.mean(rec_loss + kl_loss)
    
    varialbes = vae.trainable_variables
    gradients = tape.gradient(loss, varialbes)
    optimizer.apply_gradients(zip(gradients, varialbes))
    
    train_loss(loss)

In [49]:
save_dir = 'results_VAE/train_6'
if_not_make(save_dir)
epochs = 300
checkpoint = tf.train.Checkpoint(step=tf.Variable(1), encoder=encoder, decoder=decoder, vae=vae)
manager = tf.train.CheckpointManager(checkpoint, save_dir, max_to_keep=5)

In [50]:
for epoch in range(epochs):
    for inputs, outputs in train_ds:
        train_step(inputs)
        
    print("* Epoch: %i, loss: %f "%(epoch, train_loss.result()))
    manager.save()
    plot_recimg(save_dir, epoch)

* Epoch: 0, loss: 86.537964 
* Epoch: 1, loss: 86.533669 
* Epoch: 2, loss: 86.529282 
* Epoch: 3, loss: 86.525078 
* Epoch: 4, loss: 86.520752 
* Epoch: 5, loss: 86.516350 
* Epoch: 6, loss: 86.512047 
* Epoch: 7, loss: 86.507675 
* Epoch: 8, loss: 86.503471 
* Epoch: 9, loss: 86.499283 
* Epoch: 10, loss: 86.495171 
* Epoch: 11, loss: 86.490944 
* Epoch: 12, loss: 86.486610 
* Epoch: 13, loss: 86.482376 
* Epoch: 14, loss: 86.478195 
* Epoch: 15, loss: 86.473869 
* Epoch: 16, loss: 86.469635 
* Epoch: 17, loss: 86.465492 
* Epoch: 18, loss: 86.461372 
* Epoch: 19, loss: 86.457199 
* Epoch: 20, loss: 86.453171 
* Epoch: 21, loss: 86.449219 
* Epoch: 22, loss: 86.445107 
* Epoch: 23, loss: 86.441124 
* Epoch: 24, loss: 86.437057 
* Epoch: 25, loss: 86.433174 
* Epoch: 26, loss: 86.429108 
* Epoch: 27, loss: 86.425255 
* Epoch: 28, loss: 86.421295 
* Epoch: 29, loss: 86.417336 
* Epoch: 30, loss: 86.413567 
* Epoch: 31, loss: 86.409584 
* Epoch: 32, loss: 86.405602 
* Epoch: 33, loss: 8

* Epoch: 268, loss: 85.723404 
* Epoch: 269, loss: 85.721275 
* Epoch: 270, loss: 85.719048 
* Epoch: 271, loss: 85.716911 
* Epoch: 272, loss: 85.714767 
* Epoch: 273, loss: 85.712646 
* Epoch: 274, loss: 85.710510 
* Epoch: 275, loss: 85.708237 
* Epoch: 276, loss: 85.706093 
* Epoch: 277, loss: 85.704033 
* Epoch: 278, loss: 85.701843 
* Epoch: 279, loss: 85.699791 
* Epoch: 280, loss: 85.697647 
* Epoch: 281, loss: 85.695473 
* Epoch: 282, loss: 85.693382 
* Epoch: 283, loss: 85.691315 
* Epoch: 284, loss: 85.689339 
* Epoch: 285, loss: 85.687248 
* Epoch: 286, loss: 85.685188 
* Epoch: 287, loss: 85.683220 
* Epoch: 288, loss: 85.681076 
* Epoch: 289, loss: 85.679001 
* Epoch: 290, loss: 85.676918 
* Epoch: 291, loss: 85.674858 
* Epoch: 292, loss: 85.672798 
* Epoch: 293, loss: 85.670685 
* Epoch: 294, loss: 85.668571 
* Epoch: 295, loss: 85.666443 
* Epoch: 296, loss: 85.664452 
* Epoch: 297, loss: 85.662407 
* Epoch: 298, loss: 85.660378 
* Epoch: 299, loss: 85.658325 
