<a href="https://colab.research.google.com/github/silverstar0727/Today_I_Learned/blob/master/Pre_training_BERT_from_scratch_with_cloud_TPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pre-training BERT from scratch with cloud TPU

In this experiment, we will be pre-training a state-of-the-art Natural Language Understanding model [BERT](https://arxiv.org/abs/1810.04805.) on arbitrary text data using Google Cloud infrastructure.

This guide covers all stages of the procedure, including:

1. Setting up the training environment
2. Downloading raw text data
3. Preprocessing text data
4. Learning a new vocabulary
5. Creating sharded pre-training data
6. Setting up GCS storage for data and model
7. Training the model on a cloud TPU

For persistent storage of training data and model, you will require a Google Cloud Storage bucket. 
Please follow the [Google Cloud TPU quickstart](https://cloud.google.com/tpu/docs/quickstart) to create a GCP account and GCS bucket. New Google Cloud users have [$300 free credit](https://cloud.google.com/free/) to get started with any GCP product. 

Steps 1-5 of this tutorial can be run without a GCS bucket for demonstration purposes. In that case, however, you will not be able to train the model.

**Note** 
The only parameter you *really have to set* is BUCKET_NAME in steps 5 and 6. Everything else has default values which should work for most use-cases.

**Note** 
Pre-training a BERT-Base model on a TPUv2 will take about 54 hours. Google Colab is not designed for executing such long-running jobs and will interrupt the training process every 8 hours or so. For uninterrupted training, consider using a preemptible TPUv2 instance. 

That said, at the time of writing (09.05.2019), with a Colab TPU, pre-training a BERT model from scratch can be achieved at a negligible cost of storing the said model and data in GCS  (~1 USD).

Now, let's get to business.

MIT License

Copyright (c) [2019] [Antyukhov Denis Olegovich]

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

## Step 1: setting up training environment
First and foremost, we get the packages required to train the model. 
The Jupyter environment allows executing bash commands directly from the notebook by using an exclamation mark ‘!’. I will be exploiting this approach to make use of several other bash commands throughout the experiment.

Now, let’s import the packages and authorize ourselves in Google Cloud.

In [None]:
!pip install sentencepiece
!git clone https://github.com/google-research/bert

import os
import sys
import json
import nltk
import random
import logging
import tensorflow as tf
import sentencepiece as spm

from glob import glob
from google.colab import auth, drive
from tensorflow.keras.utils import Progbar

sys.path.append("bert")

from bert import modeling, optimization, tokenization
from bert.run_pretraining import input_fn_builder, model_fn_builder

auth.authenticate_user()
  
# configure logging
log = logging.getLogger('tensorflow')
log.setLevel(logging.INFO)

# create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s :  %(message)s')
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
sh.setFormatter(formatter)
log.handlers = [sh]

if 'COLAB_TPU_ADDR' in os.environ:
  log.info("Using TPU runtime")
  USE_TPU = True
  TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']

  with tf.Session(TPU_ADDRESS) as session:
    log.info('TPU address is ' + TPU_ADDRESS)
    # Upload credentials to TPU.
    with open('/content/adc.json', 'r') as f:
      auth_info = json.load(f)
    tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
    
else:
  log.warning('Not connected to TPU runtime')
  USE_TPU = False

Collecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/14/3d/efb655a670b98f62ec32d66954e1109f403db4d937c50d779a75b9763a29/sentencepiece-0.1.83-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)
[K     |████████████████████████████████| 1.0MB 3.5MB/s 
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.1.83
Cloning into 'bert'...
remote: Enumerating objects: 333, done.[K
remote: Total 333 (delta 0), reused 0 (delta 0), pack-reused 333[K
Receiving objects: 100% (333/333), 282.45 KiB | 3.82 MiB/s, done.
Resolving deltas: 100% (183/183), done.

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



2019-09-18 12:09:36,959 :  Using TPU runtime
2019-09-18 12:09:36,961 :  TPU address is grpc://10.91.210.154:8470


## Step 2: getting the data

We begin with obtaining a corpus of raw text data. For this experiment, we will be using the [OpenSubtitles](http://www.opensubtitles.org/) dataset, which is available for 65 languages [here](http://opus.nlpl.eu/OpenSubtitles-v2016.php). 

Unlike more common text datasets (like Wikipedia) it does not require any complex pre-processing. It also comes pre-formatted with one sentence per line.

Feel free to use the dataset for your language instead by changing the language code (en) below.

In [None]:
AVAILABLE =  {'af','ar','bg','bn','br','bs','ca','cs',
              'da','de','el','en','eo','es','et','eu',
              'fa','fi','fr','gl','he','hi','hr','hu',
              'hy','id','is','it','ja','ka','kk','ko',
              'lt','lv','mk','ml','ms','nl','no','pl',
              'pt','pt_br','ro','ru','si','sk','sl','sq',
              'sr','sv','ta','te','th','tl','tr','uk',
              'ur','vi','ze_en','ze_zh','zh','zh_cn',
              'zh_en','zh_tw','zh_zh'}

LANG_CODE = "en" #@param {type:"string"}

assert LANG_CODE in AVAILABLE, "Invalid language code selected"

!wget http://opus.nlpl.eu/download.php?f=OpenSubtitles/v2016/mono/OpenSubtitles.raw.'$LANG_CODE'.gz -O dataset.txt.gz
!gzip -d dataset.txt.gz
!tail dataset.txt

--2019-09-18 12:10:08--  http://opus.nlpl.eu/download.php?f=OpenSubtitles/v2016/mono/OpenSubtitles.raw.en.gz
Resolving opus.nlpl.eu (opus.nlpl.eu)... 193.166.25.9
Connecting to opus.nlpl.eu (opus.nlpl.eu)|193.166.25.9|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://object.pouta.csc.fi/OPUS-OpenSubtitles/v2016/mono/en.txt.gz [following]
--2019-09-18 12:10:09--  https://object.pouta.csc.fi/OPUS-OpenSubtitles/v2016/mono/en.txt.gz
Resolving object.pouta.csc.fi (object.pouta.csc.fi)... 86.50.254.1, 86.50.254.0
Connecting to object.pouta.csc.fi (object.pouta.csc.fi)|86.50.254.1|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2906709304 (2.7G) [application/gzip]
Saving to: ‘dataset.txt.gz’


2019-09-18 12:12:28 (20.1 MB/s) - ‘dataset.txt.gz’ saved [2906709304/2906709304]

The astronomers run at full speed turning around each time they are pressed too closely and reducing the fragile beings to dust.
At last, the astronomers have f

For demonstration purposes, we will only use a small fraction of the whole corpus for this experiment. 

When training the real model, make sure to uncheck the DEMO_MODE checkbox to use a 100x larger dataset.

Rest assured, 100M lines are perfectly sufficient to train a reasonably good BERT-base model.

In [None]:
DEMO_MODE = True #@param {type:"boolean"}

if DEMO_MODE:
  CORPUS_SIZE = 1000000
else:
  CORPUS_SIZE = 100000000 #@param {type: "integer"}
  
!(head -n $CORPUS_SIZE dataset.txt) > subdataset.txt
!mv subdataset.txt dataset.txt

## Step 3: preprocessing text

The raw text data we have downloaded contains punсtuation, uppercase letters and non-UTF symbols which we will remove before proceeding. During inference, we will apply the same normalization procedure to new data.

If your use-case requires different preprocessing (e.g. if uppercase letters or punctuation are expected during inference), feel free to modify the function below to accomodate for your needs.

In [None]:
regex_tokenizer = nltk.RegexpTokenizer("\w+")

def normalize_text(text):
  # lowercase text
  text = str(text).lower()
  # remove non-UTF
  text = text.encode("utf-8", "ignore").decode()
  # remove punktuation symbols
  text = " ".join(regex_tokenizer.tokenize(text))
  return text

def count_lines(filename):
  count = 0
  with open(filename) as fi:
    for line in fi:
      count += 1
  return count

Check how that works.

In [None]:
normalize_text('Thanks to the advance, they have succeeded in getting over their adversaries.')

'thanks to the advance they have succeeded in getting over their adversaries'

Apply normalization to the whole dataset.

In [None]:
RAW_DATA_FPATH = "dataset.txt" #@param {type: "string"}
PRC_DATA_FPATH = "proc_dataset.txt" #@param {type: "string"}

# apply normalization to the dataset
# this will take a minute or two

total_lines = count_lines(RAW_DATA_FPATH)
bar = Progbar(total_lines)

with open(RAW_DATA_FPATH,encoding="utf-8") as fi:
  with open(PRC_DATA_FPATH, "w",encoding="utf-8") as fo:
    for l in fi:
      fo.write(normalize_text(l)+"\n")
      bar.add(1)



## Step 4: building the vocabulary

For the next step, we will learn a new vocabulary that we will use to represent our dataset. 

The BERT paper uses a WordPiece tokenizer, which is not available in opensource. Instead, we will be using SentencePiece tokenizer in unigram mode. While it is not directly compatible with BERT, with a small hack we can make it work.

SentencePiece requires quite a lot of RAM, so running it on the full dataset in Colab will crash the kernel. To avoid this, we will randomly subsample a fraction of the dataset for building the vocabulary. Another option would be to use a machine with more RAM for this step - that decision is up to you.

Also, SentencePiece adds BOS and EOS control symbols to the vocabulary by default. We disable them explicitly by setting their indices to -1.

The typical values for VOC_SIZE are somewhere in between 32000 and 128000. We reserve NUM_PLACEHOLDERS tokens in case one wants to update the vocabulary and fine-tune the model after the pre-training phase is finished.

In [None]:
MODEL_PREFIX = "tokenizer" #@param {type: "string"}
VOC_SIZE = 32000 #@param {type:"integer"}
SUBSAMPLE_SIZE = 12800000 #@param {type:"integer"}
NUM_PLACEHOLDERS = 256 #@param {type:"integer"}

SPM_COMMAND = ('--input={} --model_prefix={} '
               '--vocab_size={} --input_sentence_size={} '
               '--shuffle_input_sentence=true ' 
               '--bos_id=-1 --eos_id=-1').format(
               PRC_DATA_FPATH, MODEL_PREFIX, 
               VOC_SIZE - NUM_PLACEHOLDERS, SUBSAMPLE_SIZE)

spm.SentencePieceTrainer.Train(SPM_COMMAND)

True

Now let's see how we can make SentencePiece tokenizer work for the BERT model. 

Below is a sentence tokenized using the WordPiece vocabulary from a pretrained English [BERT-base](https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip) model from the official [repo](https://github.com/google-research/bert). 

In [None]:
testcase = "Colorless geothermal substations are generating furiously"



```
>>> wordpiece.tokenize("Colorless geothermal substations are generating furiously")

['color',
 '##less',
 'geo',
 '##thermal',
 'sub',
 '##station',
 '##s',
 'are',
 'generating',
 'furiously']
```



As we can see, the WordPiece tokenizer prepends the subwords which occur in the middle of words with '##'. The subwords occurring at the beginning of words are unchanged. If the subword occurs both in the beginning and in the middle of words, both versions (with and without '##') are added to the vocabulary.

Now let's have a look at the vocabulary that the SentencePiece tokenizer has learned.

In [None]:
!ls

adc.json  dataset.txt	    sample_data      tokenizer.vocab
bert	  proc_dataset.txt  tokenizer.model


SentencePiece has created two files: tokenizer.model and tokenizer.vocab. Let's have a look at the learned vocabulary:

In [None]:
!head -n 100 tokenizer.vocab

<unk>	0
▁you	-3.2342
▁i	-3.2821
▁the	-3.56375
▁s	-3.84955
▁to	-3.87601
▁a	-3.9102
▁it	-3.97593
▁t	-4.25729
▁and	-4.32686
▁that	-4.33991
▁of	-4.57503
▁what	-4.59504
▁is	-4.67144
▁me	-4.7115
▁in	-4.71265
▁we	-4.7231
▁he	-4.88446
▁this	-4.89466
▁no	-4.98519
▁on	-4.99226
▁my	-4.99616
▁m	-5.02169
▁your	-5.02715
▁for	-5.05648
▁have	-5.06695
▁don	-5.09057
▁do	-5.12828
▁	-5.12998
▁re	-5.1371
▁can	-5.20534
▁was	-5.20891
▁know	-5.22494
▁be	-5.23557
▁are	-5.25247
▁not	-5.26685
▁all	-5.29439
▁with	-5.31925
▁but	-5.37188
▁so	-5.42124
▁get	-5.43031
▁here	-5.43446
▁just	-5.43719
▁ll	-5.49322
▁like	-5.5182
▁they	-5.5196
▁there	-5.52012
▁up	-5.55201
▁go	-5.55369
▁she	-5.5887
▁right	-5.61613
▁out	-5.64626
▁oh	-5.68517
s	-5.71583
▁come	-5.73109
▁if	-5.73944
▁him	-5.74635
▁one	-5.75141
▁about	-5.75416
▁got	-5.77074
▁at	-5.78535
▁now	-5.8265
▁yeah	-5.84144
▁how	-5.84948
▁her	-5.87603
▁well	-5.94933
▁let	-5.95131
▁good	-5.98986
▁want	-6.01295
▁ve	-6.02348
▁think	-6.04507
▁who	-6.06409
▁did	-6.09598
▁see	-6.

In [None]:
def read_sentencepiece_vocab(filepath):
  voc = []
  with open(filepath, encoding='utf-8') as fi:
    for line in fi:
      voc.append(line.split("\t")[0])
  # skip the first <unk> token
  voc = voc[1:]
  return voc

snt_vocab = read_sentencepiece_vocab("{}.vocab".format(MODEL_PREFIX))
print("Learnt vocab size: {}".format(len(snt_vocab)))
print("Sample tokens: {}".format(random.sample(snt_vocab, 10)))

Learnt vocab size: 31743
Sample tokens: ['▁beth', '▁voice', '▁hologram', '▁condom', '▁hail', '▁regret', '▁satir', '▁entering', '▁admit', '▁marion']


As we may observe, SentencePiece does quite the opposite to WordPiece. From the [documentation](https://github.com/google/sentencepiece/blob/master/README.md):


SentencePiece first escapes the whitespace with a meta-symbol "▁" (U+2581) as follows:

`Hello▁World`.

Then, this text is segmented into small pieces, for example:

`[Hello] [▁Wor] [ld] [.]`

Subwords which occur after whitespace (which are also those that most words begin with) are prepended with '▁', while others are unchanged. This excludes subwords which only occur at the beginning of sentences and nowhere else. These cases should be quite rare, however. 

So, in order to obtain a vocabulary analogous to WordPiece, we need to perform a simple conversion, removing "▁" from the tokens that contain it and adding "##"  to the ones that don't.

In [None]:
def parse_sentencepiece_token(token):
    if token.startswith("▁"):
        return token[1:]
    else:
        return "##" + token

In [None]:
bert_vocab = list(map(parse_sentencepiece_token, snt_vocab))

We also add some special control symbols which are required by the BERT architecture. By convention, we put those at the beginning of the vocabulary.

In [None]:
ctrl_symbols = ["[PAD]","[UNK]","[CLS]","[SEP]","[MASK]"]
bert_vocab = ctrl_symbols + bert_vocab

We also append some placeholder tokens to the vocabulary. Those are useful if one wishes to update the pre-trained model with new, task-specific tokens. 

In that case, the placeholder tokens are replaced with new real ones, the pre-training data is re-generated, and the model is fine-tuned on new data.

In [None]:
bert_vocab += ["[UNUSED_{}]".format(i) for i in range(VOC_SIZE - len(bert_vocab))]
print(len(bert_vocab))

32000


Finally, we write the obtained vocabulary to file.

In [None]:
VOC_FNAME = "vocab.txt" #@param {type:"string"}

with open(VOC_FNAME, "w") as fo:
  for token in bert_vocab:
    fo.write(token+"\n")

Now let's see how the new vocabulary works in practice:

In [None]:
bert_tokenizer = tokenization.FullTokenizer(VOC_FNAME)
bert_tokenizer.tokenize(testcase)

['color',
 '##less',
 'geo',
 '##ther',
 '##mal',
 'subs',
 '##tation',
 '##s',
 'are',
 'generat',
 '##ing',
 'furious',
 '##ly']

Looking good!

## Step 5: generating pre-training data

With the vocabulary at hand, we are ready to generate pre-training data for the BERT model. Since our dataset might be quite large, we will split it into shards:

In [None]:
!mkdir ./shards
!split -a 4 -l 256000 -d $PRC_DATA_FPATH ./shards/shard_
!ls ./shards/

shard_0000  shard_0001	shard_0002  shard_0003


Before we start generating, we need to set some model-specific parameters.  

In [None]:
MAX_SEQ_LENGTH = 128 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param
MAX_PREDICTIONS = 20 #@param {type:"integer"}
DO_LOWER_CASE = True #@param {type:"boolean"}
PROCESSES = 2 #@param {type:"integer"}
PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}

Now, for each shard we need to call *create_pretraining_data.py* script. To that end, we will employ the  *xargs* command. 

Running this might take quite some time depending on the size of your dataset.

In [None]:
XARGS_CMD = ("ls ./shards/ | "
             "xargs -n 1 -P {} -I{} "
             "python3 bert/create_pretraining_data.py "
             "--input_file=./shards/{} "
             "--output_file={}/{}.tfrecord "
             "--vocab_file={} "
             "--do_lower_case={} "
             "--max_predictions_per_seq={} "
             "--max_seq_length={} "
             "--masked_lm_prob={} "
             "--random_seed=34 "
             "--dupe_factor=5")

XARGS_CMD = XARGS_CMD.format(PROCESSES, '{}', '{}', PRETRAINING_DIR, '{}', 
                             VOC_FNAME, DO_LOWER_CASE, 
                             MAX_PREDICTIONS, MAX_SEQ_LENGTH, MASKED_LM_PROB)

In [None]:
tf.gfile.MkDir(PRETRAINING_DIR)
!$XARGS_CMD

## Step 6: setting up persistent storage

To preserve our hard-earned assets, we will persist them to Google Cloud Storage. Provided that you have created the GCS bucket, this should be simple.

We will create two directories in GCS, one for the data and one for the model.
In the model directory, we will put the model vocabulary and configuration file.

**Configure your BUCKET_NAME variable here before proceeding, otherwise the model and data will not be saved.**

In [None]:
BUCKET_NAME = "bert_resourses" #@param {type:"string"}
MODEL_DIR = "bert_model" #@param {type:"string"}
tf.gfile.MkDir(MODEL_DIR)

if not BUCKET_NAME:
  log.warning("WARNING: BUCKET_NAME is not set. "
              "You will not be able to train the model.")

Below is the sample hyperparameter configuration for BERT-base. Change at your own risk.

In [None]:
# use this for BERT-base

bert_base_config = {
  "attention_probs_dropout_prob": 0.1, 
  "directionality": "bidi", 
  "hidden_act": "gelu", 
  "hidden_dropout_prob": 0.1, 
  "hidden_size": 768, 
  "initializer_range": 0.02, 
  "intermediate_size": 3072, 
  "max_position_embeddings": 512, 
  "num_attention_heads": 12, 
  "num_hidden_layers": 12, 
  "pooler_fc_size": 768, 
  "pooler_num_attention_heads": 12, 
  "pooler_num_fc_layers": 3, 
  "pooler_size_per_head": 128, 
  "pooler_type": "first_token_transform", 
  "type_vocab_size": 2, 
  "vocab_size": VOC_SIZE
}

with open("{}/bert_config.json".format(MODEL_DIR), "w") as fo:
  json.dump(bert_base_config, fo, indent=2)
  
with open("{}/{}".format(MODEL_DIR, VOC_FNAME), "w") as fo:
  for token in bert_vocab:
    fo.write(token+"\n")

In [None]:
if BUCKET_NAME:
  !gsutil -m cp -r $MODEL_DIR $PRETRAINING_DIR gs://$BUCKET_NAME

Copying file://pretraining_data/shard_0003.tfrecord [Content-Type=application/octet-stream]...
Copying file://bert_model/bert_config.json [Content-Type=application/json]...
Copying file://pretraining_data/shard_0002.tfrecord [Content-Type=application/octet-stream]...
Copying file://pretraining_data/shard_0001.tfrecord [Content-Type=application/octet-stream]...
Copying file://pretraining_data/shard_0000.tfrecord [Content-Type=application/octet-stream]...
Copying file://bert_model/vocab.txt [Content-Type=text/plain]...
-
Operation completed over 6 objects/269.4 MiB.                                    


## Step 7: training the model

We are almost ready to begin training our model. If you wish  to continue an interrupted training run, you may skip steps 2-6 and proceed from here.

**Make sure that you have set the BUCKET_NAME here as well.**

In [None]:
BUCKET_NAME = "bert_resourses" #@param {type:"string"}
MODEL_DIR = "bert_model" #@param {type:"string"}
PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}
VOC_FNAME = "vocab.txt" #@param {type:"string"}

# Input data pipeline config
TRAIN_BATCH_SIZE = 128 #@param {type:"integer"}
MAX_PREDICTIONS = 20 #@param {type:"integer"}
MAX_SEQ_LENGTH = 128 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param

# Training procedure config
EVAL_BATCH_SIZE = 64
LEARNING_RATE = 2e-5
TRAIN_STEPS = 1000000 #@param {type:"integer"}
SAVE_CHECKPOINTS_STEPS = 2500 #@param {type:"integer"}
NUM_TPU_CORES = 8

if BUCKET_NAME:
  BUCKET_PATH = "gs://{}".format(BUCKET_NAME)
else:
  BUCKET_PATH = "."

BERT_GCS_DIR = "{}/{}".format(BUCKET_PATH, MODEL_DIR)
DATA_GCS_DIR = "{}/{}".format(BUCKET_PATH, PRETRAINING_DIR)

VOCAB_FILE = os.path.join(BERT_GCS_DIR, VOC_FNAME)
CONFIG_FILE = os.path.join(BERT_GCS_DIR, "bert_config.json")

INIT_CHECKPOINT = tf.train.latest_checkpoint(BERT_GCS_DIR)

bert_config = modeling.BertConfig.from_json_file(CONFIG_FILE)
input_files = tf.gfile.Glob(os.path.join(DATA_GCS_DIR,'*tfrecord'))

log.info("Using checkpoint: {}".format(INIT_CHECKPOINT))
log.info("Using {} data shards".format(len(input_files)))

2019-09-18 12:40:05,216 :  Couldn't match files for checkpoint gs://bert_resourses/bert_model/model.ckpt-102500
2019-09-18 12:40:05,651 :  Using checkpoint: None
2019-09-18 12:40:05,656 :  Using 4 data shards


Prepare the training run configuration, build the estimator and input function, power up the bass cannon.

In [None]:
model_fn = model_fn_builder(
      bert_config=bert_config,
      init_checkpoint=INIT_CHECKPOINT,
      learning_rate=LEARNING_RATE,
      num_train_steps=TRAIN_STEPS,
      num_warmup_steps=10,
      use_tpu=USE_TPU,
      use_one_hot_embeddings=True)

tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)

run_config = tf.contrib.tpu.RunConfig(
    cluster=tpu_cluster_resolver,
    model_dir=BERT_GCS_DIR,
    save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS,
    tpu_config=tf.contrib.tpu.TPUConfig(
        iterations_per_loop=SAVE_CHECKPOINTS_STEPS,
        num_shards=NUM_TPU_CORES,
        per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))

estimator = tf.contrib.tpu.TPUEstimator(
    use_tpu=USE_TPU,
    model_fn=model_fn,
    config=run_config,
    train_batch_size=TRAIN_BATCH_SIZE,
    eval_batch_size=EVAL_BATCH_SIZE)
  
train_input_fn = input_fn_builder(
        input_files=input_files,
        max_seq_length=MAX_SEQ_LENGTH,
        max_predictions_per_seq=MAX_PREDICTIONS,
        is_training=True)

2019-09-18 12:42:14,316 :  Estimator's model_fn (<function model_fn_builder.<locals>.model_fn at 0x7f242273e7b8>) includes params argument, but params are not passed to Estimator.
2019-09-18 12:42:14,325 :  Using config: {'_model_dir': 'gs://bert_resourses/bert_model', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 2500, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.91.210.154:8470"
    }
  }
}
isolate_session_state: true
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f24285ab208>, '_task_type': 'worker', '_task_id': 0, '_

Fire!

In [None]:
estimator.train(input_fn=train_input_fn, max_steps=TRAIN_STEPS)

2019-09-18 12:42:19,833 :  Couldn't match files for checkpoint gs://bert_resourses/bert_model/model.ckpt-102500
2019-09-18 12:42:19,835 :  Querying Tensorflow master (grpc://10.91.210.154:8470) for TPU system metadata.
2019-09-18 12:42:19,857 :  Found TPU system:
2019-09-18 12:42:19,859 :  *** Num TPU Cores: 8
2019-09-18 12:42:19,860 :  *** Num TPU Workers: 1
2019-09-18 12:42:19,861 :  *** Num TPU Cores Per Worker: 8
2019-09-18 12:42:19,863 :  *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 10112529597588573712)
2019-09-18 12:42:19,866 :  *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 7136981085282186828)
2019-09-18 12:42:19,867 :  *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 4902132816901732828)
2019-09-18 12:42:19,868 :  *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 208929406772

ValueError: ignored

Training the model with the default parameters for 1 million steps will take ~53 hours. 

In case the kernel is restarted, you may always continue training from the latest checkpoint. 

This concludes the guide to pre-training BERT from scratch on a cloud TPU. However, the really fun stuff is still  to come, so stay tuned.

Keep learning!

In [None]:
!gsutil cp gs://hermes_assets/russian_uncased_L-12_H-768_A-12.zip gs://bert_resourses/