# Pretrain Bert with Sentencepiece model

## Refer to : 

### Estimator API
https://www.tensorflow.org/guide/estimators

### TPUEstimator API
https://cloud.google.com/tpu/docs/tutorials/migrating-to-tpuestimator-api

In [1]:
import tensorflow as tf
from bert import modeling
from bert import optimization
import tempfile
import json

In [2]:
### Configuration paramters

output_dir= 'model_test'

train_batch_size=6
max_seq_length=512
max_predictions_per_seq=20
num_train_steps=2000
num_warmup_steps=1000
save_checkpoints_steps=2000
learning_rate=1e-4
max_eval_steps=100

tf_record_files="""data/wiki-kor/AA/all-maxseq512.tfrecord,
data/wiki-kor/AB/all-maxseq512.tfrecord,
data/wiki-kor/AC/all-maxseq512.tfrecord,
data/wiki-kor/AD/all-maxseq512.tfrecord,
data/wiki-kor/AE/all-maxseq512.tfrecord,
data/wiki-kor/AF/all-maxseq512.tfrecord,
data/wiki-kor/AG/all-maxseq512.tfrecord"""

input_files =[] 
for input_pattern in tf_record_files.split(","):
    input_files.extend(tf.gfile.Glob(input_pattern.strip()))
print(input_files)


bert_config =  modeling.BertConfig.from_json_file('model/bert_config.json')

['data/wiki-kor/AA/all-maxseq512.tfrecord', 'data/wiki-kor/AB/all-maxseq512.tfrecord', 'data/wiki-kor/AC/all-maxseq512.tfrecord', 'data/wiki-kor/AD/all-maxseq512.tfrecord', 'data/wiki-kor/AE/all-maxseq512.tfrecord', 'data/wiki-kor/AF/all-maxseq512.tfrecord', 'data/wiki-kor/AG/all-maxseq512.tfrecord']


# 1. Write a model function

   * The model function used here has the following call signature
   ```python
   def my_model_fn(
       features, # This is batch_features from input_fn
       labels,   # This is batch_labels from input_fn
       mode,     # An instance of tf.estimator.ModeKeys
       params):  # Additional configuration
   ```


###  1.1  Masked LM Task output function

INFO:tensorflow:  name = bert/pooler/dense/kernel:0, shape = (768, 768)
INFO:tensorflow:  name = bert/pooler/dense/bias:0, shape = (768,)
INFO:tensorflow:  name = cls/predictions/transform/dense/kernel:0, shape = (768, 768)
INFO:tensorflow:  name = cls/predictions/transform/dense/bias:0, shape = (768,)
INFO:tensorflow:  name = cls/predictions/transform/LayerNorm/beta:0, shape = (768,)
INFO:tensorflow:  name = cls/predictions/transform/LayerNorm/gamma:0, shape = (768,)
INFO:tensorflow:  name = cls/predictions/output_bias:0, shape = (32000,)
INFO:tensorflow:  name = cls/seq_relationship/output_weights:0, shape = (2, 768)
INFO:tensorflow:  name = cls/seq_relationship/output_bias:0, shape = (2,)

In [3]:
def get_masked_lm_output(bert_config, input_tensor, output_weights,
                         positions,label_ids, label_weights):
    
    def gather_indexes(input_tensor, positions):
        # Gathers the vectors at the specific positions over a minibatch.
        sequence_shape = modeling.get_shape_list(input_tensor, expected_rank=3)
        batch_size = sequence_shape[0]
        seq_length = sequence_shape[1]
        width      = sequence_shape[2]

        flat_offsets = tf.reshape(
            tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
        flat_positions = tf.reshape(positions + flat_offsets, [-1])
        flat_sequence_tensor = tf.reshape(input_tensor,
                                        [batch_size * seq_length, width])
        output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
        return output_tensor
    
    
    #Compute loss& log probs for the masked LM
    
    input_tensor = gather_indexes(input_tensor, positions)

    with tf.variable_scope("cls/predictions"):
        # We apply one more non-linear transformation before the output layer.
        # This matrix is not used after pre-training.
        with tf.variable_scope("transform"):
            input_tensor = tf.layers.dense(
                  input_tensor,
                  units=bert_config.hidden_size,
                  activation=modeling.get_activation(bert_config.hidden_act),
                  kernel_initializer=modeling.create_initializer(
                      bert_config.initializer_range))
                
            input_tensor = modeling.layer_norm(input_tensor)


        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        output_bias = tf.get_variable(
            "output_bias",
            shape=[bert_config.vocab_size],
            initializer=tf.zeros_initializer())


        logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)
        log_probs = tf.nn.log_softmax(logits, axis=-1)

        label_ids = tf.reshape(label_ids, [-1])
        label_weights = tf.reshape(label_weights, [-1])

        one_hot_labels = tf.one_hot(
            label_ids, depth=bert_config.vocab_size, dtype=tf.float32)

        # The `positions` tensor might be zero-padded (if the sequence is too
        # short to have the maximum number of predictions). The `label_weights`
        # tensor has a value of 1.0 for every real prediction and 0.0 for the
        # padding predictions.
        per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
        numerator = tf.reduce_sum(label_weights * per_example_loss)
        denominator = tf.reduce_sum(label_weights) + 1e-5
        loss = numerator / denominator

    return (loss, per_example_loss, log_probs)

### 1.2  Next Sentence Task output function

In [4]:
def get_next_sentence_output(bert_config, input_tensor, labels):
    #Get loss and log probs for the next sentence prediction.

    # Simple binary classification. 
    # 0 is "next sentence" 
    # 1 is "random sentence". 
    # This weight matrix is not used after pre-training.
    with tf.variable_scope("cls/seq_relationship"):
        output_weights = tf.get_variable(
            "output_weights",
            shape=[2, bert_config.hidden_size],
            initializer=modeling.create_initializer(bert_config.initializer_range))
        output_bias = tf.get_variable(
            "output_bias", shape=[2], initializer=tf.zeros_initializer())

        logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)
        log_probs = tf.nn.log_softmax(logits, axis=-1)
        labels = tf.reshape(labels, [-1])
        one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
        per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
        loss = tf.reduce_mean(per_example_loss)
        return (loss, per_example_loss, log_probs)


### 1.3. Metrics function for computing loss and accuracy of evaluation

In [5]:
def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
              masked_lm_weights, next_sentence_example_loss,
              next_sentence_log_probs, next_sentence_labels):

    # Computes the loss and accuracy of the model
    masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
                                     [-1, masked_lm_log_probs.shape[-1]])
    masked_lm_predictions = tf.argmax(
        masked_lm_log_probs, axis=-1, output_type=tf.int32)
    masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
    masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
    masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
    masked_lm_accuracy = tf.metrics.accuracy(
        labels=masked_lm_ids,
        predictions=masked_lm_predictions,
        weights=masked_lm_weights)
    masked_lm_mean_loss = tf.metrics.mean(
        values=masked_lm_example_loss, weights=masked_lm_weights)

    next_sentence_log_probs = tf.reshape(
        next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]])
    next_sentence_predictions = tf.argmax(
        next_sentence_log_probs, axis=-1, output_type=tf.int32)
    next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
    next_sentence_accuracy = tf.metrics.accuracy(
        labels=next_sentence_labels, predictions=next_sentence_predictions)
    next_sentence_mean_loss = tf.metrics.mean(
        values=next_sentence_example_loss)

    return {
        "masked_lm_accuracy": masked_lm_accuracy,
        "masked_lm_loss": masked_lm_mean_loss,
        "next_sentence_accuracy": next_sentence_accuracy,
        "next_sentence_loss": next_sentence_mean_loss,
    }

### 1.4.  Implementing Model function Builder for Estimator

The tf.estimator.EstimatorSpec returned for evaluation typically contains the following information:

   * loss, which is the model's loss
   * eval_metric_ops, which is an optional dictionary of metrics.

So, we'll create a dictionary containing our sole metric. If we had calculated other metrics, we would have added them as additional key/value pairs to that same dictionary. Then, we'll pass that dictionary in the eval_metric_ops argument of tf.estimator.EstimatorSpec

In [6]:
def model_fn_builder(bert_config, init_checkpoint, learning_rate,
                     num_train_steps, num_warmup_steps):
    
    # Return model_fn for Estimator

    def model_fn(features, labels, mode, params):
        
        # Define the model_fn
        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]
        next_sentence_labels = features["next_sentence_labels"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = modeling.BertModel(
            config=bert_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=False)

        (masked_lm_loss,
         masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
            bert_config, model.get_sequence_output(), model.get_embedding_table(),
            masked_lm_positions, masked_lm_ids, masked_lm_weights)

        (next_sentence_loss, next_sentence_example_loss,
         next_sentence_log_probs) = get_next_sentence_output(
            bert_config, model.get_pooled_output(), next_sentence_labels)

        total_loss = masked_lm_loss + next_sentence_loss

        tvars = tf.trainable_variables()

        initialized_variable_names = {}
        
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)


        output_spec = None
        
        # Training
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(
                total_loss, learning_rate, num_train_steps, 
                num_warmup_steps,False)

            output_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op)
        
        # evaluation    
        elif mode == tf.estimator.ModeKeys.EVAL:

            output_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metric_ops = metric_fn(masked_lm_example_loss, 
                          masked_lm_log_probs, masked_lm_ids,
                          masked_lm_weights, next_sentence_example_loss,
                          next_sentence_log_probs, next_sentence_labels))
        return output_spec

    return model_fn


### Input Function Builder

### Estimator 

In [7]:
model_fn = model_fn_builder(
    bert_config=bert_config,
    init_checkpoint=None,
    learning_rate=learning_rate,
    num_train_steps = num_train_steps,
    num_warmup_steps=num_warmup_steps)

In [8]:
run_config = tf.estimator.RunConfig(
    model_dir=output_dir,
    save_checkpoints_steps=save_checkpoints_steps)

In [9]:
params = {"batch_size":train_batch_size,
         "max_seq_length":max_seq_length,
         "max_predictions_per_seq":max_predictions_per_seq,
         "num_train_steps":num_train_steps,
         "num_warmup_steps":num_warmup_steps}
estimator = tf.estimator.Estimator(
    model_fn=model_fn,
    config = run_config,
    params=params
    #train_batch_size=train_batch_size
)

INFO:tensorflow:Using config: {'_model_dir': 'model_test', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 2000, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f5dc13a4470>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


# 2. Write an Input function

  The input function will read records from .tfrecords  file,
  parse each record and form a batch to be used with our model.

In [10]:
def input_fn_builder(input_files,
                     max_seq_length,
                     max_predictions_per_seq,
                     is_training,
                     num_cpu_threads=4):
    #Creates an `input_fn` closure to be passed to TPUEstimator.

    def input_fn(params):
        #The actual input function
        batch_size = params["batch_size"]

        d = tf.data.TFRecordDataset(input_files)

        d = d.repeat()
        
        d = d.map(parse)
        
        d = d.batch(batch_size)
        
        return d

    return input_fn


def parse(record):
    
    name_to_features = {
            "input_ids":
                tf.FixedLenFeature([max_seq_length], tf.int64),
            "input_mask":
                tf.FixedLenFeature([max_seq_length], tf.int64),
            "segment_ids":
                tf.FixedLenFeature([max_seq_length], tf.int64),
            "masked_lm_positions":
                tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
            "masked_lm_ids":
                tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
            "masked_lm_weights":
                tf.FixedLenFeature([max_predictions_per_seq], tf.float32),
            "next_sentence_labels":
                tf.FixedLenFeature([1], tf.int64),
        }

    example = tf.parse_single_example(record, name_to_features)

    for name in list(example.keys()):
        t = example[name]
        if t.dtype == tf.int64:
            t = tf.to_int32(t)
        example[name] = t

    return example


# 3. Train 

### 3.1 Input Function Build for Training


In [11]:
train_input_fn = input_fn_builder(
    input_files=input_files,
    max_seq_length=max_seq_length,
    max_predictions_per_seq=max_predictions_per_seq,
    is_training=True)

### 3.2 Training using estimator

In [12]:
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

INFO:tensorflow:Skipping training since max_steps has already saved.


<tensorflow.python.estimator.estimator.Estimator at 0x7f5dc13a4dd8>

# 4. Evaluation

### 4.1 Input Function Build for Evaluation

In [13]:
eval_input_fn = input_fn_builder(
    input_files=input_files,
    max_seq_length=max_seq_length,
    max_predictions_per_seq=max_predictions_per_seq,
    is_training=False)

### 4.2 Evaluation using estimators

In [14]:
result = estimator.evaluate(
    input_fn=eval_input_fn, steps=100)

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2019-04-08-06:54:24
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from model_test/model.ckpt-2000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Evaluation [80/100]
INFO:tensorflow:Evaluation [90/100]
INFO:tensorflow:Evaluation [100/100]
INFO:tensorflow:Finished evaluation at 2019-04-08-06:54:40
INFO:tensorflow:Saving dict for global step 2000: global_step = 2000, loss = 8.713631, masked_lm_accuracy = 0.054181494, masked_lm_loss = 8.11758, next_sentence_accuracy = 0.71166664, next_sentence_loss = 0.5982522
INFO:tensorflow:Saving 'checkpoint_path' sum