In [1]:
import tensorflow as tf
import numpy as np
import os
import sys
import math

os.environ["CUDA_VISIBLE_DEVICES"] = '0'
sys.path.append("../")

  from ._conv import register_converters as _register_converters


In [2]:
from vae import VAE
from config import FLAGS
from batchloader import BatchLoader

In [3]:
def train():
    save_dir = "model"
    
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    tf.reset_default_graph()
    
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        batchloader = BatchLoader()
        with tf.variable_scope("VAE"):
            vae = VAE(sess=sess, batchloader=batchloader, learning_rate=FLAGS.LEARNING_RATE, training=True, ru=False)
        
        with tf.variable_scope("VAE", reuse=True):
            vae_test = VAE(sess=sess, batchloader=batchloader, learning_rate=FLAGS.LEARNING_RATE, training=False, ru=True)
        
        saver = tf.train.Saver()
        #summary_writer = tf.summary.FileWriter(FLAGS.LOG_DIR, sess.graph)
        
        sess.run(tf.global_variables_initializer())
        
        loss_sum = []
        reconst_loss_sum = []
        kld_sum = []
        
        step = 0
        lr = FLAGS.LEARNING_RATE
        
        train_batch = batchloader.make_batch(FLAGS.BATCH_SIZE, is_training=True)
        total_batch = int(len(train_batch) / FLAGS.BATCH_SIZE)

        val_batch_idx = 0
        best_loss = 10000
        patient = 0
        exit = 0
        
        print("Start Learning!!")
        for epoch in range(FLAGS.EPOCH):
                
            for batch in range(total_batch):
                step += 1
                
                kld_weight = (math.tanh((step-3500)/1000) + 1) / 2
                
                #print('batch: {}, batch_size: {}'.format(type(batch), type(FLAGS.BATCH_SIZE)))
                batch_idx = batch*FLAGS.BATCH_SIZE
                minibatch = train_batch[batch_idx:batch_idx+FLAGS.BATCH_SIZE]
                encoder_input, decoder_input, target = batchloader.prepro_minibatch(minibatch, dropword=True)
                #encoder_input, decoder_input, target = batchloader.next_batch(FLAGS.BATCH_SIZE, batch_idx=batch, is_training=True)
                
                #print('encoder_input: {}\tdecoder_input: {}\ttarget: {}\tkld_weight: {}\tstep: {}'.format(encoder_input, decoder_input, target, kld_weight, step))
                feed_dict = {vae.encoder_input: encoder_input,
                             vae.decoder_input: decoder_input,
                             vae.target: target,
                             vae.KL_d_weight: kld_weight,
                             vae.step: step}
                
                #print('[type] encoder_input: {}\tdecoder_input: {}\ttarget: {}\tKL_d_weight: {}\tstep: {}'.format(type(encoder_input), type(decoder_input), type(target), type(kld_weight), type(step)))
                encoder_length, decoder_length, en_input, logits, loss, reconst_loss, kld, _ = sess.run([vae.encoder_length, vae.decoder_length, vae.encoder_input, vae.logits, vae.loss, vae.reconst_loss, vae.KL_d, vae.train_op], 
                                                                              feed_dict=feed_dict)
                
                #print('[encoder_length] {}: {}\t[decoder_length] {}: {}'.format(encoder_length, encoder_length.shape, decoder_length, decoder_length.shape))
                
                reconst_loss_sum.append(reconst_loss)
                kld_sum.append(kld)
                loss_sum.append(loss)
                #summary_writer.add_summary(merged_summary, step)
                
                if batch%200 == 99:
                    avg_loss = np.average(loss_sum)
                    avg_reconst_loss = np.average(reconst_loss_sum)
                    avg_kld = np.average(kld_sum)

                    print('[Epoch {}] loss: {}, reconst_loss: {}, kld: {}'.format(epoch, avg_loss, avg_reconst_loss, avg_kld))
                    loss_sum = []
                    reconst_loss_sum = []
                    kld_sum = []

                    '''sample_train_input = sess.run([vae.encoder_input], feed_dict=feed_dict)
                    print(sample_train_input, type(sample_train_input), np.shape(sample_train_input))
                    encoder_input_texts = batchloader.logits2str(sample_train_input, 1, onehot=False)

                    sample_train_outputs = batchloader.logits2str(logits, 1)


                    print('train input: {}, train output: {}'.format(encoder_input_texts[0], sample_train_outputs[0]))'''
                
                
            
            
            #validation output
            val_batch = batchloader.make_batch(FLAGS.BATCH_SIZE, is_training=False)
            #print('[val_batch] {}'.format(np.shape(val_batch)))
            val_minibatch = val_batch[val_batch_idx*FLAGS.BATCH_SIZE:(val_batch_idx + 1)*FLAGS.BATCH_SIZE]
            val_encoder_input, val_decoder_input, val_target = batchloader.prepro_minibatch(val_minibatch, dropword=False)
            #print('[val encoder input] {}\t[val decoder input] {}\t[val target] {}'.format(val_encoder_input[0], val_decoder_input[0], val_target[0]))
            #sample_input, _, sample_target = batchloader.next_batch(FLAGS.BATCH_SIZE, is_training=False)

            val_logits, val_loss, val_prediction = sess.run([vae_test.logits, vae_test.reconst_loss, vae_test.decoder_prediction],
                                                 feed_dict = {vae_test.encoder_input: val_encoder_input,
                                                              vae_test.decoder_input: val_decoder_input,
                                                              vae_test.target: val_target,
                                                              vae_test.KL_d_weight: kld_weight})
            
            input_text = batchloader.input2str(val_encoder_input)
            output_text = batchloader.pred2str(val_prediction)
            
            print('[input] {}\n[output] {}'.format(input_text[0], output_text[0]))
            
            if val_loss < best_loss:
                best_loss = val_loss
                filename = os.path.join(save_dir, "model_{}.ckpt".format(epoch))
                save_path = saver.save(sess, filename)
                patient = 0
            else:
                if patient == 15:
                    exit = 1
                patient += 1

            print('[Epoch {}] best loss: {}, current loss: {}, patient: {}'.format(epoch, best_loss, val_loss, patient))

            #print('[val] logits[0]: {}, type(logit): {}, shape(logit): {}'.format(val_logits[0], type(val_logits), np.shape(val_logits)))
            #val_input_texts = batchloader.logits2str(val_encoder_input, 1, onehot=False)
            #val_output_texts = batchloader.logits2str(val_logits, 1)

            #print('input length: {}\toutput length: {}'.format(len(val_input_texts), len(val_output_texts)))
            #for i in range(FLAGS.BATCH_SIZE):
                #print('[VAL] input: {}\toutput: {}'.format(val_input_texts[i], val_output_texts[i]))
                
            #print('sample input: {}, sample output: {}'.format(val_input_texts[0], val_output_texts[0]))
            #summary_writer.add_summary(merged_summary, step)
            
            if exit == 1:
                #save model
                
                print('Model saved in file {}'.format(save_path))
                print("Finish Learning!!")
                break

            val_batch_idx += 1

In [4]:
if __name__ == "__main__":
    sys.argv = ['ddd']
    print(FLAGS.LEARNING_RATE)
    train()

0.001
10002
(32, 60, 353)
Instructions for updating:
Use the retry module or similar alternatives.
[decoder_input] (32, 60, 353)
[latent] (32, 60, 13)
[decoder_logit] (1920, 10002)
[reconstruction_loss] (1920,)
[reconst_loss] ()
(32, 60, 353)
[start] (32, 353)
[reconstruction_loss] (60, 32)
[reconst_loss] ()
Start Learning!!
[Epoch 0] loss: 501.04034423828125, reconst_loss: 500.8358459472656, kld: 203.59214782714844
[Epoch 0] loss: 477.5327453613281, reconst_loss: 477.4717102050781, kld: 45.83076477050781
[Epoch 0] loss: 459.8138427734375, reconst_loss: 459.7102355957031, kld: 50.5080680847168
[Epoch 0] loss: 443.9869384765625, reconst_loss: 443.8168029785156, kld: 55.88475036621094
[Epoch 0] loss: 429.9681396484375, reconst_loss: 429.73004150390625, kld: 52.77970886230469
[Epoch 0] loss: 416.9253845214844, reconst_loss: 416.5958251953125, kld: 48.985862731933594
[Epoch 0] loss: 403.61993408203125, reconst_loss: 403.1771240234375, kld: 44.298988342285156
[input] consumers may want to m

KeyboardInterrupt: 