In [1]:
# !/usr/bin/env python3
import os
import sys
sys.path.append('./bert/')
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.ops import rnn_cell_impl
import bert
from bert import modeling
from bert import tokenization
from bert.modeling import BertConfig, BertModel

In [2]:
class MyOutputProjectionWrapper(tf.contrib.rnn.RNNCell):
    """Operator adding an output projection to the given cell.
    Note: in many cases it may be more efficient to not use this wrapper,
    but instead concatenate the whole sequence of your outputs in time,
    do the projection on this batch-concatenated sequence, then split it
    if needed or directly feed into a softmax.
    """

    def __init__(self, cell, output_size, W, activation=None, reuse=None):
        """Create a cell with output projection.
        Args:
          cell: an RNNCell, a projection to output_size is added to it.
          output_size: integer, the size of the output after projection.
          activation: (optional) an optional activation function.
          reuse: (optional) Python boolean describing whether to reuse variables
            in an existing scope.  If not `True`, and the existing scope already has
            the given variables, an error is raised.
        Raises:
          TypeError: if cell is not an RNNCell.
          ValueError: if output_size is not positive.
        """
        super(MyOutputProjectionWrapper, self).__init__(_reuse=reuse)
        rnn_cell_impl.assert_like_rnncell("cell", cell)
        if output_size < 1:
            raise ValueError(
                "Parameter output_size must be > 0: %d." % output_size)
        self._cell = cell
        self._output_size = output_size
        self._activation = activation
        self._W = W

    @property
    def state_size(self):
        return self._cell.state_size

    @property
    def output_size(self):
        return self._output_size

    def zero_state(self, batch_size, dtype):
        with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
            return self._cell.zero_state(batch_size, dtype)

    def call(self, inputs, state):
        """Run the cell and output projection on inputs, starting from state."""
        output, res_state = self._cell(inputs, state)
        projected = tf.matmul(output, tf.transpose(self._W))
        if self._activation:
            projected = self._activation(projected)
        return projected, res_state


In [3]:
class ChatModel:
    def __init__(self, config_file, x_max_len, y_max_len, max_decode_len, vocab, ckpt_file=None):
        self.x_max_len = x_max_len
        self.y_max_len = y_max_len
        self.vocab = vocab
        self.beam_width = 4
        self.max_decode_len = max_decode_len
        self.x = tf.placeholder(tf.int32, shape=[None, self.x_max_len], name='x')
        self.x_mask = tf.placeholder(tf.int32, shape=[None, self.x_max_len], name='x_mask')
        self.x_seg = tf.placeholder(tf.int32, shape=[None, self.x_max_len], name='x_seg')
        self.x_len = tf.placeholder(tf.int32, shape=[None], name='x_len')
        self.y = tf.placeholder(tf.int32, shape=[None, self.y_max_len], name='y')
        self.y_len = tf.placeholder(tf.int32, shape=[None], name='y_len')
        self.bert_config = BertConfig.from_json_file(config_file)
        self.hidden_size = self.bert_config.hidden_size
        self.vocab_size = self.bert_config.vocab_size
        self.bert_model = BertModel(config=self.bert_config, input_ids=self.x, input_mask=self.x_mask, token_type_ids=self.x_seg, is_training=True, use_one_hot_embeddings=False)
        if ckpt_file is not None:
            tvars = tf.trainable_variables()
            self.assignment_map, self.initialized_variable_map = modeling.get_assignment_map_from_checkpoint(tvars, ckpt_file)
        
    def inference(self):
        x = self.x
        batch_size = tf.shape(self.x)[0]
        X = self.bert_model.get_sequence_output()
        self.embeddings = self.bert_model.get_embedding_table()
        
        start_tokens = tf.ones([batch_size], dtype=tf.int32) * tf.convert_to_tensor(self.vocab['<S>'])
        train_output = tf.concat([tf.expand_dims(start_tokens, 1), self.y], 1)
        output_emb = tf.nn.embedding_lookup(self.embeddings, train_output)
        output_len = 1 + self.y_len
        
        encoder_output = X[:,1:,:]
        encoder_state = X[:,0,:]
        
        train_helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(output_emb, output_len, self.embeddings, 0.1)
        cell = tf.contrib.rnn.GRUCell(num_units=self.hidden_size)
        
        def decode(scope):
            with tf.variable_scope(scope):
                attention_mechnism = tf.contrib.seq2seq.BahdanauAttention(num_units=self.hidden_size, memory=encoder_output, memory_sequence_length=self.x_len)
                attention_cell = tf.contrib.seq2seq.AttentionWrapper(cell, attention_mechnism, attention_layer_size=self.hidden_size)
                out_cell = MyOutputProjectionWrapper(attention_cell, self.vocab_size, self.embeddings, reuse=False)
                initial_state = out_cell.zero_state(dtype=tf.float32, batch_size=batch_size)
                initial_state = initial_state.clone(cell_state=encoder_state)
                decoder = tf.contrib.seq2seq.BasicDecoder(cell=out_cell, initial_state=initial_state, helper=train_helper)
                t_final_output, t_final_state, t_final_seq_len = tf.contrib.seq2seq.dynamic_decode(decoder=decoder, impute_finished=True, output_time_major=False, maximum_iterations=self.max_decode_len)
            with tf.variable_scope(scope, reuse=True):
                tiled_encoder_output = tf.contrib.seq2seq.tile_batch(encoder_output, multiplier=self.beam_width)
                tiled_encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=self.beam_width)
                tiled_x_len = tf.contrib.seq2seq.tile_batch(self.x_len, multiplier=self.beam_width)
                attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units=self.hidden_size, memory=tiled_encoder_output, memory_sequence_length=tiled_x_len)
                attention_cell = tf.contrib.seq2seq.AttentionWrapper(cell, attention_mechanism=attention_mechnism, attention_layer_size=self.hidden_size)
                out_cell = MyOutputProjectionWrapper(attention_cell, self.vocab_size, self.embeddings, reuse=True)
                initial_state = out_cell.zero_state(dtype=tf.float32, batch_size=batch_size*self.beam_width)
                initial_state = initial_state.clone(cell_state=tiled_encoder_state)
                end_token = self.vocab['<T>']
                decoder = tf.contrib.seq2seq.BeamSearchDecoder(cell=out_cell, beam_width=self.beam_width, coverage_penalty_weight=0.001, embedding=self.embeddings, initial_state=initial_state, start_tokens=start_tokens, end_token=end_token)
                p_final_output, p_final_state, p_seq_len = tf.contrib.seq2seq.dynamic_decode(decoder=decoder, output_time_major=False, maximum_iterations=self.max_decode_len)
            return t_final_output, p_final_output
        t_output, p_output = decode('decode')
        p_output = tf.identity(p_output.predicted_ids[:,:,0], name='predictions')
        return t_output, p_output
        
                
        
        

In [4]:
vocab_file = './model/chinese_L-12_H-768_A-12/vocab.txt'
tokenizer = bert.tokenization.FullTokenizer(vocab_file)

In [5]:
config_file = './model/chinese_L-12_H-768_A-12/bert_config.json'
x_max_len = 100
y_max_len = 50
max_decode_len = 50
ckpt_file = './model/chinese_L-12_H-768_A-12/bert_model.ckpt'
chat_model = ChatModel(config_file, x_max_len, y_max_len, max_decode_len, tokenizer.vocab, ckpt_file)
t_output, p_output = chat_model.inference()

In [6]:
a = tf.convert_to_tensor(tokenizer.vocab['<T>'])