In [None]:
import tensorflow as tf
import opennmt as onmt
from opennmt.utils.misc import count_lines
from opennmt.utils.parallel import GraphDispatcher
from opennmt import constants
from opennmt.utils.losses import cross_entropy_sequence_loss
from opennmt.utils.optim import *
from opennmt.utils.evaluator import *
import os
import yaml
import time
import numpy as np
import datetime
from opennmt.utils.cell import build_cell

In [None]:
def load_vocab(vocab_path, vocab_size):
    if not vocab_size:
        vocab_size = count_lines(vocab_path) + 1 #for UNK
        print("vocab size of",vocab_path,":",vocab_size)
    vocab = tf.contrib.lookup.index_table_from_file(vocab_path, vocab_size = vocab_size - 1, num_oov_buckets = 1)
    return vocab, vocab_size

def get_dataset_size(data_file):
    return count_lines(data_file)

def get_padded_shapes(dataset):    
    return tf.contrib.framework.nest.map_structure(
    lambda shape: shape.as_list(), dataset.output_shapes)

def filter_irregular_batches(multiple):    
    if multiple == 1:
        return lambda dataset: dataset

    def _predicate(*x):
        flat = tf.contrib.framework.nest.flatten(x)
        batch_size = tf.shape(flat[0])[0]
        return tf.equal(tf.mod(batch_size, multiple), 0)

    return lambda dataset: dataset.filter(_predicate)

def prefetch_element(buffer_size=None):  
    support_auto_tuning = hasattr(tf.data, "experimental") or hasattr(tf.contrib.data, "AUTOTUNE")
    if not support_auto_tuning and buffer_size is None:
        buffer_size = 1
    return lambda dataset: dataset.prefetch(buffer_size)

def create_embeddings(vocab_size, depth=512):
      """Creates an embedding variable."""
      return tf.get_variable("embedding", shape = [vocab_size, depth])

def kl_coeff(i):
    # coeff = (tf.tanh((i - 3500)/1000) + 1)/2
    coeff = (tf.tanh((i - 20000)/5000) + 1)/2
    return tf.cast(coeff, tf.float32)

In [None]:
def load_data(src_path, src_vocab, batch_size=32, batch_type ="examples", batch_multiplier = 1, tgt_path=None, tgt_vocab=None, 
              max_len=50, bucket_width = 1, mode="Training", padded_shapes = None, 
              shuffle_buffer_size = None, prefetch_buffer_size = 100000, num_threads = 4, version=None, distribution=None, tf_idf_table=None):

    batch_size = batch_size * batch_multiplier
    print("batch_size", batch_size)
    
    def _make_dataset(text_path):
        dataset = tf.data.TextLineDataset(text_path)
        dataset = dataset.map(lambda x: tf.string_split([x]).values) #split by spaces
        return dataset    
       
    def _batch_func(dataset):
        return dataset.padded_batch(batch_size,
                                    padded_shapes=padded_shapes or get_padded_shapes(dataset))

    def _key_func(dataset):                
        #bucket_id = tf.squeeze(dataset["domain"])
        features_length = dataset["src_length"] #features_length_fn(features) if features_length_fn is not None else None
        labels_length = dataset["tgt_length"] #labels_length_fn(labels) if labels_length_fn is not None else None        
        bucket_id = tf.constant(0, dtype=tf.int32)
        if features_length is not None:
            bucket_id = tf.maximum(bucket_id, features_length // bucket_width)
        if labels_length is not None:
            bucket_id = tf.maximum(bucket_id, labels_length // bucket_width)
        return tf.cast(bucket_id, tf.int64)
        #return tf.to_int64(bucket_id)

    def _reduce_func(unused_key, dataset):
        return _batch_func(dataset)

    def _window_size_func(key):
        if bucket_width > 1:
            key += 1  # For bucket_width == 1, key 0 is unassigned.
        size = batch_size // (key * bucket_width)
        if batch_multiplier > 1:
            # Make the window size a multiple of batch_multiplier.
            size = size + batch_multiplier - size % batch_multiplier
        return tf.to_int64(tf.maximum(size, batch_multiplier))             
    
    bos = tf.constant([constants.START_OF_SENTENCE_ID], dtype=tf.int64)
    eos = tf.constant([constants.END_OF_SENTENCE_ID], dtype=tf.int64)
    
    if version==None:
        print("old dataprocessing version")
        src_dataset = _make_dataset(src_path)            
        if mode=="Training":
            tgt_dataset = _make_dataset(tgt_path)
            dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset))
        elif mode=="Inference":
            dataset = src_dataset
        elif mode == "Predict":
            dataset = src_dataset

        if mode=="Training":                    
            dataset = dataset.map(lambda x,y:{                      
                    "src_raw": x,
                    "tgt_raw": y,
                    "src_ids": src_vocab.lookup(x),
                    "tgt_ids": tgt_vocab.lookup(y),
                    "tgt_ids_in": tf.concat([bos, tgt_vocab.lookup(y)], axis=0),
                    "tgt_ids_out": tf.concat([tgt_vocab.lookup(y), eos], axis=0),
                    "src_length": tf.shape(src_vocab.lookup(x))[0],
                    "tgt_length": tf.shape(tgt_vocab.lookup(y))[0],                
                    }, num_parallel_calls=num_threads)    
                       
        elif mode == "Inference":            
            dataset = dataset.map(lambda x:{                    
                    "src_raw": x,                
                    "src_ids": src_vocab.lookup(x),                
                    "src_length": tf.shape(src_vocab.lookup(x))[0],                
                    }, num_parallel_calls=num_threads) 
            
        elif mode == "Predict":            
            dataset = dataset.map(lambda x:{
                    "src_raw": x,                
                    "src_ids": src_vocab.lookup(x),                
                    "src_length": tf.shape(src_vocab.lookup(x))[0],                
                    }, num_parallel_calls=num_threads)
            
        if mode=="Training":            
            if shuffle_buffer_size is not None and shuffle_buffer_size != 0:            
                dataset_size = get_dataset_size(src_path) 
                if dataset_size is not None:
                    if shuffle_buffer_size < 0:
                        shuffle_buffer_size = dataset_size
                elif shuffle_buffer_size < dataset_size:        
                    dataset = dataset.apply(random_shard(shuffle_buffer_size, dataset_size))        
                dataset = dataset.shuffle(shuffle_buffer_size)

            dataset = dataset.filter(lambda x: tf.logical_and(tf.logical_and(tf.greater(x["src_length"],0), tf.greater(x["tgt_length"], 0)), tf.logical_and(tf.less_equal(x["src_length"], max_len), tf.less_equal(x["tgt_length"], max_len))))
            
            if bucket_width is None:
                dataset = dataset.apply(_batch_func)
            else:
                if hasattr(tf.data, "experimental"):
                    group_by_window_fn = tf.data.experimental.group_by_window
                else:
                    group_by_window_fn = tf.contrib.data.group_by_window
                print("batch type: ", batch_type)
                if batch_type == "examples":
                    dataset = dataset.apply(group_by_window_fn(_key_func, _reduce_func, window_size = batch_size))
                elif batch_type == "tokens":
                    dataset = dataset.apply(group_by_window_fn(_key_func, _reduce_func, window_size_func = _window_size_func))   
                else:
                    raise ValueError(
                            "Invalid batch type: '{}'; should be 'examples' or 'tokens'".format(batch_type))
            dataset = dataset.apply(filter_irregular_batches(batch_multiplier))             
            dataset = dataset.repeat()
            dataset = dataset.apply(prefetch_element(buffer_size=prefetch_buffer_size))                        
        else:
            dataset = dataset.apply(_batch_func)                      
        
    return dataset.make_initializable_iterator()

In [None]:
class Model:
    def _compute_loss(self, outputs, tgt_ids_batch, tgt_length, params, mode, mu_states, logvar_states):
        
        if mode == "Training":
            mode = tf.estimator.ModeKeys.TRAIN            
        else:
            mode = tf.estimator.ModeKeys.EVAL            
          
        if self.Loss_type == "Cross_Entropy":
            if isinstance(outputs, dict):
                logits = outputs["logits"]
                attention = outputs.get("attention")
            else:
                logits = outputs
                attention = None 
                            
            loss, loss_normalizer, loss_token_normalizer = cross_entropy_sequence_loss(
                logits,
                tgt_ids_batch, 
                tgt_length + 1,                                                         
                label_smoothing = params.get("label_smoothing", 0.0),
                average_in_time = params.get("average_loss_in_time", True),
                mode = mode
            )
            
            
            #----- Calculating kl divergence --------

            kld_loss = -0.5 * tf.reduce_sum(logvar_states - self.logvar_0 - tf.pow(mu_states-self.mu_0, 2)/tf.exp(self.logvar_0) - tf.exp(logvar_states)/tf.exp(self.logvar_0) + 1, 1)

            return loss, loss_normalizer, loss_token_normalizer, kld_loss
        
    
    def _initializer(self, params):
        
        if params["Architecture"] == "Transformer":
            print("tf.variance_scaling_initializer")
            return tf.variance_scaling_initializer(
        mode="fan_avg", distribution="uniform", dtype=self.dtype)
        else:            
            param_init = params.get("param_init")
            if param_init is not None:
                print("tf.random_uniform_initializer")
                return tf.random_uniform_initializer(
              minval=-param_init, maxval=param_init, dtype=self.dtype)
        return None
        
    def __init__(self, config_file, mode, test_feature_file=None):

        def _normalize_loss(num, den=None):
            """Normalizes the loss."""
            if isinstance(num, list):  # Sharded mode.
                if den is not None:
                    assert isinstance(den, list)
                    return tf.add_n(num) / tf.add_n(den) #tf.reduce_mean([num_/den_ for num_,den_ in zip(num, den)]) #tf.add_n(num) / tf.add_n(den)
                else:
                    return tf.reduce_mean(num)
            elif den is not None:
                return num / den
            else:
                return num

        def _extract_loss(loss, Loss_type="Cross_Entropy"):
            """Extracts and summarizes the loss."""
            losses = None
            print("loss numb:", len(loss))
            if Loss_type=="Cross_Entropy":
                if not isinstance(loss, tuple):                    
                    print(1)
                    actual_loss = _normalize_loss(loss)
                    tboard_loss = actual_loss
                    tf.summary.scalar("loss", tboard_loss)
                    losses = actual_loss                    
                else:                         
                    actual_loss = _normalize_loss(loss[0], den=loss[1])
                    tboard_loss = _normalize_loss(loss[0], den=loss[2]) if len(loss) > 2 else actual_loss
                    losses = actual_loss
                    loss_kd = _normalize_loss(loss[3])
                    tf.summary.scalar("loss", tboard_loss)
                    tf.summary.scalar("kl_loss", loss_kd)

            return losses,loss_kd                         

        def _loss_op(inputs, params, mode):
            """Single callable to compute the loss."""
            logits, _, tgt_ids_out, tgt_length, mu_states, logvar_states  = self._build(inputs, params, mode)
            losses = self._compute_loss(logits, tgt_ids_out, tgt_length, params, mode, mu_states, logvar_states)
            
            return losses

        with open(config_file, "r") as stream:
            config = yaml.load(stream)
        
        Loss_type = config.get("Loss_Function","Cross_Entropy")
        
        self.Loss_type = Loss_type
        self.config = config 
        self.using_tf_idf = config.get("using_tf_idf", False)
        
        train_batch_size = config["training_batch_size"]   
        eval_batch_size = config["eval_batch_size"]
        
        self.latent_variable_size = config.get("latent_variable_size",128)
        
        max_len = config["max_len"]
        
        example_sampling_distribution = config.get("example_sampling_distribution",None)
        self.dtype = tf.float32
        
        # Input pipeline:
        # Return lookup table of type index_table_from_file
        src_vocab, src_vocab_size = load_vocab(config["src_vocab_path"], config.get("src_vocab_size", None))
        tgt_vocab, tgt_vocab_size = load_vocab(config["tgt_vocab_path"], config.get("tgt_vocab_size", None))
        self.src_vocab_size = src_vocab_size
        self.tgt_vocab_size = tgt_vocab_size
        
        # define mu0_ and logvar0_
        self.mu_0 = .0
        self.logvar_0 = .0
        
        load_data_version = config.get("dataprocess_version",None)
        
        if mode == "Training":    
            print("num_devices", config.get("num_devices",1))
            
            dispatcher = GraphDispatcher(
                config.get("num_devices",1), 
                daisy_chain_variables=config.get("daisy_chain_variables",False), 
                devices= config.get("devices",None)
            ) 
            
            batch_multiplier = config.get("num_devices", 1)
            num_threads = config.get("num_threads", 4)
            
            if Loss_type == "Wasserstein":
                self.using_tf_idf = True
                
            if self.using_tf_idf:
                tf_idf_table = build_tf_idf_table(
                    config["tgt_vocab_path"], 
                    self.src_vocab_size, 
                    config["domain_numb"], 
                    config["training_feature_file"])           
                self.tf_idf_table = tf_idf_table
                
            iterator = load_data(
                config["training_label_file"], 
                src_vocab, 
                batch_size = train_batch_size, 
                batch_type=config["training_batch_type"], 
                batch_multiplier = batch_multiplier, 
                tgt_path=config["training_feature_file"], 
                tgt_vocab=tgt_vocab, 
                max_len = max_len, 
                mode=mode, 
                shuffle_buffer_size = config["sample_buffer_size"], 
                num_threads = num_threads, 
                version = load_data_version, 
                distribution = example_sampling_distribution
            )
            
            inputs = iterator.get_next()
            data_shards = dispatcher.shard(inputs)

            with tf.variable_scope(config["Architecture"], initializer=self._initializer(config)):
                losses_shards = dispatcher(_loss_op, data_shards, config, mode)

            self.loss = _extract_loss(losses_shards, Loss_type=Loss_type) 

        elif mode == "Inference": 
            assert test_feature_file != None
            
            iterator = load_data(
                test_feature_file, 
                src_vocab, 
                batch_size = eval_batch_size, 
                batch_type = "examples", 
                batch_multiplier = 1, 
                max_len = max_len, 
                mode = mode, 
                version = load_data_version
            )
            
            inputs = iterator.get_next() 
            
            with tf.variable_scope(config["Architecture"]):
                _ , self.predictions, _, _, _, _ = self._build(inputs, config, mode)
            
        self.iterator = iterator
        self.inputs = inputs
        
    def loss_(self):
        return self.loss
    
    def prediction_(self):
        return self.predictions
   
    def inputs_(self):
        return self.inputs
    
    def iterator_initializers(self):
        if isinstance(self.iterator,list):
            return [iterator.initializer for iterator in self.iterator]
        else:
            return [self.iterator.initializer]        
           
    def _build(self, inputs, config, mode):        

        debugging = config.get("debugging", False)
        Loss_type = self.Loss_type       
        print("Loss_type: ", Loss_type)           

        hidden_size = config["hidden_size"]       
        print("hidden size: ", hidden_size)
                
        tgt_vocab_rev = tf.contrib.lookup.index_to_string_table_from_file(config["tgt_vocab_path"], vocab_size= int(self.tgt_vocab_size) - 1, default_value=constants.UNKNOWN_TOKEN)
        
        end_token = constants.END_OF_SENTENCE_ID
        
        # Embedding        
        size_src = config.get("src_embedding_size",512)
        size_tgt = config.get("tgt_embedding_size",512)
        latent_variable_size = self.latent_variable_size
        
        with tf.variable_scope("src_embedding"):
            src_emb = create_embeddings(self.src_vocab_size, depth=size_src)

        with tf.variable_scope("tgt_embedding"):
            tgt_emb = create_embeddings(self.tgt_vocab_size, depth=size_tgt)

        self.tgt_emb = tgt_emb
        self.src_emb = src_emb

        # Build encoder, decoder
#---------------------------------------------GRU-----------------------------------------#
        if config["Architecture"] == "GRU":
            nlayers = config.get("nlayers",4)
            
#==============================ENCODER==================================
            encoder = onmt.encoders.BidirectionalRNNEncoder(
                nlayers, 
                hidden_size, 
                reducer=onmt.layers.ConcatReducer(), 
                cell_class = tf.contrib.rnn.GRUCell, 
                dropout=0.1, 
                residual_connections=True
            )
#==============================DECODER==================================
            decoder = onmt.decoders.AttentionalRNNDecoder(
                nlayers, 
                hidden_size, 
                bridge=onmt.layers.CopyBridge(), 
                cell_class=tf.contrib.rnn.GRUCell, 
                dropout=0.1, 
                residual_connections=True
            )
    
#---------------------------------------------LSTM-----------------------------------------# 
        elif config["Architecture"] == "LSTM":
            nlayers = config.get("nlayers",4)
            
            assert hidden_size % 2 == 0, \
                "To use BidirectionalRNN, hidden_size must be devided by 2."

#==============================ENCODER==================================
            encoder_src = onmt.encoders.BidirectionalRNNEncoder(
                nlayers, 
                num_units=hidden_size, 
                reducer=onmt.layers.ConcatReducer(), 
                cell_class=tf.nn.rnn_cell.LSTMCell,
                dropout=0.1, 
                residual_connections=True
            )
        
#==============================DECODER==================================
            decoder = onmt.decoders.AttentionalRNNDecoder(
                nlayers, 
                num_units=hidden_size, 
                bridge=onmt.layers.CopyBridge(), 
                attention_mechanism_class=tf.contrib.seq2seq.LuongAttention,
                cell_class=tf.nn.rnn_cell.LSTMCell, 
                dropout=0.1, 
                residual_connections=True
            )
    
#---------------------------------------------TRANSFORMER-----------------------------------------#
        elif config["Architecture"] == "Transformer":
            nlayers = config.get("nlayers",6)

#==============================ENCODER==================================
# Requires 2 encoder for SVAE
            encoder_src = onmt.encoders.self_attention_encoder.SelfAttentionEncoder(
                nlayers, 
                num_units=hidden_size, 
                num_heads=8, 
                ffn_inner_dim=2048, 
                dropout=0.1, 
                attention_dropout=0.1, 
                relu_dropout=0.1
            )  
            
            encoder_tgt = onmt.encoders.self_attention_encoder.SelfAttentionEncoder(
                nlayers, 
                num_units=hidden_size, 
                num_heads=8, 
                ffn_inner_dim=2048, 
                dropout=0.1, 
                attention_dropout=0.1, 
                relu_dropout=0.1
            )
#==============================DECODER==================================
            decoder = onmt.decoders.self_attention_decoder.SelfAttentionDecoder(
                nlayers, 
                num_units=hidden_size, 
                num_heads=8, 
                ffn_inner_dim=2048, 
                dropout=0.1, 
                attention_dropout=0.1, 
                relu_dropout=0.1
            )

        print("Model type: ", config["Architecture"])
        
        output_layer = None

#-----------------------------------------------------TRAINING MODE-----------------------------------------------#
        if mode =="Training":            
            print("Building model in Training mode")
            
            src_length = inputs["src_length"]
            tgt_length = inputs["tgt_length"]
            
            emb_src_batch = tf.nn.embedding_lookup(src_emb, inputs["src_ids"]) # dim = [batch, length, depth]
            emb_tgt_batch = tf.nn.embedding_lookup(tgt_emb, inputs["tgt_ids"])  
            emb_tgt_batch_in = tf.nn.embedding_lookup(tgt_emb, inputs["tgt_ids_in"])   
            
            self.emb_tgt_batch = emb_tgt_batch
            self.emb_src_batch = emb_src_batch
            
            print("emb_src_batch: ", emb_src_batch)
            print("emb_tgt_batch: ", emb_tgt_batch)
            
            tgt_ids_batch = inputs["tgt_ids_out"]
                        
            #========ENCODER_PROCESS======================
            with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE):
                
                if config["Architecture"] == "Transformer":
                    
                    encoder_output_src = encoder_src.encode(
                        emb_src_batch, 
                        sequence_length = src_length, 
                        mode=tf.estimator.ModeKeys.TRAIN
                    )

                    encoder_output_tgt = encoder_tgt.encode(
                        emb_tgt_batch, 
                        sequence_length = tgt_length, 
                        mode=tf.estimator.ModeKeys.TRAIN
                    )

                    self.encoder_output_src = encoder_output_src
                    self.encoder_output_tgt = encoder_output_tgt

                    encoder_outputs_src, encoder_states_src, encoder_seq_length_src = encoder_output_src
                    encoder_outputs_tgt, encoder_states_tgt, _ = encoder_output_tgt
                        # encoder_outputs: [batch_size, max_length, hidden_units_size]
                        # encoder_states: nlayers ° [batch_size, hidden_units_size] (tuple of 2d array)

                    encoder_states_combined = tf.concat([encoder_states_src[-1], encoder_states_tgt[-1]], 1)
                        # dim = [batch_size, hidden_size + hidden_size]
                        
                elif config["Architecture"] == "LSTM" :
                    
                    encoder_output_src = encoder_src.encode(
                        emb_src_batch, 
                        sequence_length = src_length, 
                        mode=tf.estimator.ModeKeys.TRAIN
                    )
                    
                    encoder_outputs_src, encoder_states_src, encoder_seq_length_src = encoder_output_src
                                        
                    encoder_states_combined = tf.reshape(tf.transpose(encoder_states_src, perm=[2, 0, 1, 3]), [tf.shape(encoder_outputs_src)[0], -1])
                        # dim = [batch_size, 2 * nlayers * hidden_size]
                    print("encoder_states_combined",encoder_states_combined)

                # important infos
                batch_size = tf.shape(encoder_outputs_src)[0]
                max_length = tf.shape(encoder_outputs_src)[1]
                    
            #======GENERATIVE_PROCESS====================
            
            with tf.variable_scope("generator"):
                
                if config["Architecture"] == "Transformer":
                    input_latent_size = 2 * hidden_size
                        # 2 hidden states from source and target sentences
                else:
                    input_latent_size = 2 * nlayers * hidden_size
                        # in a RNN model, the hidden states from source is passed into target's encoder process
                        # still *2 if use a Bidirectional RNN model
                        

                W_out_to_mu = tf.get_variable('output_to_mu_weight', shape = [input_latent_size, latent_variable_size])
                b_out_to_mu = tf.get_variable('output_to_mu_bias', shape = [latent_variable_size])
              
                W_out_to_logvar = tf.get_variable('output_to_logvar_weight', shape = [input_latent_size, latent_variable_size])
                b_out_to_logvar = tf.get_variable('output_to_logvar_bias', shape = [latent_variable_size])

                mu = tf.nn.sigmoid(tf.add(tf.matmul(encoder_states_combined, W_out_to_mu), b_out_to_mu))
                logvar = tf.nn.sigmoid(tf.add(tf.matmul(encoder_states_combined, W_out_to_logvar), b_out_to_logvar))
            
                std = tf.exp(0.5 * logvar)

                z = tf.random_normal([batch_size, latent_variable_size])
                    # z, mu, logvar shape [batch_size, latent_size]
                z = z * std + mu 
                    #Rappel z: [batch_size, latent_size]
                    
                if config["Architecture"] == "Transformer":
                    
                    z = tf.reshape(z, [batch_size,1,-1]) 
                        # shape [batch_size, 1, latent_size]
                    
                    zz = tf.tile(z,[1, max_length, 1])
                        # shape [batch_size, max_length, latent_size]

                    new_memory_inputs = tf.concat([encoder_outputs_src, zz], 2)

                    new_memory_inputs = tf.reshape(new_memory_inputs,[batch_size, -1, hidden_size + latent_variable_size])

                    new_encoder_states_src = encoder_states_src
                    
                    print("encoder_outputs_src",encoder_outputs_src)
                    print("new_memory_inputs",new_memory_inputs)
                    
                elif config["Architecture"] == "LSTM":

                    new_memory_inputs = encoder_outputs_src
                    
                    W_z_to_state = tf.get_variable('z_to_state_weight', shape = [latent_variable_size, input_latent_size])
                    b_z_to_state = tf.get_variable('z_to_state_bias', shape = [input_latent_size])
                    states_from_z = tf.add(tf.matmul(z, W_z_to_state), b_z_to_state)
                    
                    t_states_rebuild = tf.transpose(tf.reshape(states_from_z, [batch_size, nlayers, 2, -1]), perm=[1,2,0,3])
                    
                    t_states_rebuild = tf.unstack(t_states_rebuild)
                                                            
                    new_encoder_states_src = tuple([tf.nn.rnn_cell.LSTMStateTuple
                                                    (tf.reshape(t_states_rebuild[i][0],[batch_size, hidden_size]),
                                                    tf.reshape(t_states_rebuild[i][1],[batch_size, hidden_size])) 
                                                                                            for i in range(nlayers)]
                                                  )

                    print("encoder_outputs_src",encoder_outputs_src)
                    print("new_memory_inputs",new_memory_inputs)
               
            #======DECODER_PROCESS====================
            with tf.variable_scope("decoder"): 

                logits, dec_states, dec_length, attention = decoder.decode(
                                          emb_tgt_batch_in, 
                                          tgt_length + 1,
                                          vocab_size = int(self.tgt_vocab_size),
                                          initial_state = new_encoder_states_src,
                                          output_layer = output_layer,                                              
                                          mode = tf.estimator.ModeKeys.TRAIN,
                                          memory = new_memory_inputs,
                                          memory_sequence_length = encoder_seq_length_src,
                                          return_alignment_history = True
                )
                
                outputs = {
                        "logits": logits
                        }
                
            predictions = None

#-----------------------------------------------------INFERENCE MODE-----------------------------------------------#
        elif mode == "Inference":
            
            print("Build model in Inference mode")
            
            beam_width = config.get("beam_width", 5)
            
            src_length = inputs["src_length"]
            emb_src_batch = tf.nn.embedding_lookup(src_emb, inputs["src_ids"])
                        
            start_tokens = tf.fill([tf.shape(inputs["src_ids"])[0]], constants.START_OF_SENTENCE_ID)
            
           #========ENCODER_PROCESS======================
            with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE):
                
                encoder_output_src = encoder_src.encode(
                    emb_src_batch, 
                    sequence_length = src_length, 
                    mode=tf.estimator.ModeKeys.TRAIN
                )
                
                encoder_outputs_src, encoder_states_src, encoder_seq_length_src = encoder_output_src
                
                encoder_states_combined = tf.reshape(tf.transpose(encoder_states_src, perm=[2, 0, 1, 3]), [tf.shape(encoder_outputs_src)[0], -1])
                    # dim = [batch_size, 2 * nlayers * hidden_size]
                print("encoder_states_combined",encoder_states_combined)
                
                # important infos
                batch_size = tf.shape(encoder_outputs_src)[0]
                max_length = tf.shape(encoder_outputs_src)[1]

                input_latent_size = 2 * nlayers * hidden_size
                 
            #======GENERATIVE_PROCESS====================
            with tf.variable_scope("generator"):

                z = tf.random_normal([batch_size, latent_variable_size])
                    # z, mu, logvar shape [batch_size, latent_size]
                
                # Inference from normale
                std_0 = tf.exp(0.5 * self.logvar_0)

                z_0 = z* std_0 + self.mu_0 
                
                # Inference from input
                    
                W_out_to_mu = tf.get_variable('output_to_mu_weight', shape = [input_latent_size, latent_variable_size])
                b_out_to_mu = tf.get_variable('output_to_mu_bias', shape = [latent_variable_size])

                W_out_to_logvar = tf.get_variable('output_to_logvar_weight', shape = [input_latent_size, latent_variable_size])
                b_out_to_logvar = tf.get_variable('output_to_logvar_bias', shape = [latent_variable_size])

                mu = tf.nn.sigmoid(tf.add(tf.matmul(encoder_states_combined, W_out_to_mu), b_out_to_mu))
                logvar = tf.nn.sigmoid(tf.add(tf.matmul(encoder_states_combined, W_out_to_logvar), b_out_to_logvar))
                std = tf.exp(0.5 * logvar)

                z_1 = z * std + mu 
                    
                if config["Architecture"] == "Transformer":
                    
                    z = tf.reshape(z, [batch_size,1,-1]) 
                        # shape [batch_size, 1, latent_size]
                    
                    zz = tf.tile(z,[1, max_length, 1])
                        # shape [batch_size, max_length, latent_size]

                    new_memory_inputs = tf.concat([encoder_outputs_src, zz], 2)

                    new_memory_inputs = tf.reshape(new_memory_inputs,[batch_size, -1, hidden_size + latent_variable_size])

                    new_encoder_states_src = encoder_states_src
                    
                elif config["Architecture"] == "LSTM":
                    
                    W_z_to_state = tf.get_variable('z_to_state_weight', shape = [latent_variable_size, input_latent_size])
                    b_z_to_state = tf.get_variable('z_to_state_bias', shape = [input_latent_size])
                    
                    # Inference from normale
                    states_from_z_0 = tf.add(tf.matmul(z_0, W_z_to_state), b_z_to_state)
                    t_states_rebuild = tf.transpose(tf.reshape(states_from_z_0, [batch_size, nlayers, 2, -1]), perm=[1,2,0,3])
                    t_states_rebuild = tf.unstack(t_states_rebuild)
                    
                    new_encoder_states_src_0 = tuple([tf.nn.rnn_cell.LSTMStateTuple
                                                    (tf.reshape(t_states_rebuild[i][0],[batch_size, hidden_size]),
                                                    tf.reshape(t_states_rebuild[i][1],[batch_size, hidden_size])) 
                                                                                            for i in range(nlayers)]
                                                  )
                    
                    
                    
                    # Inference from input
                    states_from_z_1 = tf.add(tf.matmul(z_1, W_z_to_state), b_z_to_state)
                    t_states_rebuild = tf.transpose(tf.reshape(states_from_z_1, [batch_size, nlayers, 2, -1]), perm=[1,2,0,3])
                    t_states_rebuild = tf.unstack(t_states_rebuild)
                    
                    new_encoder_states_src_1 = tuple([tf.nn.rnn_cell.LSTMStateTuple
                                                    (tf.reshape(t_states_rebuild[i][0],[batch_size, hidden_size]),
                                                    tf.reshape(t_states_rebuild[i][1],[batch_size, hidden_size])) 
                                                                                            for i in range(nlayers)]
                                                  )
                    
                    new_memory_inputs = encoder_outputs_src
            
            #======DECODER_PROCESS====================
                    
                
            print("Inference with beam width %d"%(beam_width))
            maximum_iterations = config.get("maximum_iterations", tf.round(2 * max_length))

            if beam_width <= 1:  
                
                with tf.variable_scope("decoder"):
                    sampled_ids_0, _, sampled_length_0, log_probs_0, alignment_0 = decoder.dynamic_decode(
                                                                    tgt_emb,
                                                                    start_tokens,
                                                                    end_token,
                                                                    vocab_size=int(self.tgt_vocab_size),
                                                                    initial_state=new_encoder_states_src_0,
                                                                    maximum_iterations=maximum_iterations,
                                                                    output_layer = output_layer,
                                                                    mode=tf.estimator.ModeKeys.PREDICT,
                                                                    memory=new_memory_inputs,
                                                                    memory_sequence_length=encoder_seq_length_src,
                                                                    dtype=tf.float32,
                                                                    return_alignment_history=True
                                                                    )
                with tf.variable_scope("decoder", reuse=True): 
                    sampled_ids_1, _, sampled_length_1, log_probs_1, alignment_1 = decoder.dynamic_decode(
                                                                    tgt_emb,
                                                                    start_tokens,
                                                                    end_token,
                                                                    vocab_size=int(self.tgt_vocab_size),
                                                                    initial_state=new_encoder_states_src_1,
                                                                    maximum_iterations=maximum_iterations,
                                                                    output_layer = output_layer,
                                                                    mode=tf.estimator.ModeKeys.PREDICT,
                                                                    memory=new_memory_inputs,
                                                                    memory_sequence_length=encoder_seq_length_src,
                                                                    dtype=tf.float32,
                                                                    return_alignment_history=True
                                                                    )
            else:
                length_penalty = config.get("length_penalty", 0)
                
                with tf.variable_scope("decoder"):
                    sampled_ids_0, _, sampled_length_0, log_probs_0, alignment_0 = decoder.dynamic_decode_and_search(
                                                            tgt_emb,
                                                            start_tokens,
                                                            end_token,
                                                            vocab_size = int(self.tgt_vocab_size),
                                                            initial_state = new_encoder_states_src_0,
                                                            beam_width = beam_width,
                                                            length_penalty = length_penalty,
                                                            maximum_iterations = maximum_iterations,
                                                            output_layer = output_layer,
                                                            mode = tf.estimator.ModeKeys.PREDICT,
                                                            memory = new_memory_inputs,
                                                            memory_sequence_length = encoder_seq_length_src,
                                                            dtype=tf.float32,
                                                            return_alignment_history = True)
                with tf.variable_scope("decoder", reuse=True): 
                    sampled_ids_1, _, sampled_length_1, log_probs_1, alignment_1 = decoder.dynamic_decode_and_search(
                                                            tgt_emb,
                                                            start_tokens,
                                                            end_token,
                                                            vocab_size = int(self.tgt_vocab_size),
                                                            initial_state = new_encoder_states_src_1,
                                                            beam_width = beam_width,
                                                            length_penalty = length_penalty,
                                                            maximum_iterations = maximum_iterations,
                                                            output_layer = output_layer,
                                                            mode = tf.estimator.ModeKeys.PREDICT,
                                                            memory = new_memory_inputs,
                                                            memory_sequence_length = encoder_seq_length_src,
                                                            dtype=tf.float32,
                                                            return_alignment_history = True)

            target_tokens_0 = tgt_vocab_rev.lookup(tf.cast(sampled_ids_0, tf.int64))
            target_tokens_1 = tgt_vocab_rev.lookup(tf.cast(sampled_ids_1, tf.int64))
            
            predictions = [
                {
              "tokens": target_tokens_0,
              "length": sampled_length_0,
              "log_probs": log_probs_0,
              "alignment": alignment_0,
                            },
                {
              "tokens": target_tokens_1,
              "length": sampled_length_1,
              "log_probs": log_probs_1,
              "alignment": alignment_1,
                            } ]
            
            tgt_ids_batch = None
            tgt_length = None
            mu = None
            logvar = None
            outputs = None
            
        self.outputs = outputs
        self.mu = mu
        self.logvar = logvar
        
        return outputs, predictions, tgt_ids_batch, tgt_length, mu, logvar

In [None]:
def inference(config_file, checkpoint_path=None, test_feature_file=None):
    
    with open(config_file, "r") as stream:
        config = yaml.load(stream)
        
    assert test_feature_file!=None
    
    from opennmt.utils.misc import print_bytes
    
    graph = tf.Graph()
    
    with tf.Session(graph=graph,config=tf.ConfigProto(log_device_placement=False, allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True))) as sess_:
     
        eval_model = Model(config_file, "Inference", test_feature_file)
        #emb_src_batch = eval_model.emb_src_batch_()
        saver = tf.train.Saver()
        tf.tables_initializer().run()
        tf.global_variables_initializer().run()

        if checkpoint_path==None:
            checkpoint_dir = config["model_dir"]
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)

        print(("Evaluating model %s"%checkpoint_path))
        saver.restore(sess_, checkpoint_path)        

        predictions_0, predictions_1 = eval_model.prediction_()
        
        tokens_0 = predictions_0["tokens"]
        length_0 = predictions_0["length"]    
        
        tokens_1 = predictions_1["tokens"]
        length_1 = predictions_1["length"]
        
        sess_.run(eval_model.iterator_initializers())
        
        # pred_dict = sess_.run([predictions])
        pred_dict = None
        
        print("write to :%s"%os.path.join(config["model_dir"],"eval",os.path.basename(test_feature_file) + ".trans." + os.path.basename(checkpoint_path)))
        
        normal_eval_path = os.path.join(config["model_dir"],"eval",os.path.basename(test_feature_file) + ".trans." + os.path.basename(checkpoint_path) + ".normal")
        from_input_eval_path = os.path.join(config["model_dir"],"eval",os.path.basename(test_feature_file) + ".trans." + os.path.basename(checkpoint_path) + ".frominput")
        
        with open(normal_eval_path,"w") as output_0_, \
            open(from_input_eval_path,"w") as output_1_:
            while True:                 
                try:                
                    _tokens_1, _length_1, _tokens_0, _length_0 = sess_.run([tokens_1, length_1, tokens_0, length_0])                    
                    #print emb_src_batch_
                    for b in range(_tokens_0.shape[0]):                        
                        pred_toks = _tokens_0[b][0][:_length_0[b][0] - 1]                                                
                        pred_sent = b" ".join(pred_toks)                        
                        print_bytes(pred_sent, output_0_)    
                        
                    for b in range(_tokens_1.shape[0]):                        
                        pred_toks = _tokens_1[b][0][:_length_1[b][0] - 1]                                                
                        pred_sent = b" ".join(pred_toks)                        
                        print_bytes(pred_sent, output_1_)   
                        
                except tf.errors.OutOfRangeError:
                    break
        
    return normal_eval_path, from_input_eval_path, pred_dict

In [None]:
config_file = 'config/SVAE_zS_config_tuanh_WMT.yml'

with open(config_file, "r") as stream:
    config = yaml.load(stream)
    
# Eval directory stores prediction files
if not os.path.exists(os.path.join(config["model_dir"],"eval")):
    os.makedirs(os.path.join(config["model_dir"],"eval"))
if not os.path.exists(os.path.join(config["model_dir"],"important ckpts")):
    os.makedirs(os.path.join(config["model_dir"],"important ckpts"))
    
training_model = Model(config_file, "Training")

global_step = tf.train.create_global_step()

if config.get("Loss_Function","Cross_Entropy")=="Cross_Entropy":
    loss, kl_loss = training_model.loss_()
    use_kl_weight = config.get("use_kl_weight", True)
    # use_kl_weight = False
    if use_kl_weight : 
        kl_weight = tf.cond(loss > 2., lambda: tf.minimum(kl_coeff(global_step),0.001), lambda: tf.minimum(kl_coeff(global_step), 1))
        kl_weight = tf.cond(kl_loss > 1., lambda: kl_weight, lambda: tf.minimum(kl_coeff(global_step),0.001))
    else:
        kl_weight = tf.constant(1)

    tf.summary.scalar("kl_weight", kl_weight)
    generator_total_loss = loss + kl_loss * kl_weight

inputs = training_model.inputs_()

if config["mode"] == "Training":
    optimizer_params = config["optimizer_parameters"]
    with tf.variable_scope("main_training"):
        train_op, accum_vars_ = optimize_loss(generator_total_loss, config["optimizer_parameters"])
        
Eval_dataset_numb = len(config["eval_label_file"])
print("Number of validation set: ", Eval_dataset_numb)
external_evaluator = [None] * Eval_dataset_numb 
writer_bleu = [None] * Eval_dataset_numb 
print(1)
for i in range(Eval_dataset_numb):
    external_evaluator[i] = BLEUEvaluator(config["eval_label_file"][i], config["model_dir"])
    writer_bleu[i] = [tf.summary.FileWriter(os.path.join(config["model_dir"],"BLEU","domain_%d"%i,"from_normal")),
                      tf.summary.FileWriter(os.path.join(config["model_dir"],"BLEU","domain_%d"%i,"from_input"))]
print(2)
sess = tf.Session(config=tf.ConfigProto(log_device_placement=False, allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True)))
print(3)
writer = tf.summary.FileWriter(config["model_dir"])
print(4)
var_list_ = tf.global_variables()
print(5)
saver = tf.train.Saver(var_list_, max_to_keep=config["max_to_keep"])
saver_max0 = tf.train.Saver(var_list_, max_to_keep = 1)
saver_max1 = tf.train.Saver(var_list_, max_to_keep = 1)
saver_maxkld = tf.train.Saver(var_list_, max_to_keep = 1)
checkpoint_path = tf.train.latest_checkpoint(config["model_dir"])
print(6)
sess.run(global_step.initializer)
print(6.35)
sess.run(tf.global_variables_initializer())
print(7)
training_summary = tf.summary.merge_all()
global_step_ = sess.run(global_step)
print(8)
if checkpoint_path:
    try :
        print("Continue training:...")
        print("Load parameters from %s"%checkpoint_path)
        saver.restore(sess, checkpoint_path)        
        global_step_ = sess.run(global_step)
        print("global_step: ", global_step_)

        for i in range(Eval_dataset_numb):
            normal_eval_path, from_input_eval_path, prediction_dict = inference(config_file, checkpoint_path, config["eval_feature_file"][i])
            score_0 = external_evaluator[i].score(config["eval_label_file"][i], normal_eval_path)
            score_1 = external_evaluator[i].score(config["eval_label_file"][i], from_input_eval_path)
            print("=========================================EVALUATION SCORE==========================================")
            print("From Normal BLEU at checkpoint %s for testset %s: %f"%(checkpoint_path, config["eval_feature_file"][i], score_0))
            print("From Input BLEU at checkpoint %s for testset %s: %f"%(checkpoint_path, config["eval_feature_file"][i], score_1))
            print("=====================================================================================")
    except TypeError:
        print("There is a TypeError, the output maybe empty !")
        pass

else:
    print("Training from scratch")
    
sess.run(tf.tables_initializer())
sess.run(training_model.iterator_initializers())  

In [None]:
total_loss = []
best_bleu_0 = 0.
best_bleu_1 = 0.

In [None]:
run_time = 0.

print("Start training from step {:d}...".format(global_step_))

while global_step_ <= config["iteration_number"]:                       

    #=================== 1 iteration=======================
    start_time = time.time()
    
    if use_kl_weight:
        ce_loss_, kl_loss_,loss_, global_step_, _, kl_weight_ = sess.run([loss, kl_loss, generator_total_loss, global_step, train_op, kl_weight])     
    else:
        ce_loss_, kl_loss_,loss_, global_step_, _ = sess.run([loss, kl_loss, generator_total_loss, global_step, train_op])     
        kl_weight_ = 1
    
    run_time += time.time() - start_time
    
    total_loss.append(loss_)

    #==================printing things======================
    if (np.mod(global_step_, config["printing_freq"])) == 0:            
        print((datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
        print("*************Step: {:d} - RTime: {:f}***************".format(global_step_,run_time))
        print("CE Loss: {:4f} - KL Loss: {:4f} - KL Weight: {:4f}".format(ce_loss_, kl_loss_,kl_weight_))
        print("TotalLoss at step {:d}: {:4f}".format(global_step_, np.mean(total_loss)))
        run_time = 0.             

    if (np.mod(global_step_, config["summary_freq"])) == 0:
        training_summary_ = sess.run(training_summary)
        writer.add_summary(training_summary_, global_step=global_step_)
        writer.flush()
        total_loss = []

    if (np.mod(global_step_, config["save_freq"])) == 0 and global_step_ > 0:    
        print((datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
        checkpoint_path = os.path.join(config["model_dir"], 'model.ckpt')
        print(("save to %s"%(checkpoint_path)))
        saver.save(sess, checkpoint_path, global_step = global_step_)
        
        if kl_loss_ > 1.5 :
            checkpoint_path = os.path.join(config["model_dir"],"important ckpts", 'lastBigKLD.model.ckpt')
            saver_maxkld.save(sess, checkpoint_path, global_step = global_step_)

    if (np.mod(global_step_, config["eval_freq"])) == 0 and global_step_ >0: 
        try :
            checkpoint_path = tf.train.latest_checkpoint(config["model_dir"])
            for i in range(Eval_dataset_numb):
                normal_eval_path, from_input_eval_path, prediction_dict = inference(config_file, checkpoint_path, config["eval_feature_file"][i])
                score_0 = external_evaluator[i].score(config["eval_label_file"][i], normal_eval_path)
                score_1 = external_evaluator[i].score(config["eval_label_file"][i], from_input_eval_path)
                
                print("=========================================EVALUATION SCORE==========================================")
                print("From Normal BLEU at checkpoint %s for testset %s: %f"%(checkpoint_path, config["eval_feature_file"][i], score_0))
                print("From Input BLEU at checkpoint %s for testset %s: %f"%(checkpoint_path, config["eval_feature_file"][i], score_1))
                print("=====================================================================================")
                
                if score_0 > best_bleu_0 :
                    print((datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
                    checkpoint_path = os.path.join(config["model_dir"],"important ckpts", 'bestBLEU_0.model.ckpt')
                    print(("save to %s"%(checkpoint_path)))
                    saver_max0.save(sess, checkpoint_path, global_step = global_step_)
                    best_bleu_0 = score_0
                    
                if score_1 > best_bleu_1 :
                    print((datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
                    checkpoint_path = os.path.join(config["model_dir"],"important ckpts", 'bestBLEU_1.model.ckpt')
                    print(("save to %s"%(checkpoint_path)))
                    saver_max1.save(sess, checkpoint_path, global_step = global_step_)
                    best_bleu_1 = score_1
                
                score_summary = tf.Summary(value=[tf.Summary.Value(tag="eval_score_%d"%i, simple_value=score_0)])
                writer_bleu[i][0].add_summary(score_summary, global_step_)
                writer_bleu[i][0].flush()
                
                score_summary = tf.Summary(value=[tf.Summary.Value(tag="eval_score_%d"%i, simple_value=score_1)])
                writer_bleu[i][1].add_summary(score_summary, global_step_)
                writer_bleu[i][1].flush()
                
        except TypeError:
            print("There is a TypeError, the output maybe empty !")
            pass

In [None]:
# kl_weight = tf.cond(loss > 2., lambda: tf.minimum(kl_coeff(global_step),0.01), lambda: tf.minimum(kl_coeff(global_step), 1))
# tf.summary.scalar("kl_weight", kl_weight)
# generator_total_loss = loss + kl_loss * kl_weight
# training_summary = tf.summary.merge_all()

In [None]:
def inference_now(config_file,test_feature_file, checkpoint_path=None, num_iter=10):
    
    with open(config_file, "r") as stream:
        config = yaml.load(stream)
        
    if beam_width is not None :
        config["beam_width"] = beam_width
        
    def print_bytes_sentence(byted_list) :
        sen = []
        for b in byted_list:
            sb = b.decode("utf-8") 
            if sb not in ['<blank>','<s>','</s>'] :
                sen.append(sb)
        print(" ".join(sen))
        return None
    
    graph = tf.Graph()
    
    with tf.Session(graph=graph,config=tf.ConfigProto(log_device_placement=False, allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True))) as sess_:
     
        eval_model = Model(config_file, "Inference", test_feature_file)
        #emb_src_batch = eval_model.emb_src_batch_()
        saver = tf.train.Saver()
        tf.tables_initializer().run()
        tf.global_variables_initializer().run()

        if checkpoint_path==None:
            checkpoint_dir = config["model_dir"]
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)

        print(("Evaluating model %s"%checkpoint_path))
        saver.restore(sess_, checkpoint_path)        

        _,predictions = eval_model.prediction_()
        
        tokens = predictions["tokens"]
        length = predictions["length"]            
        
        for _ in range(num_iter) :
            
            sess_.run(eval_model.iterator_initializers())

            preds = sess_.run(tokens)

            with open(test_feature_file) as fp:  
                lines = fp.readlines()

            for i in range(len(lines)):
                print("----------------------------------")
                print("Source sentence: ")
                print(lines[i])
                print("Generated sentences:")
                for tokenized_line in preds[i]:
                    print_bytes_sentence(tokenized_line)
        
    return None

In [None]:
Questions = [
    'What is the best start up idea ? \n'
]

with open('infer_test', 'w') as file_:
    file_.writelines(Questions)
    
inference_now(config_file, test_feature_file='infer_test', checkpoint_path=None, num_iter=10)