In [1]:
import numpy as np
import os
import pandas as pd

import tensorflow as tf
from models import *
from def_dict import *

In [6]:
data_path = 'data/w58y67/step_2/'

x_train = np.load(data_path+'x_train.npy')
x_train = x_train/255
train_ds = tf.data.Dataset.from_tensor_slices((x_train, x_train)).shuffle(10000).batch(256)
train_ds

<BatchDataset shapes: ((None, 216, 64, 3), (None, 216, 64, 3)), types: (tf.float64, tf.float64)>

In [7]:
encoder, decoder, vae = build_vae(x_train, 10)

In [8]:
vae.summary()

Model: "vae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 216, 64, 3)]      0         
_________________________________________________________________
encoder (Model)              [(None, 10), (None, 10),  1860052   
_________________________________________________________________
decoder (Model)              (None, 216, 64, 3)        1864355   
Total params: 3,724,407
Trainable params: 3,724,407
Non-trainable params: 0
_________________________________________________________________


In [10]:
optimizer = tf.keras.optimizers.Adam(0.001)
train_loss = tf.keras.metrics.Mean(name='train_loss')
valid_loss = tf.keras.metrics.Mean(name='valid_loss')

def get_rec_loss(inputs, predictions):
    rec_loss = tf.keras.losses.binary_crossentropy(inputs, predictions)
    rec_loss = tf.reduce_mean(rec_loss)
    rec_loss *= 200*48
    return rec_loss

def get_kl_loss(z_log_var, z_mean):
    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = tf.reduce_mean(kl_loss)
    kl_loss *= -0.5
    
    return kl_loss

In [11]:
@tf.function
def train_step(inputs):
    with tf.GradientTape() as tape:
        
        # Get model ouputs
        z_log_var, z_mean, z = encoder(inputs)
        predictions = decoder(z)
        
        # Compute losses
        rec_loss = get_rec_loss(inputs, predictions)
        kl_loss = get_kl_loss(z_log_var, z_mean)
        loss = rec_loss + kl_loss
    
    # Compute gradients
    varialbes = vae.trainable_variables
    gradients = tape.gradient(loss, varialbes)
    # Update weights
    optimizer.apply_gradients(zip(gradients, varialbes))
    
    # Update train loss
    train_loss(loss)

In [12]:
save_dir = 'train_result/train_1/'
if_not_make(save_dir)

def reduce_lr(pre_v_loss, v_loss, count, lr, patience, factor, min_lr):
    if v_loss < pre_v_loss:
        count = 0
    else:
        count += 1
        if count >= patience: 
            lr = lr*factor
            if lr < min_lr: 
                lr = min_lr
            count = 0
            print('reduce learning rate..', lr)    
    return count, lr

checkpoint = tf.train.Checkpoint(step=tf.Variable(1), encoder=encoder, decoder=decoder, vae=vae)
csv_logger = tf.keras.callbacks.CSVLogger(save_dir+'/training.log')

In [13]:
checkpoint.restore(tf.train.latest_checkpoint('train_result/train_1/ckp/'))

<tensorflow.python.training.tracking.util.InitializationOnlyStatus at 0x7f82e9a4e790>

In [None]:
epochs = 300

# Initialize values
best_loss, count = float('inf'), 0

# Start epoch loop
for epoch in range(epochs):
    for inputs, outputs in train_ds:
        train_step(inputs)
    
    # Get loss and leraning rate at this epoch
    t_loss = train_loss.result().numpy() 
    l_rate = optimizer.learning_rate.numpy()

    # Control learning rate
    count, lr  = reduce_lr(best_loss, t_loss, count, l_rate, 5, 0.2, 0.00001)
    optimizer.learning_rate = lr
    
    # Save checkpoint if best v_loss 
    if t_loss < best_loss:
        best_loss = t_loss
        checkpoint.save(file_prefix=os.path.join(save_dir+'/ckp/', 'ckp'))
    
    # Save loss, lerning rate
    print("* %i * loss: %f,  best_loss: %f, l_rate: %f, lr_count: %i"%(epoch, t_loss, best_loss, l_rate, count ))
    df = pd.DataFrame({'epoch':[epoch], 'loss':[t_loss], 'best_loss':[best_loss], 'l_rate':[l_rate]  } )
    df.to_csv(save_dir+'/process.csv', mode='a', header=False)
    
    # Reset loss
    train_loss.reset_states()

* 0 * loss: 6102.498535,  best_loss: 6102.498535, l_rate: 0.001000, lr_count: 0
* 1 * loss: 5815.026855,  best_loss: 5815.026855, l_rate: 0.001000, lr_count: 0
* 2 * loss: 5755.753418,  best_loss: 5755.753418, l_rate: 0.001000, lr_count: 0
* 3 * loss: 5730.976074,  best_loss: 5730.976074, l_rate: 0.001000, lr_count: 0
* 4 * loss: 5721.758301,  best_loss: 5721.758301, l_rate: 0.001000, lr_count: 0
* 5 * loss: 5716.935547,  best_loss: 5716.935547, l_rate: 0.001000, lr_count: 0
* 6 * loss: 5713.729980,  best_loss: 5713.729980, l_rate: 0.001000, lr_count: 0
* 7 * loss: 5711.673340,  best_loss: 5711.673340, l_rate: 0.001000, lr_count: 0
* 8 * loss: 5709.840332,  best_loss: 5709.840332, l_rate: 0.001000, lr_count: 0
* 9 * loss: 5708.485352,  best_loss: 5708.485352, l_rate: 0.001000, lr_count: 0
* 10 * loss: 5707.375000,  best_loss: 5707.375000, l_rate: 0.001000, lr_count: 0
* 11 * loss: 5706.532715,  best_loss: 5706.532715, l_rate: 0.001000, lr_count: 0
* 12 * loss: 5705.679199,  best_loss: 