In [19]:
import numpy as np
import os
import pandas as pd

import tensorflow as tf
from models import *
from def_dict import *

In [16]:
data_path = 'data/w58y67_prepro/'

x_train = np.load(data_path+'x_train.npy')
x_train = x_train/255
train_ds = tf.data.Dataset.from_tensor_slices((x_train, x_train)).shuffle(10000).batch(256)
train_ds

<BatchDataset shapes: ((None, 200, 48, 3), (None, 200, 48, 3)), types: (tf.float64, tf.float64)>

In [5]:
encoder, decoder, vae = build_vae(x_train, 10)

In [6]:
vae.summary()

Model: "vae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 200, 48, 3)]      0         
_________________________________________________________________
encoder (Model)              [(None, 10), (None, 10),  1319380   
_________________________________________________________________
decoder (Model)              (None, 200, 48, 3)        1321571   
Total params: 2,640,951
Trainable params: 2,640,951
Non-trainable params: 0
_________________________________________________________________


In [13]:
checkpoint.restore(tf.train.latest_checkpoint(data_path+'/train_1/ckp/'))

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fbc3e6e6bd0>

In [14]:
optimizer = tf.keras.optimizers.Adam(0.001)
train_loss = tf.keras.metrics.Mean(name='train_loss')
valid_loss = tf.keras.metrics.Mean(name='valid_loss')

def get_rec_loss(inputs, predictions):
    rec_loss = tf.keras.losses.binary_crossentropy(inputs, predictions)
    rec_loss = tf.reduce_mean(rec_loss)
    rec_loss *= 200*48
    return rec_loss

def get_kl_loss(z_log_var, z_mean):
    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = tf.reduce_mean(kl_loss)
    kl_loss *= -0.5
    
    return kl_loss

In [15]:
@tf.function
def train_step(inputs):
    with tf.GradientTape() as tape:
        
        # Get model ouputs
        z_log_var, z_mean, z = encoder(inputs)
        predictions = decoder(z)
        
        # Compute losses
        rec_loss = get_rec_loss(inputs, predictions)
        kl_loss = get_kl_loss(z_log_var, z_mean)
        loss = rec_loss + kl_loss
    
    # Compute gradients
    varialbes = vae.trainable_variables
    gradients = tape.gradient(loss, varialbes)
    # Update weights
    optimizer.apply_gradients(zip(gradients, varialbes))
    
    # Update train loss
    train_loss(loss)

In [11]:
save_dir = 'data/w58y67_prepro/train_2/'
if_not_make(save_dir)

def reduce_lr(pre_v_loss, v_loss, count, lr, patience, factor, min_lr):
    if v_loss < pre_v_loss:
        count = 0
    else:
        count += 1
        if count >= patience: 
            lr = lr*factor
            if lr < min_lr: 
                lr = min_lr
            count = 0
            print('reduce learning rate..', lr)    
    return count, lr

checkpoint = tf.train.Checkpoint(step=tf.Variable(1), encoder=encoder, decoder=decoder, vae=vae)
csv_logger = tf.keras.callbacks.CSVLogger(save_dir+'/training.log')

In [21]:
epochs = 1000

# Initialize values
best_loss, count = float('inf'), 0

# Start epoch loop
for epoch in range(epochs):
    for inputs, outputs in train_ds:
        train_step(inputs)
    
    # Get loss and leraning rate at this epoch
    t_loss = train_loss.result().numpy() 
    l_rate = optimizer.learning_rate.numpy()

    # Control learning rate
    count, lr  = reduce_lr(best_loss, t_loss, count, l_rate, 5, 0.2, 0.00001)
    optimizer.learning_rate = lr
    
    # Save checkpoint if best v_loss 
    if t_loss < best_loss:
        best_loss = t_loss
        checkpoint.save(file_prefix=os.path.join(save_dir+'/ckp/', 'ckp'))
    
    # Save loss, lerning rate
    print("* %i * loss: %f,  best_loss: %f, l_rate: %f, lr_count: %i"%(epoch, t_loss, best_loss, l_rate, count ))
    df = pd.DataFrame({'epoch':[epoch], 'loss':[t_loss], 'best_loss':[best_loss], 'l_rate':[l_rate]  } )
    df.to_csv(save_dir+'/process.csv', mode='a', header=False)
    
    # Reset loss
    train_loss.reset_states()

* 0 * loss: 5689.645020,  best_loss: 5689.645020, l_rate: 0.001000, lr_count: 0
* 1 * loss: 5689.657227,  best_loss: 5689.645020, l_rate: 0.001000, lr_count: 1
* 2 * loss: 5689.658203,  best_loss: 5689.645020, l_rate: 0.001000, lr_count: 2
* 3 * loss: 5689.668457,  best_loss: 5689.645020, l_rate: 0.001000, lr_count: 3
* 4 * loss: 5689.664062,  best_loss: 5689.645020, l_rate: 0.001000, lr_count: 4
reduce learning rate.. 0.00020000000949949026
* 5 * loss: 5689.674805,  best_loss: 5689.645020, l_rate: 0.001000, lr_count: 0
* 6 * loss: 5688.839355,  best_loss: 5688.839355, l_rate: 0.000200, lr_count: 0
* 7 * loss: 5688.720215,  best_loss: 5688.720215, l_rate: 0.000200, lr_count: 0
* 8 * loss: 5688.690430,  best_loss: 5688.690430, l_rate: 0.000200, lr_count: 0
* 9 * loss: 5688.682617,  best_loss: 5688.682617, l_rate: 0.000200, lr_count: 0
* 10 * loss: 5688.670410,  best_loss: 5688.670410, l_rate: 0.000200, lr_count: 0
* 11 * loss: 5688.652832,  best_loss: 5688.652832, l_rate: 0.000200, lr_c

* 100 * loss: 5688.169434,  best_loss: 5688.169434, l_rate: 0.000010, lr_count: 0
* 101 * loss: 5688.187500,  best_loss: 5688.169434, l_rate: 0.000010, lr_count: 1
* 102 * loss: 5688.186523,  best_loss: 5688.169434, l_rate: 0.000010, lr_count: 2
* 103 * loss: 5688.184082,  best_loss: 5688.169434, l_rate: 0.000010, lr_count: 3
* 104 * loss: 5688.187500,  best_loss: 5688.169434, l_rate: 0.000010, lr_count: 4
reduce learning rate.. 1e-05
* 105 * loss: 5688.181152,  best_loss: 5688.169434, l_rate: 0.000010, lr_count: 0
* 106 * loss: 5688.185547,  best_loss: 5688.169434, l_rate: 0.000010, lr_count: 1
* 107 * loss: 5688.184082,  best_loss: 5688.169434, l_rate: 0.000010, lr_count: 2
* 108 * loss: 5688.190918,  best_loss: 5688.169434, l_rate: 0.000010, lr_count: 3
* 109 * loss: 5688.180176,  best_loss: 5688.169434, l_rate: 0.000010, lr_count: 4
reduce learning rate.. 1e-05
* 110 * loss: 5688.190430,  best_loss: 5688.169434, l_rate: 0.000010, lr_count: 0
* 111 * loss: 5688.182617,  best_loss: 5

* 195 * loss: 5688.155273,  best_loss: 5688.151367, l_rate: 0.000010, lr_count: 2
* 196 * loss: 5688.171387,  best_loss: 5688.151367, l_rate: 0.000010, lr_count: 3
* 197 * loss: 5688.166504,  best_loss: 5688.151367, l_rate: 0.000010, lr_count: 4
reduce learning rate.. 1e-05
* 198 * loss: 5688.163574,  best_loss: 5688.151367, l_rate: 0.000010, lr_count: 0
* 199 * loss: 5688.160156,  best_loss: 5688.151367, l_rate: 0.000010, lr_count: 1
* 200 * loss: 5688.165039,  best_loss: 5688.151367, l_rate: 0.000010, lr_count: 2
* 201 * loss: 5688.142578,  best_loss: 5688.142578, l_rate: 0.000010, lr_count: 0
* 202 * loss: 5688.159668,  best_loss: 5688.142578, l_rate: 0.000010, lr_count: 1
* 203 * loss: 5688.170898,  best_loss: 5688.142578, l_rate: 0.000010, lr_count: 2
* 204 * loss: 5688.154297,  best_loss: 5688.142578, l_rate: 0.000010, lr_count: 3
* 205 * loss: 5688.159668,  best_loss: 5688.142578, l_rate: 0.000010, lr_count: 4
reduce learning rate.. 1e-05
* 206 * loss: 5688.153320,  best_loss: 5

KeyboardInterrupt: 