# BERT finetuning tasks in 5 minutes with Cloud TPU

<table class="tfo-notebook-buttons" align="left" >
 <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/tpu/blob/master/tools/colab/bert_finetuning_with_cloud_tpus.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/tpu/blob/master/tools/colab/bert_finetuning_with_cloud_tpus.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>


**BERT**, or **B**idirectional **E**mbedding **R**epresentations from **T**ransformers, is a new method of pre-training language representations which obtains state-of-the-art results on a wide array of Natural Language Processing (NLP) tasks. The academic paper can be found here: https://arxiv.org/abs/1810.04805.

This Colab demonstates using a free Colab Cloud TPU to fine-tune sentence and sentence-pair classification tasks built on top of pretrained BERT models.

**Note:**  You will need a GCP (Google Compute Engine) account and a GCS (Google Cloud 
Storage) bucket for this Colab to run.

Please follow the [Google Cloud TPU quickstart](https://cloud.google.com/tpu/docs/quickstart) for how to create GCP account and GCS bucket. You have [$300 free credit](https://cloud.google.com/free/) to get started with any GCP product. You can learn more about Cloud TPU at https://cloud.google.com/tpu/docs.

Once you finish the setup, let's start!

**Firstly**, we need to set up Colab TPU running environment, verify a TPU device is succesfully connected and upload credentials to TPU for GCS bucket usage.

In [None]:
import datetime
import json
import os
import pprint
import random
import string
import sys
import tensorflow as tf

assert 'COLAB_TPU_ADDR' in os.environ, 'ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!'
TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']
print('TPU address is', TPU_ADDRESS)

from google.colab import auth
auth.authenticate_user()
with tf.Session(TPU_ADDRESS) as session:
  print('TPU devices:')
  pprint.pprint(session.list_devices())

  # Upload credentials to TPU.
  with open('/content/adc.json', 'r') as f:
    auth_info = json.load(f)
  tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
  # Now credentials are set for all future sessions on this TPU.

TPU address is grpc://10.23.109.42:8470
TPU devices:
[_DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:CPU:0, CPU, -1, 3369616037017264270),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 2622471493146608794),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 1798304136644219600),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 10581187294090197416),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 12329653481336770435),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 14528859872554398856),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, 17707723227843171150),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 3881497008498220576),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 815797905580

W0825 17:52:06.277870 139820775229312 lazy_loader.py:50] 
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



# New Section

**Secondly**, prepare and import BERT modules.

In [None]:
import sys
!rm -r bert_repo6
!test -d bert_repo6 || git clone https://github.com/AdeDZY/SIGIR19-BERT-IR bert_repo6
if not 'bert_repo6' in sys.path:
  sys.path += ['bert_repo6']
if not '.' in sys.path:
  sys.path += ['.']

Cloning into 'bert_repo6'...
remote: Enumerating objects: 69, done.[K
remote: Counting objects:   1% (1/69)[Kremote: Counting objects:   2% (2/69)[Kremote: Counting objects:   4% (3/69)[Kremote: Counting objects:   5% (4/69)[Kremote: Counting objects:   7% (5/69)[Kremote: Counting objects:   8% (6/69)[Kremote: Counting objects:  10% (7/69)[Kremote: Counting objects:  11% (8/69)[Kremote: Counting objects:  13% (9/69)[Kremote: Counting objects:  14% (10/69)[Kremote: Counting objects:  15% (11/69)[Kremote: Counting objects:  17% (12/69)[Kremote: Counting objects:  18% (13/69)[Kremote: Counting objects:  20% (14/69)[Kremote: Counting objects:  21% (15/69)[Kremote: Counting objects:  23% (16/69)[Kremote: Counting objects:  24% (17/69)[Kremote: Counting objects:  26% (18/69)[Kremote: Counting objects:  27% (19/69)[Kremote: Counting objects:  28% (20/69)[Kremote: Counting objects:  30% (21/69)[Kremote: Counting objects:  31% (22/69)[Kremote: Counti

In [None]:
sys.path = ['',
 '/env/python',
 '/usr/lib/python36.zip',
 '/usr/lib/python3.6',
 '/usr/lib/python3.6/lib-dynload',
 '/usr/local/lib/python3.6/dist-packages',
 '/usr/lib/python3/dist-packages',
 '/usr/local/lib/python3.6/dist-packages/IPython/extensions',
 '/root/.ipython',
 '.',
 'bert_repo6']

In [None]:
sys.path

['',
 '/env/python',
 '/usr/lib/python36.zip',
 '/usr/lib/python3.6',
 '/usr/lib/python3.6/lib-dynload',
 '/usr/local/lib/python3.6/dist-packages',
 '/usr/lib/python3/dist-packages',
 '/usr/local/lib/python3.6/dist-packages/IPython/extensions',
 '/root/.ipython',
 '.',
 'bert_repo6']

In [None]:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import csv
import os
import modeling
import optimization
import tokenization
import tensorflow as tf
import random
import json
from run_qe_classifier import *
import random


W0825 17:52:48.086066 139820775229312 deprecation_wrapper.py:119] From bert_repo6/optimization.py:87: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.



In [None]:
FOLD=5 #@param {type:"integer"}
QUERY_FIELD="desc" #@param {type:"string"}

In [None]:
class MyRobust04Processor(DataProcessor):

    def __init__(self):
        self.max_test_depth = 100
        self.max_train_depth = 1000
        self.n_folds = 5
        self.fold = FOLD
        self.q_fields = QUERY_FIELD.split(' ')
        tf.logging.info("Using query fields {}".format(' '.join(self.q_fields)))

        self.train_folds = [(self.fold + i) % self.n_folds + 1 for i in range(self.n_folds - 1)]
        self.dev_fold = (self.fold + self.n_folds - 2) % self.n_folds + 1
        self.test_folds = (self.fold + self.n_folds - 1) % self.n_folds + 1
        tf.logging.info("Train Folds: {}".format(str(self.train_folds)))
        tf.logging.info("Dev Fold: {}".format(str(self.dev_fold)))
        tf.logging.info("Test Fold: {}".format(str(self.test_folds)))

    def get_train_examples(self, data_dir):
        examples = []
        train_files = ["{}.trec.with_json".format(i) for i in self.train_folds]

        qrel_file = tf.gfile.Open(os.path.join(data_dir, "qrels"))
        qrels = self._read_qrel(qrel_file)
        tf.logging.info("Qrel size: {}".format(len(qrels)))

        query_file = tf.gfile.Open(os.path.join(data_dir, "queries.json"))
        qid2queries = self._read_queries(query_file)
        tf.logging.info("Loaded {} queries.".format(len(qid2queries)))

        for file_name in train_files:
            train_file = tf.gfile.Open(os.path.join(data_dir, file_name))
            for i, line in enumerate(train_file):
               
                items = line.strip().split('#')
                trec_line = items[0]

                qid, _, docid, r, _, _ = trec_line.strip().split(' ')
                
                if int(docid.split('_')[-1].split('-')[-1])!=0 and random.random() > 0.1:
                    continue
                    
                assert qid in qid2queries, "QID {} not found".format(qid)
                q_json_dict = qid2queries[qid]
                q_text_list = [tokenization.convert_to_unicode(q_json_dict[field]) for field in self.q_fields]

                json_dict = json.loads('#'.join(items[1:]))
                d = tokenization.convert_to_unicode(json_dict["doc"]["body"])

                r = int(r)
                if r > self.max_train_depth:
                    continue
                label = tokenization.convert_to_unicode("0")
                if (qid, docid) in qrels or (qid, docid.split('_')[0]) in qrels:
                    label = tokenization.convert_to_unicode("1")
                guid = "train-%s-%s" % (qid, docid)
                examples.append(
                    InputExample(guid=guid, text_a_list=q_text_list, text_b=d, label=label)
                )
            train_file.close()
        random.shuffle(examples)
        return examples

    def get_dev_examples(self, data_dir):
        examples = []
        dev_file = tf.gfile.Open(os.path.join(data_dir, "{}.trec.with_json".format(self.dev_folds)))
        qrel_file = tf.gfile.Open(os.path.join(data_dir, "qrels"))
        qrels = self._read_qrel(qrel_file)
        tf.logging.info("Qrel size: {}".format(len(qrels)))

        query_file = tf.gfile.Open(os.path.join(data_dir, "queries.json"))
        qid2queries = self._read_queries(query_file)
        tf.logging.info("Loaded {} queries.".format(len(qid2queries)))
        
        flag = False
        for i, line in enumerate(dev_file):
            items = line.strip().split('#')
            trec_line = items[0]

            qid, _, docid, r, _, _ = trec_line.strip().split(' ')
            assert qid in qid2queries, "QID {} not found".format(qid)
            q_json_dict = qid2queries[qid]
            q_text_list = [tokenization.convert_to_unicode(q_json_dict[field]) for field in self.q_fields]

            json_dict = json.loads('#'.join(items[1:]))
            d = tokenization.convert_to_unicode(json_dict["doc"]["body"])

            r = int(r)
            if r > self.max_test_depth:
                continue
            label = tokenization.convert_to_unicode("0")
            if (qid, docid) in qrels or (qid, docid.split('_')[0]) in qrels:
                label = tokenization.convert_to_unicode("1")
                flag = True
            guid = "dev-%s-%s" % (qid, docid)
            examples.append(
                InputExample(guid=guid, text_a_list=q_text_list, text_b=d, label=label)
            )
        dev_file.close()
        if not flag:
            tf.logging.warning("No relevant document is labeled!")
        return examples

    def get_test_examples(self, data_dir):
        examples = []
        dev_file = tf.gfile.Open(os.path.join(data_dir, "{}.trec.with_json".format(self.test_folds)))
        qrel_file = tf.gfile.Open(os.path.join(data_dir, "qrels"))
        qrels = self._read_qrel(qrel_file)
        tf.logging.info("Qrel size: {}".format(len(qrels)))

        query_file = tf.gfile.Open(os.path.join(data_dir, "queries.json"))
        qid2queries = self._read_queries(query_file)
        tf.logging.info("Loaded {} queries.".format(len(qid2queries)))

        for i, line in enumerate(dev_file):
            items = line.strip().split('#')
            trec_line = items[0]

            qid, _, docid, r, _, _ = trec_line.strip().split(' ')
            assert qid in qid2queries, "QID {} not found".format(qid)
            q_json_dict = qid2queries[qid]
            q_text_list = [tokenization.convert_to_unicode(q_json_dict[field]) for field in self.q_fields]

            json_dict = json.loads('#'.join(items[1:]))
            d = tokenization.convert_to_unicode(json_dict["doc"]["body"])

            r = int(r)
            if r > self.max_test_depth:
                continue
            label = tokenization.convert_to_unicode("0")
            if (qid, docid) in qrels or (qid, docid.split('_')[0]) in qrels:
                label = tokenization.convert_to_unicode("1")
            guid = "test-%s-%s" % (qid, docid)
            examples.append(
                InputExample(guid=guid, text_a_list=q_text_list, text_b=d, label=label)
            )
        dev_file.close()
        return examples

    def _read_qrel(self, qrel_file):
        qrels = set()
        for line in qrel_file:
            qid, _, docid, rel = line.strip().split(' ')
            rel = int(rel)
            if rel > 0:
                qrels.add((qid, docid))
        return qrels

    def _read_queries(self, query_file):
        qid2queries = {}
        for i, line in enumerate(query_file):
            json_dict = json.loads(line)
            qid = json_dict['qid']
            qid2queries[qid] = json_dict
            if i < 3:
              tf.logging.info("Example Q: {}".format(json_dict))
        return qid2queries
   
    def get_labels(self):
        return ["0", "1"]


**Thirdly**, prepare for training:

*  Specify task and download training data.
*  Specify BERT pretrained model
*  Specify GS bucket, create output directory for model checkpoints and eval results.



In [None]:
TASK = 'robust-descinit-passage' #@param {type:"string"}

# Available pretrained model checkpoints:
#   uncased_L-12_H-768_A-12: uncased BERT base model
#   uncased_L-24_H-1024_A-16: uncased BERT large model
#   cased_L-12_H-768_A-12: cased BERT large model
BERT_MODEL = 'uncased_L-12_H-768_A-12' #@param {type:"string"}
BERT_PRETRAINED_DIR = 'gs://cloud-tpu-checkpoints/bert/' + BERT_MODEL
print('***** BERT pretrained directory: {} *****'.format(BERT_PRETRAINED_DIR))
!gsutil ls $BERT_PRETRAINED_DIR

BUCKET = 'bertir' #@param {type:"string"}
assert BUCKET, 'Must specify an existing GCS bucket name'
OUTPUT_DIR = 'gs://{}/bert/models/{}-fold{}'.format(BUCKET, TASK, FOLD)
tf.gfile.MakeDirs(OUTPUT_DIR)
print('***** Model output directory: {} *****'.format(OUTPUT_DIR))

DATA_DIR = "robust/cv_descinit_passages/" #@param {type:"string"}
TASK_DATA_DIR = 'gs://{}/{}'.format(BUCKET, DATA_DIR) 

!gsutil ls $TASK_DATA_DIR


***** BERT pretrained directory: gs://cloud-tpu-checkpoints/bert/uncased_L-12_H-768_A-12 *****
gs://cloud-tpu-checkpoints/bert/uncased_L-12_H-768_A-12/bert_config.json
gs://cloud-tpu-checkpoints/bert/uncased_L-12_H-768_A-12/bert_model.ckpt.data-00000-of-00001
gs://cloud-tpu-checkpoints/bert/uncased_L-12_H-768_A-12/bert_model.ckpt.index
gs://cloud-tpu-checkpoints/bert/uncased_L-12_H-768_A-12/bert_model.ckpt.meta
gs://cloud-tpu-checkpoints/bert/uncased_L-12_H-768_A-12/checkpoint
gs://cloud-tpu-checkpoints/bert/uncased_L-12_H-768_A-12/vocab.txt
***** Model output directory: gs://bertir/bert/models/robust-descinit-passage-fold5 *****
gs://bertir/robust/cv_descinit_passages/1.trec.with_json
gs://bertir/robust/cv_descinit_passages/2.trec.with_json
gs://bertir/robust/cv_descinit_passages/3.trec.with_json
gs://bertir/robust/cv_descinit_passages/4.trec.with_json
gs://bertir/robust/cv_descinit_passages/5.trec.with_json
gs://bertir/robust/cv_descinit_passages/qrels
gs://bertir/robust/cv_descinit_

**Now, let's play!**

In [None]:
# Setup task specific model and TPU running config.

import modeling
import optimization
import tokenization

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' 

tf.logging.set_verbosity(tf.logging.INFO)


# Model Hyper Parameters
TRAIN_BATCH_SIZE = 16
EVAL_BATCH_SIZE = 8
PREDICT_BATCH_SIZE = 32

LEARNING_RATE = 1e-5
NUM_TRAIN_EPOCHS = 1.0
WARMUP_PROPORTION = 0.1
MAX_SEQCONCAT_LENGTH = 256

# Model configs
SAVE_CHECKPOINTS_STEPS = 20000
ITERATIONS_PER_LOOP = 1000
NUM_TPU_CORES = 8
VOCAB_FILE = os.path.join(BERT_PRETRAINED_DIR, 'vocab.txt')
CONFIG_FILE = os.path.join(BERT_PRETRAINED_DIR, 'bert_config.json')
INIT_CHECKPOINT = os.path.join(BERT_PRETRAINED_DIR, 'bert_model.ckpt')
DO_LOWER_CASE = BERT_MODEL.startswith('uncased')


processor = MyRobust04Processor()
label_list = processor.get_labels()
tokenizer = tokenization.FullTokenizer(vocab_file=VOCAB_FILE, do_lower_case=DO_LOWER_CASE)

tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)
run_config = tf.contrib.tpu.RunConfig(
    cluster=tpu_cluster_resolver,
    model_dir=OUTPUT_DIR,
    save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS,
    tpu_config=tf.contrib.tpu.TPUConfig(
        iterations_per_loop=ITERATIONS_PER_LOOP,
        num_shards=NUM_TPU_CORES,
        per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))

#


I0825 17:54:57.573493 139820775229312 <ipython-input-8-74853bd7f2a8>:9] Using query fields desc
I0825 17:54:57.577228 139820775229312 <ipython-input-8-74853bd7f2a8>:14] Train Folds: [1, 2, 3, 4]
I0825 17:54:57.579711 139820775229312 <ipython-input-8-74853bd7f2a8>:15] Dev Fold: 4
I0825 17:54:57.582394 139820775229312 <ipython-input-8-74853bd7f2a8>:16] Test Fold: 5
W0825 17:54:57.584985 139820775229312 deprecation_wrapper.py:119] From bert_repo6/tokenization.py:125: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.



In [None]:
train_examples = processor.get_train_examples(TASK_DATA_DIR)
num_train_steps = int(len(train_examples) / TRAIN_BATCH_SIZE * NUM_TRAIN_EPOCHS)
num_warmup_steps = int(num_train_steps * WARMUP_PROPORTION)

model_fn = model_fn_builder(
    bert_config=modeling.BertConfig.from_json_file(CONFIG_FILE),
    num_labels=len(label_list),
    init_checkpoint=INIT_CHECKPOINT,
    learning_rate=LEARNING_RATE,
    num_train_steps=num_train_steps,
    num_warmup_steps=num_warmup_steps,
    use_tpu=True,
    use_one_hot_embeddings=True)

estimator = tf.contrib.tpu.TPUEstimator(
    use_tpu=True,
    model_fn=model_fn,
    config=run_config,
    train_batch_size=TRAIN_BATCH_SIZE,
    eval_batch_size=EVAL_BATCH_SIZE,
    predict_batch_size=PREDICT_BATCH_SIZE)

I0825 17:55:07.429852 139820775229312 <ipython-input-8-74853bd7f2a8>:24] Qrel size: 17412
I0825 17:55:07.624071 139820775229312 <ipython-input-8-74853bd7f2a8>:155] Example Q: {'title': 'Islamic Revolution', 'qid': '669', 'question': 'what is Islamic Revolution', 'narr': 'Relevant documents must discuss the reasons that relations between the Islamic world and the United States have deteriorated.', 'desc_short': 'causes Islamic Revolution relative relations US', 'desc': 'What were the causes for the Islamic Revolution relative to relations with the U.S.?'}
I0825 17:55:07.625161 139820775229312 <ipython-input-8-74853bd7f2a8>:155] Example Q: {'title': 'poverty, disease', 'qid': '668', 'question': 'what is the relation ship between poverty and disease', 'narr': 'Documents that do not link poverty to diseases directly but mention a link between poverty and health care are relevant. Documents that simply mention poverty and disease but do not draw a connection are not relevant.', 'desc_short'

KeyboardInterrupt: ignored

In [None]:
# Train the model.
train_file = os.path.join(OUTPUT_DIR, "train.tf_record")
#train_features = file_based_convert_examples_to_features(
#    train_examples, label_list, MAX_SEQCONCAT_LENGTH, tokenizer, train_file)



In [None]:
train_file = os.path.join(OUTPUT_DIR, "train.tf_record")

print('***** Started training at {} *****'.format(datetime.datetime.now()))
print('  Batch size = {}'.format(TRAIN_BATCH_SIZE))
tf.logging.info("  Num steps = %d", num_train_steps)
train_input_fn = file_based_input_fn_builder(
    input_file=train_file,
    seq_length=MAX_SEQCONCAT_LENGTH,
    is_training=True,
    drop_remainder=True)
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
print('***** Finished training at {} *****'.format(datetime.datetime.now()))

***** Started training at 2019-02-18 14:04:37.471168 *****
  Batch size = 16
INFO:tensorflow:  Num steps = 23407
INFO:tensorflow:Querying Tensorflow master (grpc://10.21.22.90:8470) for TPU system metadata.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 13578720447639949201)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 14271873016053651699)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 1639536778088073386)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 10243009842499693073)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/tas

# New Section

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
tf.logging.set_verbosity(tf.logging.INFO)

predict_examples = processor.get_test_examples(TASK_DATA_DIR)
num_actual_predict_examples = len(predict_examples)
assert num_actual_predict_examples > 0
predict_batch_size = 32
while len(predict_examples) % predict_batch_size != 0:
  predict_examples.append(PaddingInputExample())

predict_file = os.path.join(OUTPUT_DIR, "predict.tf_record")
file_based_convert_examples_to_features(predict_examples, label_list,
                                        MAX_SEQCONCAT_LENGTH, tokenizer,
                                        predict_file)

INFO:tensorflow:Qrel size: 17412
INFO:tensorflow:Example Q: {'title': 'Islamic Revolution', 'qid': '669', 'question': 'what is Islamic Revolution', 'narr': 'Relevant documents must discuss the reasons that relations between the Islamic world and the United States have deteriorated.', 'desc_short': 'causes Islamic Revolution relative relations US', 'desc': 'What were the causes for the Islamic Revolution relative to relations with the U.S.?'}
INFO:tensorflow:Example Q: {'title': 'poverty, disease', 'qid': '668', 'question': 'what is the relation ship between poverty and disease', 'narr': 'Documents that do not link poverty to diseases directly but mention a link between poverty and health care are relevant. Documents that simply mention poverty and disease but do not draw a connection are not relevant.', 'desc_short': 'relationship poverty disease', 'desc': 'What is the relationship between poverty and disease?'}
INFO:tensorflow:Example Q: {'title': 'unmarried-partner households', 'qid'

In [None]:

tf.logging.set_verbosity(tf.logging.ERROR)

#predict_file="gs://bertir/bert/models/marco/predict.tf_record"
#num_actual_predict_examples = 999240
tf.logging.info("***** Running prediction*****")
#tf.logging.info("  Num examples = %d (%d actual, %d padding)",
#                    len(predict_examples), num_actual_predict_examples,
#                    len(predict_examples) - num_actual_predict_examples)
tf.logging.info("  Batch size = %d", PREDICT_BATCH_SIZE)

predict_drop_remainder = True 
predict_input_fn = file_based_input_fn_builder(
        input_file=predict_file,
        seq_length=MAX_SEQCONCAT_LENGTH,
        is_training=False,
        drop_remainder=predict_drop_remainder)

result = estimator.predict(input_fn=predict_input_fn)

output_predict_file = os.path.join(OUTPUT_DIR, "test_results.tsv")
with tf.gfile.GFile(output_predict_file, "w") as writer:
  num_written_lines = 0
  tf.logging.info("***** Predict results *****")
  for (i, prediction) in enumerate(result):
    probabilities = prediction["probabilities"]
    if i >= num_actual_predict_examples:
      break
    output_line = "\t".join(
            str(class_probability)
            for class_probability in probabilities) + "\n"
    writer.write(output_line)
    num_written_lines += 1
    if num_written_lines % 100000 == 0:
      print(num_written_lines)
assert num_written_lines == num_actual_predict_examples

# New Section

In [None]:
# Eval the model.
eval_examples = processor.get_dev_examples(TASK_DATA_DIR)
eval_features = run_classifier.convert_examples_to_features(
    eval_examples, label_list, MAX_SEQ_LENGTH, tokenizer)
print('***** Started evaluation at {} *****'.format(datetime.datetime.now()))
print('  Num examples = {}'.format(len(eval_examples)))
print('  Batch size = {}'.format(EVAL_BATCH_SIZE))
# Eval will be slightly WRONG on the TPU because it will truncate
# the last batch.
eval_steps = int(len(eval_examples) / EVAL_BATCH_SIZE)
eval_input_fn = run_classifier.input_fn_builder(
    features=eval_features,
    seq_length=MAX_SEQ_LENGTH,
    is_training=False,
    drop_remainder=True)
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
print('***** Finished evaluation at {} *****'.format(datetime.datetime.now()))
output_eval_file = os.path.join(OUTPUT_DIR, "eval_results.txt")
with tf.gfile.GFile(output_eval_file, "w") as writer:
  print("***** Eval results *****")
  for key in sorted(result.keys()):
    print('  {} = {}'.format(key, str(result[key])))
    writer.write("%s = %s\n" % (key, str(result[key])))

AttributeError: ignored