In [1]:
import os
import random
import numpy as np
import tensorflow as tf
from tensorflow.python.ops import rnn, rnn_cell, seq2seq
from tqdm import *
from tensorflow.python.ops import control_flow_ops
import sys
import random
import time

from Utils import Utils
import copy_mechanism
from hyperboard import Agent

    
def train(train_utils, valid_utils, source_len, oseq_len, n_echos, simplified_len, decoder_hidden, encoder_hidden,
          embedding_size, batch_size, display_step, prob, lstm_layer, learning_rate,
          model_dir, source_nfilters, defendant_nfilters, source_width, defendant_width):
            
    
    hyperparameters = {
        'learning rate': learning_rate,
        'batch size': batch_size,
        'criteria': '',
        'decoder hidden': decoder_hidden,
        'encoder_hidden':encoder_hidden,
        'prob': prob,
        'source_nfilters': source_nfilters,
        'defendant_nfilters': defendant_nfilters, 
        'source_width': source_width,
        'defendant_width': defendant_width, 
        'lstm_layer': lstm_layer,  
        'embedding_size': embedding_size
    }
        
    agent = Agent(port = 5000)
    
    hyperparameters['criteria'] = 'valid loss'
    name_valid_loss = agent.register(hyperparameters, 'cross entropy')
    
    hyperparameters['criteria'] = 'valid accu'
    name_valid_accu = agent.register(hyperparameters, 'accuracy')
    
    hyperparameters['criteria'] = 'train loss'
    name_train_loss = agent.register(hyperparameters, 'cross entropy')
    
    hyperparameters['criteria'] = 'train accu'
    name_train_accu = agent.register(hyperparameters, 'accuracy')
    
        
    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)
        
    words_size = train_utils.get_words_size()
    
    outputs = copy_mechanism.model(words_size=words_size, 
                                   embedding_size=embedding_size,
                                   source_len=source_len,
                                   simplified_len=simplified_len,
                                   oseq_len=oseq_len, 
                                   decoder_hidden=decoder_hidden,
                                   encoder_hidden=encoder_hidden,
                                   source_nfilters=source_nfilters,
                                   defendant_nfilters=defendant_nfilters,
                                   source_width=source_width,
                                   defendant_width=defendant_width,
                                   lstm_layer=lstm_layer, 
                                   batch_size=batch_size, 
                                   is_train=True)

    cost=outputs['cost']
    words_prediction=outputs['words_prediction']
    source=outputs['source']
    defendant=outputs['defendant']
    defendant_length=outputs['defendant_length']
    label=outputs['label']
    decoder_inputs=outputs['decoder_inputs']
    loss_weights=outputs['loss_weights']
    keep_prob=outputs['keep_prob']
    sample_rate=outputs['sample_rate']
    
    
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, name='adam_optimizer')
    # Compute the gradients for a list of variables.
    grads_and_vars = optimizer.compute_gradients(cost)
    # grads_and_vars is a list of tuples (gradient, variable).
    capped_grads_and_vars = [(tf.clip_by_norm(g, 5), v) for g,v in grads_and_vars]
    # Ask the optimizer to apply the capped gradients.
    train_op = optimizer.apply_gradients(capped_grads_and_vars)

    
    init = tf.initialize_all_variables()

    os.environ["CUDA_VISIBLE_DEVICES"] = '2'
    gpu_option = tf.GPUOptions(per_process_gpu_memory_fraction = 0.99)
    session_conf = tf.ConfigProto(allow_soft_placement = True, 
                                  log_device_placement = False,
                                  gpu_options = gpu_option)
    
    sess = tf.Session()
    
    pre_saver = tf.train.Saver()
    ckpt = tf.train.latest_checkpoint('/home/xuwenshen/2017_3_13/model_v2_2/')
    pre_saver.restore(sess, ckpt)
    
    global_steps = 0
    
    echos = 0
    batch = 0
    
    saver = tf.train.Saver(max_to_keep = 100)

#     sess.run(init)
    
    tvar = tf.trainable_variables()
    for v in tvar:
        print (v.name)
    
    print ('init done')
    
    train_cost = 0
    train_accu = 0
    
    while True:
        
        datas, is_again = train_utils.next_batch()
        
        if is_again:
            
            train_cost = 0
            train_accu = 0
            echos += 1
            batch = 0
            if echos == n_echos:
                break
            continue
            
        
        batch_source = datas['source'] 
        batch_defendant = datas['defendant'] 
        batch_defendant_length = datas['defendant_length']
        batch_ground_truth = datas['ground_truth']
        batch_label = datas['label']
        batch_weights = datas['loss_weights']
        
        batch += 1
        

        sample_rate_ = min(.5, 0.39 + 0.005 * (global_steps))

        
        words_, cost_, _= sess.run([words_prediction, cost, train_op], 
                                   feed_dict={source:batch_source,
                                              defendant:batch_defendant,
                                              label:batch_label,
                                              defendant_length:batch_defendant_length,
                                              decoder_inputs:batch_ground_truth,
                                              loss_weights:batch_weights,
                                              keep_prob:prob,
                                              sample_rate:sample_rate_,
                                               })
    
        train_cost += cost_
#         train_accu += train_utils.i2t(words_, to_print = False)

        
        if batch % display_step == 0:
            
            valid_cost = 0
            valid_accu = 0
            valid_batchs = 0
            
            train_cost /= display_step
            train_accu /= display_step
            
            while True:
                
                datas, is_again = valid_utils.next_batch()
                
                if is_again:
                    break
        
                batch_source = datas['source'] 
                batch_defendant = datas['defendant'] 
                batch_defendant_length = datas['defendant_length']
                batch_ground_truth = datas['ground_truth']
                batch_label = datas['label']
                batch_weights = datas['loss_weights']

                valid_batchs += 1
                
                word_, cost_, = sess.run([words_prediction, cost], 
                                          feed_dict={source:batch_source,
                                                     defendant:batch_defendant,
                                                     defendant_length:batch_defendant_length,
                                                     label:batch_label,
                                                     decoder_inputs:batch_ground_truth,
                                                     loss_weights:batch_weights,
                                                     keep_prob:1.,
                                                     sample_rate:1.,
                                                      })
                valid_accu += valid_utils.i2t(word_, to_print = True)
                valid_cost += cost_
            
            valid_cost /= valid_batchs
            valid_accu /= valid_batchs
            
            saver.save(sess, model_dir + \
                       "sample_rate-" + str(sample_rate_) + \
                       "-train_accu-{:.4f}".format(train_accu) + \
                       "-train_cost-{:.4f}".format(train_cost) + \
                       "-valid_accu-{:.4f}".format(valid_accu) + \
                       "-valid_cost-{:.4f}".format(valid_cost) + \
                       "-model.ckpt", global_step = global_steps)
            
            
            print ("Echo: "+str(echos) + " Iters: "+str(batch) + \
                   " Sample: " + "{:.3f}".format(sample_rate_) + \
                   " Train loss: " + "{:.4f}".format(train_cost) + \
                   " Valid loss: " + "{:.4f}".format(valid_cost)+ \
                   " Train accu: " + "{:.4f}".format(train_accu)+ \
                   " Valid accu: " + "{:.4f}".format(valid_accu))
            

            agent.append(name_valid_loss, global_steps, valid_cost)
            agent.append(name_valid_accu, global_steps, valid_accu)
            agent.append(name_train_loss, global_steps, train_cost)
            agent.append(name_train_accu, global_steps, train_accu)
            
            global_steps += 1
            
        
            
if __name__ == '__main__':
    
    words_path = '/home/xuwenshen/data/big_data/2017_3_13/words'
    train_path = '/home/xuwenshen/data/big_data/2017_3_13/train.h5'
    valid_path = '/home/xuwenshen/data/big_data/2017_3_13/valid.h5'
    
    oseq_len = 200
    source_len = 1000
    simplified_len = 150

    batch_size = 64
    display_step = 400
    n_echos = 100
    prob = 0.75
    decoder_hidden = 600
    encoder_hidden = 256
    source_nfilters = 128
    defendant_nfilters = 32
    source_width = 3
    defendant_width = 3
    lstm_layer = 1
    learning_rate = 0.001
    embedding_size = 200


    
    
    train_utils = Utils(words_path, train_path, batch_size, nb_samples=1600000)
    valid_utils = Utils(words_path, valid_path, batch_size, nb_samples=640)

    train(train_utils=train_utils,
          valid_utils=valid_utils,
          source_len=source_len,
          simplified_len=simplified_len,
          oseq_len=oseq_len,
          embedding_size=embedding_size,
          batch_size=batch_size,
          display_step=display_step,
          n_echos=n_echos,
          prob=prob,
          decoder_hidden=decoder_hidden,
          encoder_hidden=encoder_hidden,
          source_nfilters=source_nfilters,
          defendant_nfilters=defendant_nfilters,
          source_width=source_width,
          defendant_width=defendant_width,
          lstm_layer=lstm_layer,
          learning_rate=learning_rate,
          model_dir='/home/xuwenshen/2017_3_13/model_v2_2/')
    
    
    

KeyboardInterrupt: 