In [1]:
import numpy as np
import os
import pandas as pd

import tensorflow as tf
from models import *
from def_dict import *

In [2]:
data_path = 'data/w58y67/preprocessing_2/step_2/'

data = np.load(data_path+'data_same.npy')
data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1))
data_info = np.load(data_path+'data_info.npy')
data.shape, data_info.shape

((51721, 128, 64, 1), (51721, 3))

In [7]:
save_path = 'train_result/train_5/'
if_not_make(save_path)

In [8]:
from sklearn.model_selection import train_test_split

x_train, x_test, y_train, y_test = train_test_split(data, data_info, test_size=0.1, random_state=321)
print('* Training data shape: ', x_train.shape)
#print('* Validataion data shape : ', x_valid.shape)
print('* Test data shape : ', x_test.shape)

np.save(save_path+'x_train', x_train)
np.save(save_path+'y_train', y_train)
np.save(save_path+'x_test', x_test)
np.save(save_path+'y_test', y_test)

* Training data shape:  (46548, 128, 64, 1)
* Test data shape :  (5173, 128, 64, 1)


In [9]:
x_train = x_train/255
train_ds = tf.data.Dataset.from_tensor_slices((x_train, x_train)).batch(256)
train_ds

<BatchDataset shapes: ((None, 128, 64, 1), (None, 128, 64, 1)), types: (tf.float64, tf.float64)>

In [10]:
encoder, decoder, vae = build_vae(x_train, 10)

In [11]:
encoder.summary()
vae.summary()

Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 128, 64, 1)] 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 64, 32, 32)   320         input_1[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, 64, 32, 32)   0           conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 32, 16, 32)   9248        leaky_re_lu[0][0]                
____________________________________________________________________________________________

In [12]:
ckp_dir = save_path+'/ckp/'
checkpoint = tf.train.Checkpoint(step=tf.Variable(1), encoder=encoder, decoder=decoder, vae=vae)
checkpoint.restore(tf.train.latest_checkpoint(ckp_dir))

<tensorflow.python.training.tracking.util.InitializationOnlyStatus at 0x7f9d6f58b590>

In [13]:
optimizer = tf.keras.optimizers.Adam(0.001)
train_loss = tf.keras.metrics.Mean(name='train_loss')
valid_loss = tf.keras.metrics.Mean(name='valid_loss')

def get_rec_loss(inputs, predictions):
    rec_loss = tf.keras.losses.binary_crossentropy(inputs, predictions)
    rec_loss = tf.reduce_mean(rec_loss)
    rec_loss *= x_train.shape[1]*x_train.shape[2]
    return rec_loss

def get_kl_loss(z_log_var, z_mean):
    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = tf.reduce_mean(kl_loss)
    kl_loss *= -0.5
    
    return kl_loss

In [14]:
@tf.function
def train_step(inputs):
    with tf.GradientTape() as tape:
        
        # Get model ouputs
        z_log_var, z_mean, z = encoder(inputs)
        predictions = decoder(z)
        
        # Compute losses
        rec_loss = get_rec_loss(inputs, predictions)
        kl_loss = get_kl_loss(z_log_var, z_mean)
        loss = rec_loss + kl_loss
    
    # Compute gradients
    varialbes = vae.trainable_variables
    gradients = tape.gradient(loss, varialbes)
    # Update weights
    optimizer.apply_gradients(zip(gradients, varialbes))
    
    # Update train loss
    train_loss(loss)

In [16]:
def reduce_lr(pre_v_loss, v_loss, count, lr, patience, factor, min_lr):
    if v_loss < pre_v_loss:
        count = 0
    else:
        count += 1
        if count >= patience: 
            lr = lr*factor
            if lr < min_lr: 
                lr = min_lr
            count = 0
            print('reduce learning rate..', lr)    
    return count, lr

checkpoint = tf.train.Checkpoint(step=tf.Variable(1), encoder=encoder, decoder=decoder, vae=vae)
csv_logger = tf.keras.callbacks.CSVLogger(save_path+'/training.log')

In [17]:
epochs = 500

# Initialize values
best_loss, count = float('inf'), 0

# Start epoch loop
for epoch in range(epochs):
    for inputs, outputs in train_ds:
        train_step(inputs)
    
    # Get loss and leraning rate at this epoch
    t_loss = train_loss.result().numpy() 
    l_rate = optimizer.learning_rate.numpy()

    # Control learning rate
    count, lr  = reduce_lr(best_loss, t_loss, count, l_rate, 5, 0.2, 0.00001)
    optimizer.learning_rate = lr
    
    # Save checkpoint if best v_loss 
    if t_loss < best_loss:
        best_loss = t_loss
        checkpoint.save(file_prefix=os.path.join(save_path+'/ckp/', 'ckp'))
    
    # Save loss, lerning rate
    print("* %i * loss: %f,  best_loss: %f, l_rate: %f, lr_count: %i"%(epoch, t_loss, best_loss, l_rate, count ))
    df = pd.DataFrame({'epoch':[epoch], 'loss':[t_loss], 'best_loss':[best_loss], 'l_rate':[l_rate]  } )
    df.to_csv(save_path+'/process.csv', mode='a', header=False)
    
    # Reset loss
    train_loss.reset_states()

* 0 * loss: 5382.562012,  best_loss: 5382.562012, l_rate: 0.001000, lr_count: 0
* 1 * loss: 5283.356934,  best_loss: 5283.356934, l_rate: 0.001000, lr_count: 0
* 2 * loss: 5254.441895,  best_loss: 5254.441895, l_rate: 0.001000, lr_count: 0
* 3 * loss: 5234.719727,  best_loss: 5234.719727, l_rate: 0.001000, lr_count: 0
* 4 * loss: 5225.055176,  best_loss: 5225.055176, l_rate: 0.001000, lr_count: 0
* 5 * loss: 5219.903320,  best_loss: 5219.903320, l_rate: 0.001000, lr_count: 0
* 6 * loss: 5216.562988,  best_loss: 5216.562988, l_rate: 0.001000, lr_count: 0
* 7 * loss: 5214.093750,  best_loss: 5214.093750, l_rate: 0.001000, lr_count: 0
* 8 * loss: 5212.419434,  best_loss: 5212.419434, l_rate: 0.001000, lr_count: 0
* 9 * loss: 5211.175781,  best_loss: 5211.175781, l_rate: 0.001000, lr_count: 0
* 10 * loss: 5209.964844,  best_loss: 5209.964844, l_rate: 0.001000, lr_count: 0
* 11 * loss: 5208.466309,  best_loss: 5208.466309, l_rate: 0.001000, lr_count: 0
* 12 * loss: 5207.567383,  best_loss: 

* 102 * loss: 5194.838379,  best_loss: 5194.838379, l_rate: 0.001000, lr_count: 0
* 103 * loss: 5194.937012,  best_loss: 5194.838379, l_rate: 0.001000, lr_count: 1
* 104 * loss: 5194.965332,  best_loss: 5194.838379, l_rate: 0.001000, lr_count: 2
* 105 * loss: 5194.795410,  best_loss: 5194.795410, l_rate: 0.001000, lr_count: 0
* 106 * loss: 5194.719727,  best_loss: 5194.719727, l_rate: 0.001000, lr_count: 0
* 107 * loss: 5194.660156,  best_loss: 5194.660156, l_rate: 0.001000, lr_count: 0
* 108 * loss: 5194.770020,  best_loss: 5194.660156, l_rate: 0.001000, lr_count: 1
* 109 * loss: 5194.568848,  best_loss: 5194.568848, l_rate: 0.001000, lr_count: 0
* 110 * loss: 5194.729980,  best_loss: 5194.568848, l_rate: 0.001000, lr_count: 1
* 111 * loss: 5194.738281,  best_loss: 5194.568848, l_rate: 0.001000, lr_count: 2
* 112 * loss: 5194.744141,  best_loss: 5194.568848, l_rate: 0.001000, lr_count: 3
* 113 * loss: 5194.624512,  best_loss: 5194.568848, l_rate: 0.001000, lr_count: 4
* 114 * loss: 51

* 202 * loss: 5193.037109,  best_loss: 5193.037109, l_rate: 0.000200, lr_count: 0
* 203 * loss: 5193.040039,  best_loss: 5193.037109, l_rate: 0.000200, lr_count: 1
* 204 * loss: 5193.035645,  best_loss: 5193.035645, l_rate: 0.000200, lr_count: 0
* 205 * loss: 5193.057129,  best_loss: 5193.035645, l_rate: 0.000200, lr_count: 1
* 206 * loss: 5193.034180,  best_loss: 5193.034180, l_rate: 0.000200, lr_count: 0
* 207 * loss: 5193.027344,  best_loss: 5193.027344, l_rate: 0.000200, lr_count: 0
* 208 * loss: 5193.014648,  best_loss: 5193.014648, l_rate: 0.000200, lr_count: 0
* 209 * loss: 5193.028320,  best_loss: 5193.014648, l_rate: 0.000200, lr_count: 1
* 210 * loss: 5193.005859,  best_loss: 5193.005859, l_rate: 0.000200, lr_count: 0
* 211 * loss: 5192.998535,  best_loss: 5192.998535, l_rate: 0.000200, lr_count: 0
* 212 * loss: 5193.026367,  best_loss: 5192.998535, l_rate: 0.000200, lr_count: 1
* 213 * loss: 5192.989746,  best_loss: 5192.989746, l_rate: 0.000200, lr_count: 0
* 214 * loss: 51

* 300 * loss: 5192.611816,  best_loss: 5192.600098, l_rate: 0.000010, lr_count: 1
* 301 * loss: 5192.604980,  best_loss: 5192.600098, l_rate: 0.000010, lr_count: 2
* 302 * loss: 5192.606445,  best_loss: 5192.600098, l_rate: 0.000010, lr_count: 3
* 303 * loss: 5192.603516,  best_loss: 5192.600098, l_rate: 0.000010, lr_count: 4
reduce learning rate.. 1e-05
* 304 * loss: 5192.609375,  best_loss: 5192.600098, l_rate: 0.000010, lr_count: 0
* 305 * loss: 5192.599121,  best_loss: 5192.599121, l_rate: 0.000010, lr_count: 0
* 306 * loss: 5192.606934,  best_loss: 5192.599121, l_rate: 0.000010, lr_count: 1
* 307 * loss: 5192.596680,  best_loss: 5192.596680, l_rate: 0.000010, lr_count: 0
* 308 * loss: 5192.604492,  best_loss: 5192.596680, l_rate: 0.000010, lr_count: 1
* 309 * loss: 5192.599121,  best_loss: 5192.596680, l_rate: 0.000010, lr_count: 2
* 310 * loss: 5192.600098,  best_loss: 5192.596680, l_rate: 0.000010, lr_count: 3
* 311 * loss: 5192.602051,  best_loss: 5192.596680, l_rate: 0.000010,

* 395 * loss: 5192.576660,  best_loss: 5192.567871, l_rate: 0.000010, lr_count: 1
* 396 * loss: 5192.572266,  best_loss: 5192.567871, l_rate: 0.000010, lr_count: 2
* 397 * loss: 5192.580566,  best_loss: 5192.567871, l_rate: 0.000010, lr_count: 3
* 398 * loss: 5192.571777,  best_loss: 5192.567871, l_rate: 0.000010, lr_count: 4
reduce learning rate.. 1e-05
* 399 * loss: 5192.568848,  best_loss: 5192.567871, l_rate: 0.000010, lr_count: 0
* 400 * loss: 5192.573242,  best_loss: 5192.567871, l_rate: 0.000010, lr_count: 1
* 401 * loss: 5192.575684,  best_loss: 5192.567871, l_rate: 0.000010, lr_count: 2
* 402 * loss: 5192.580566,  best_loss: 5192.567871, l_rate: 0.000010, lr_count: 3
* 403 * loss: 5192.571777,  best_loss: 5192.567871, l_rate: 0.000010, lr_count: 4
reduce learning rate.. 1e-05
* 404 * loss: 5192.576172,  best_loss: 5192.567871, l_rate: 0.000010, lr_count: 0
* 405 * loss: 5192.573242,  best_loss: 5192.567871, l_rate: 0.000010, lr_count: 1
* 406 * loss: 5192.580078,  best_loss: 5

* 491 * loss: 5192.552734,  best_loss: 5192.541016, l_rate: 0.000010, lr_count: 4
reduce learning rate.. 1e-05
* 492 * loss: 5192.547852,  best_loss: 5192.541016, l_rate: 0.000010, lr_count: 0
* 493 * loss: 5192.548340,  best_loss: 5192.541016, l_rate: 0.000010, lr_count: 1
* 494 * loss: 5192.541016,  best_loss: 5192.541016, l_rate: 0.000010, lr_count: 2
* 495 * loss: 5192.549316,  best_loss: 5192.541016, l_rate: 0.000010, lr_count: 3
* 496 * loss: 5192.547852,  best_loss: 5192.541016, l_rate: 0.000010, lr_count: 4
reduce learning rate.. 1e-05
* 497 * loss: 5192.553711,  best_loss: 5192.541016, l_rate: 0.000010, lr_count: 0
* 498 * loss: 5192.542480,  best_loss: 5192.541016, l_rate: 0.000010, lr_count: 1
* 499 * loss: 5192.544922,  best_loss: 5192.541016, l_rate: 0.000010, lr_count: 2
