# BERT finetuning tasks in 5 minutes with Cloud TPU

<table class="tfo-notebook-buttons" align="left" >
 <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/tpu/blob/master/tools/colab/bert_finetuning_with_cloud_tpus.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/tpu/blob/master/tools/colab/bert_finetuning_with_cloud_tpus.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>


**BERT**, or **B**idirectional **E**mbedding **R**epresentations from **T**ransformers, is a new method of pre-training language representations which obtains state-of-the-art results on a wide array of Natural Language Processing (NLP) tasks. The academic paper can be found here: https://arxiv.org/abs/1810.04805.

This Colab demonstates using a free Colab Cloud TPU to fine-tune sentence and sentence-pair classification tasks built on top of pretrained BERT models.

**Note:**  You will need a GCP (Google Compute Engine) account and a GCS (Google Cloud 
Storage) bucket for this Colab to run.

Please follow the [Google Cloud TPU quickstart](https://cloud.google.com/tpu/docs/quickstart) for how to create GCP account and GCS bucket. You have [$300 free credit](https://cloud.google.com/free/) to get started with any GCP product. You can learn more about Cloud TPU at https://cloud.google.com/tpu/docs.

Once you finish the setup, let's start!

**Firstly**, we need to set up Colab TPU running environment, verify a TPU device is succesfully connected and upload credentials to TPU for GCS bucket usage.

In [None]:
import datetime
import json
import os
import pprint
import random
import string
import sys

%tensorflow_version 1.x
import tensorflow as tf

assert 'COLAB_TPU_ADDR' in os.environ, 'ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!'
TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']
print('TPU address is', TPU_ADDRESS)

from google.colab import auth
auth.authenticate_user()
with tf.Session(TPU_ADDRESS) as session:
  print('TPU devices:')
  pprint.pprint(session.list_devices())

  # Upload credentials to TPU.
  with open('/content/adc.json', 'r') as f:
    auth_info = json.load(f)
  tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
  # Now credentials are set for all future sessions on this TPU.

TensorFlow 1.x selected.
TPU address is grpc://10.1.181.170:8470
TPU devices:
[_DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:CPU:0, CPU, -1, 6664397363455324686),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 16788073122837000344),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 1903484303307337696),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 15451089682091094059),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 17712571090995915926),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 3790129473032030317),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, 18417493745187398433),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 11194780545641338139),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:6, TPU,

#Data Loader, Model Preperation

**Secondly**, prepare and import BERT modules.

In [None]:
import sys
!rm -r bert_repo6
!test -d bert_repo6 || git clone https://github.com/AdeDZY/SIGIR19-BERT-IR bert_repo6
if not 'bert_repo6' in sys.path:
  sys.path += ['bert_repo6']
if not '.' in sys.path:
  sys.path += ['.']

rm: cannot remove 'bert_repo6': No such file or directory
Cloning into 'bert_repo6'...
remote: Enumerating objects: 87, done.[K
remote: Total 87 (delta 0), reused 0 (delta 0), pack-reused 87[K
Unpacking objects: 100% (87/87), done.


In [None]:
sys.path = ['',
 '/env/python',
 '/usr/lib/python36.zip',
 '/usr/lib/python3.6',
 '/usr/lib/python3.6/lib-dynload',
 '/usr/local/lib/python3.6/dist-packages',
 '/usr/lib/python3/dist-packages',
 '/usr/local/lib/python3.6/dist-packages/IPython/extensions',
 '/root/.ipython',
 '.',
 'bert_repo6']

In [None]:
sys.path

['',
 '/env/python',
 '/usr/lib/python36.zip',
 '/usr/lib/python3.6',
 '/usr/lib/python3.6/lib-dynload',
 '/usr/local/lib/python3.6/dist-packages',
 '/usr/lib/python3/dist-packages',
 '/usr/local/lib/python3.6/dist-packages/IPython/extensions',
 '/root/.ipython',
 '.',
 'bert_repo6']

In [None]:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import csv
import os
import modeling
import optimization
import tokenization
import tensorflow as tf
import random
import json
from run_qe_classifier import *
import random


AttributeError: ignored

In [None]:
# FOLD: 1, 2, 3, 4, 5
# QUERY_FIELD: desc, title
FOLD=5 #@param {type:"integer"}
QUERY_FIELD="desc" #@param {type:"string"}

In [None]:
class MyClueWebProcessor(DataProcessor):

    def __init__(self):
        self.max_test_depth = 100
        self.max_train_depth = 100
        self.n_folds = 5
        self.fold = FOLD
        self.q_fields = QUERY_FIELD.split(' ')
        tf.logging.info("Using query fields {}".format(' '.join(self.q_fields)))

        self.train_folds = [(self.fold + i) % self.n_folds + 1 for i in range(self.n_folds - 1)]
        self.test_folds = (self.fold + self.n_folds - 1) % self.n_folds + 1
        tf.logging.info("Train Folds: {}".format(str(self.train_folds)))
        tf.logging.info("Test Fold: {}".format(str(self.test_folds)))

    def get_train_examples(self, data_dir):
        examples = []
        train_files = ["{}.trec.with_json".format(i) for i in self.train_folds]

        qrel_file = tf.gfile.Open(os.path.join(data_dir, "qrels"))
        qrels = self._read_qrel(qrel_file)
        tf.logging.info("Qrel size: {}".format(len(qrels)))

        query_file = tf.gfile.Open(os.path.join(data_dir, "queries.json"))
        qid2queries = self._read_queries(query_file)
        tf.logging.info("Loaded {} queries.".format(len(qid2queries)))

        for file_name in train_files:
            train_file = tf.gfile.Open(os.path.join(data_dir, file_name))
            for i, line in enumerate(train_file):
                items = line.strip().split('#')
                trec_line = items[0]

                qid, _, docid, r, _, _ = trec_line.strip().split(' ')
                assert qid in qid2queries, "QID {} not found".format(qid)
                q_json_dict = qid2queries[qid]
                q_text_list = [tokenization.convert_to_unicode(q_json_dict[field]) for field in self.q_fields]

                json_dict = json.loads('#'.join(items[1:]))
                body_words = json_dict["doc"]["body"].split(' ')
                truncated_body = ' '.join(body_words[0: min(200, len(body_words))])
                d = tokenization.convert_to_unicode(json_dict["doc"].get("title", "") + ". " + truncated_body)

                r = int(r)
                if r > self.max_train_depth:
                    continue
                label = tokenization.convert_to_unicode("0")
                if (qid, docid) in qrels or (qid, docid.split('_')[0]) in qrels:
                    label = tokenization.convert_to_unicode("1")
                guid = "train-%s-%s" % (qid, docid)
                examples.append(
                    InputExample(guid=guid, text_a_list=q_text_list, text_b=d, label=label)
                )
            train_file.close()
        random.shuffle(examples)
        return examples


    def get_test_examples(self, data_dir):
        examples = []
        dev_file = tf.gfile.Open(os.path.join(data_dir, "{}.trec.with_json".format(self.test_folds)))
        qrel_file = tf.gfile.Open(os.path.join(data_dir, "qrels"))
        qrels = self._read_qrel(qrel_file)
        tf.logging.info("Qrel size: {}".format(len(qrels)))

        query_file = tf.gfile.Open(os.path.join(data_dir, "queries.json"))
        qid2queries = self._read_queries(query_file)
        tf.logging.info("Loaded {} queries.".format(len(qid2queries)))

        for i, line in enumerate(dev_file):
            items = line.strip().split('#')
            trec_line = items[0]

            qid, _, docid, r, _, _ = trec_line.strip().split(' ')
            assert qid in qid2queries, "QID {} not found".format(qid)
            q_json_dict = qid2queries[qid]
            q_text_list = [tokenization.convert_to_unicode(q_json_dict[field]) for field in self.q_fields]

            json_dict = json.loads('#'.join(items[1:]))
            body_words = json_dict["doc"]["body"].split(' ')
            truncated_body =  ' '.join(body_words[0: min(200, len(body_words))])
            
            # we use the concatentation of title and document first 200 tokens
            d = tokenization.convert_to_unicode(json_dict["doc"].get("title", "") + ". " + truncated_body)

            r = int(r)
            if r > self.max_test_depth:
                continue
            label = tokenization.convert_to_unicode("0")
            if (qid, docid) in qrels or (qid, docid.split('_')[0]) in qrels:
                label = tokenization.convert_to_unicode("1")
            guid = "test-%s-%s" % (qid, docid)
            examples.append(
                InputExample(guid=guid, text_a_list=q_text_list, text_b=d, label=label)
            )
        dev_file.close()
        return examples

    def _read_qrel(self, qrel_file):
        qrels = set()
        for line in qrel_file:
            qid, _, docid, rel = line.strip().split(' ')
            rel = int(rel)
            if rel > 0:
                qrels.add((qid, docid))
        return qrels

    def _read_queries(self, query_file):
        qid2queries = {}
        for i, line in enumerate(query_file):
            json_dict = json.loads(line)
            qid = json_dict['qid']
            qid2queries[qid] = json_dict
            if i < 3:
              tf.logging.info("Example Q: {}".format(json_dict))
        return qid2queries
   
    def get_labels(self):
        return ["0", "1"]


NameError: ignored

**Thirdly**, prepare for training:

*  Specify task and download training data.
*  Specify BERT pretrained model
*  Specify GS bucket, create output directory for model checkpoints and eval results.



In [None]:
TASK = 'cw-descinit-firstp' #@param {type:"string"}

# Available pretrained model checkpoints:
#   uncased_L-12_H-768_A-12: uncased BERT base model
#   uncased_L-24_H-1024_A-16: uncased BERT large model
#   cased_L-12_H-768_A-12: cased BERT large model
BERT_MODEL = 'uncased_L-12_H-768_A-12' #@param {type:"string"}
BERT_PRETRAINED_DIR = 'gs://cloud-tpu-checkpoints/bert/' + BERT_MODEL
print('***** BERT pretrained directory: {} *****'.format(BERT_PRETRAINED_DIR))
!gsutil ls $BERT_PRETRAINED_DIR

BUCKET = 'bertir' #@param {type:"string"}
assert BUCKET, 'Must specify an existing GCS bucket name'
OUTPUT_DIR = 'gs://{}/bert/models/{}-fold{}'.format(BUCKET, TASK, FOLD)
tf.gfile.MakeDirs(OUTPUT_DIR)
print('***** Model output directory: {} *****'.format(OUTPUT_DIR))

DATA_DIR = "clueweb09/cv_descinit/" #@param {type:"string"}
TASK_DATA_DIR = 'gs://{}/{}'.format(BUCKET, DATA_DIR) 

!gsutil ls $TASK_DATA_DIR


***** BERT pretrained directory: gs://cloud-tpu-checkpoints/bert/uncased_L-12_H-768_A-12 *****
gs://cloud-tpu-checkpoints/bert/uncased_L-12_H-768_A-12/bert_config.json
gs://cloud-tpu-checkpoints/bert/uncased_L-12_H-768_A-12/bert_model.ckpt.data-00000-of-00001
gs://cloud-tpu-checkpoints/bert/uncased_L-12_H-768_A-12/bert_model.ckpt.index
gs://cloud-tpu-checkpoints/bert/uncased_L-12_H-768_A-12/bert_model.ckpt.meta
gs://cloud-tpu-checkpoints/bert/uncased_L-12_H-768_A-12/checkpoint
gs://cloud-tpu-checkpoints/bert/uncased_L-12_H-768_A-12/vocab.txt
***** Model output directory: gs://bertir/bert/models/cw-descinit-first200-fold5 *****
gs://bertir/clueweb09/cv_descinit/1.trec.with_json
gs://bertir/clueweb09/cv_descinit/2.trec.with_json
gs://bertir/clueweb09/cv_descinit/3.trec.with_json
gs://bertir/clueweb09/cv_descinit/4.trec.with_json
gs://bertir/clueweb09/cv_descinit/5.trec.with_json
gs://bertir/clueweb09/cv_descinit/qrels
gs://bertir/clueweb09/cv_descinit/queries.json


# Train

In [None]:
# Setup task specific model and TPU running config.

import modeling
import optimization
import tokenization

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' 

tf.logging.set_verbosity(tf.logging.INFO)


# Model Hyper Parameters
TRAIN_BATCH_SIZE = 16
EVAL_BATCH_SIZE = 32
PREDICT_BATCH_SIZE = 32

LEARNING_RATE = 1e-5
NUM_TRAIN_EPOCHS = 1.0
WARMUP_PROPORTION = 0.1
MAX_SEQCONCAT_LENGTH = 200

# Model configs
SAVE_CHECKPOINTS_STEPS = 20000
ITERATIONS_PER_LOOP = 1000
NUM_TPU_CORES = 8
VOCAB_FILE = os.path.join(BERT_PRETRAINED_DIR, 'vocab.txt')
CONFIG_FILE = os.path.join(BERT_PRETRAINED_DIR, 'bert_config.json')
INIT_CHECKPOINT = os.path.join(BERT_PRETRAINED_DIR, 'bert_model.ckpt')
DO_LOWER_CASE = BERT_MODEL.startswith('uncased')


processor = MyClueWebProcessor()
label_list = processor.get_labels()
tokenizer = tokenization.FullTokenizer(vocab_file=VOCAB_FILE, do_lower_case=DO_LOWER_CASE)

tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)
run_config = tf.contrib.tpu.RunConfig(
    cluster=tpu_cluster_resolver,
    model_dir=OUTPUT_DIR,
    save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS,
    tpu_config=tf.contrib.tpu.TPUConfig(
        iterations_per_loop=ITERATIONS_PER_LOOP,
        num_shards=NUM_TPU_CORES,
        per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))

#


INFO:tensorflow:Using query fields desc
INFO:tensorflow:Train Folds: [1, 2, 3, 4]
INFO:tensorflow:Dev Fold: 4
INFO:tensorflow:Test Fold: 5


In [None]:
train_examples = processor.get_train_examples(TASK_DATA_DIR)
num_train_steps = int(len(train_examples) / TRAIN_BATCH_SIZE * NUM_TRAIN_EPOCHS)
num_warmup_steps = int(num_train_steps * WARMUP_PROPORTION)

model_fn = model_fn_builder(
    bert_config=modeling.BertConfig.from_json_file(CONFIG_FILE),
    num_labels=len(label_list),
    init_checkpoint=INIT_CHECKPOINT,
    learning_rate=LEARNING_RATE,
    num_train_steps=num_train_steps,
    num_warmup_steps=num_warmup_steps,
    use_tpu=True,
    use_one_hot_embeddings=True)

estimator = tf.contrib.tpu.TPUEstimator(
    use_tpu=True,
    model_fn=model_fn,
    config=run_config,
    train_batch_size=TRAIN_BATCH_SIZE,
    eval_batch_size=EVAL_BATCH_SIZE,
    predict_batch_size=PREDICT_BATCH_SIZE)

INFO:tensorflow:Qrel size: 9863
INFO:tensorflow:Example Q: {'qid': '1', 'desc': "Find information on President Barack Obama's family history, including genealogy, national origins, places and dates of birth, etc.", 'subtopics': ['Find the TIME magazine photo essay "Barack Obama\'s Family Tree".', "Where did Barack Obama's parents and grandparents come from?", "Find biographical information on Barack Obama's mother."], 'title': 'obama family tree'}
INFO:tensorflow:Example Q: {'qid': '2', 'desc': 'Find information on French Lick Resort and Casino in Indiana.', 'subtopics': ['Find the homepage for French Lick Resort and Casino.', "What casinos are located within a day's drive of French Lick Resort and Casino?", 'What jobs are available at French Lick Casino and Resort?', 'Are there discounted packages for staying at French Lick Resort and Casino?'], 'title': 'french lick resort and casino'}
INFO:tensorflow:Example Q: {'qid': '3', 'desc': 'Find tips, resources, supplies for getting organiz

In [None]:
# Train the model.
train_file = os.path.join(OUTPUT_DIR, "train.tf_record")
train_features = file_based_convert_examples_to_features(
    train_examples, label_list, MAX_SEQCONCAT_LENGTH, tokenizer, train_file)



INFO:tensorflow:Writing example 0 of 15769
INFO:tensorflow:*** Example ***
INFO:tensorflow:guid: train-154-clueweb09-en0006-52-16869
INFO:tensorflow:tokens: [CLS] find information on nutritional or health benefits of fig ##s . [SEP] nutritional supplements . web buzz ##le . com home world news latest articles escape hatch topics free ec ##ards endless buzz topics amino acid supplement anti ##ox ##ida ##nt supplement anti ##ox ##ida ##nts free radicals healthy eating healthy recipes minerals nutrition facts sports nutrition vitamin ##s nutritional supplements causes of magnesium deficiency and treatment magnesium is one of the essential minerals required by the body . magnesium deficiency can lead to various disorders . here is a discussion about the causes of magnesium deficiency and how it can be treated with dietary intake and magnesium supplements . side effects of fl ##ax ##see ##d oil fl ##ax ##see ##d oil has been h ##ype ##d as great nutritional supplement over the years without

In [None]:
train_file = os.path.join(OUTPUT_DIR, "train.tf_record")

print('***** Started training at {} *****'.format(datetime.datetime.now()))
print('  Batch size = {}'.format(TRAIN_BATCH_SIZE))
tf.logging.info("  Num steps = %d", num_train_steps)
train_input_fn = file_based_input_fn_builder(
    input_file=train_file,
    seq_length=MAX_SEQCONCAT_LENGTH,
    is_training=True,
    drop_remainder=True)
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
print('***** Finished training at {} *****'.format(datetime.datetime.now()))

***** Started training at 2019-02-10 21:56:46.205307 *****
  Batch size = 16
INFO:tensorflow:  Num steps = 985
INFO:tensorflow:Querying Tensorflow master (grpc://10.110.95.26:8470) for TPU system metadata.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 17357500703184256678)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 18095123436672065651)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 4425371603390326467)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 16432993633361050829)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task

# Inference

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
tf.logging.set_verbosity(tf.logging.INFO)

predict_examples = processor.get_test_examples(TASK_DATA_DIR)
num_actual_predict_examples = len(predict_examples)
assert num_actual_predict_examples > 0
predict_batch_size = 32
while len(predict_examples) % predict_batch_size != 0:
  predict_examples.append(PaddingInputExample())

predict_file = os.path.join(OUTPUT_DIR, "predict.tf_record")
file_based_convert_examples_to_features(predict_examples, label_list,
                                        MAX_SEQCONCAT_LENGTH, tokenizer,
                                        predict_file)

INFO:tensorflow:Qrel size: 9863
INFO:tensorflow:Example Q: {'qid': '1', 'desc': "Find information on President Barack Obama's family history, including genealogy, national origins, places and dates of birth, etc.", 'subtopics': ['Find the TIME magazine photo essay "Barack Obama\'s Family Tree".', "Where did Barack Obama's parents and grandparents come from?", "Find biographical information on Barack Obama's mother."], 'title': 'obama family tree'}
INFO:tensorflow:Example Q: {'qid': '2', 'desc': 'Find information on French Lick Resort and Casino in Indiana.', 'subtopics': ['Find the homepage for French Lick Resort and Casino.', "What casinos are located within a day's drive of French Lick Resort and Casino?", 'What jobs are available at French Lick Casino and Resort?', 'Are there discounted packages for staying at French Lick Resort and Casino?'], 'title': 'french lick resort and casino'}
INFO:tensorflow:Example Q: {'qid': '3', 'desc': 'Find tips, resources, supplies for getting organiz

In [None]:

tf.logging.set_verbosity(tf.logging.ERROR)
predict_file = os.path.join(OUTPUT_DIR, "predict.tf_record")

tf.logging.info("***** Running prediction*****")
tf.logging.info("  Batch size = %d", PREDICT_BATCH_SIZE)

predict_drop_remainder = True 
predict_input_fn = file_based_input_fn_builder(
        input_file=predict_file,
        seq_length=MAX_SEQCONCAT_LENGTH,
        is_training=False,
        drop_remainder=predict_drop_remainder)

result = estimator.predict(input_fn=predict_input_fn)

output_predict_file = os.path.join(OUTPUT_DIR, "test_results.tsv")
with tf.gfile.GFile(output_predict_file, "w") as writer:
  num_written_lines = 0
  tf.logging.info("***** Predict results *****")
  for (i, prediction) in enumerate(result):
    probabilities = prediction["probabilities"]
    if i >= num_actual_predict_examples:
      break
    output_line = "\t".join(
            str(class_probability)
            for class_probability in probabilities) + "\n"
    writer.write(output_line)
    num_written_lines += 1
    if num_written_lines % 100000 == 0:
      print(num_written_lines)
assert num_written_lines == num_actual_predict_examples