# BERT pretraining on Japanese wiki

This notebook is assumed to be executed on Colaboratory notebook with TPU.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14Ky8w5NodVyfk7tm13u6vdaGPl5qvPxL)


In [0]:
import tensorflow as tf

In [0]:
tf.__version__

'1.12.0'

In [0]:
!git clone --recursive https://github.com/yoheikikuta/bert-japanese.git

Cloning into 'bert-japanese'...
remote: Enumerating objects: 73, done.[K
remote: Counting objects: 100% (73/73), done.[K
remote: Compressing objects: 100% (47/47), done.[K
remote: Total 73 (delta 32), reused 58 (delta 20), pack-reused 0[K
Unpacking objects: 100% (73/73), done.
Submodule 'bert' (https://github.com/google-research/bert.git) registered for path 'bert'
Cloning into '/content/bert-japanese/bert'...
remote: Enumerating objects: 299, done.        
remote: Total 299 (delta 0), reused 0 (delta 0), pack-reused 299        
Receiving objects: 100% (299/299), 184.07 KiB | 3.61 MiB/s, done.
Resolving deltas: 100% (178/178), done.
Submodule path 'bert': checked out 'f39e881b169b9d53bea03d2d341b31707a6c052b'


Authentication to use TPU.

In [0]:
from google.colab import auth
auth.authenticate_user()

## Check TPU devices

In [0]:
import datetime
import json
import os
import pprint
import random
import string
import sys
import tensorflow as tf

assert 'COLAB_TPU_ADDR' in os.environ, 'ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!'
TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']
print('TPU address is', TPU_ADDRESS)

with tf.Session(TPU_ADDRESS) as session:
  print('TPU devices:')
  pprint.pprint(session.list_devices())

  # Upload credentials to TPU.
  with open('/content/adc.json', 'r') as f:
    auth_info = json.load(f)
  tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
  # Now credentials are set for all future sessions on this TPU.

TPU address is grpc://10.8.143.218:8470
TPU devices:
[_DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:CPU:0, CPU, -1, 3453988297340998056),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 269548181985733838),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:XLA_GPU:0, XLA_GPU, 17179869184, 10368517432306680665),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 5577581473910429639),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 1252546431492591496),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 5574517467752852166),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 14832744018215649907),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, 12561706912073700509),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 853713

## Set input and output

Need to put `all-maxseq(128|512).tfrecord` data for pre-traning on your GCS bucket.  
Trained objects will be saved into a specified GCS bucket.

In [0]:
INPUT_DATA_GCS = 'gs://bert-wiki-ja/data'

In [0]:
TARGET_DIRS = [
  'AA',
  'AB',
  'AC',
  'AD',
  'AE',
  'AF',
  'AG',
  'AH',
  'AI',
  'AJ',
  'AK',
  'AL',
  'AM',
  'AN',
  'AO',
  'AP',
  'AQ',
  'AR',
  'AS',
  'AT',
  'AU',
  'AV',
  'AW',
  'AX',
  'AY',
  'AZ',
  'BA',
  'BB'
]

In [0]:
# MAX_SEQ_LEN = 128
MAX_SEQ_LEN = 512

In [0]:
INPUT_FILE = ','.join( [ '{}/{}/all-maxseq{}.tfrecord'.format(INPUT_DATA_GCS, elem, MAX_SEQ_LEN) for elem in TARGET_DIRS] )

In [0]:
OUTPUT_GCS = 'gs://bert-wiki-ja/model'

## Execute pre-training

NOTE that you have to give `service-xxx@cloud-tpu.iam.gserviceaccount.com` the following permissions on the specified GCS bucket:
- Storage Legacy Bucket Reader
- Storage Legacy Bucket Writer
- Storage Legacy Object Reader
- Storage Object Viewer


In [0]:
# !python bert-japanese/src/run_pretraining.py \
#   --input_file={INPUT_FILE} \
#   --output_dir={OUTPUT_GCS} \
#   --use_tpu=True \
#   --tpu_name={TPU_ADDRESS} \
#   --num_tpu_cores=8 \
#   --do_train=True \
#   --do_eval=True \
#   --train_batch_size=256 \
#   --max_seq_length={MAX_SEQ_LEN} \
#   --max_predictions_per_seq=20 \
#   --num_train_steps=1000000 \
#   --num_warmup_steps=10000 \
#   --save_checkpoints_steps=10000 \
#   --learning_rate=1e-4

INFO:tensorflow:*** Input Files ***
INFO:tensorflow:  gs://bert-wiki-ja/data/AA/all-maxseq128.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AB/all-maxseq128.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AC/all-maxseq128.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AD/all-maxseq128.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AE/all-maxseq128.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AF/all-maxseq128.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AG/all-maxseq128.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AH/all-maxseq128.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AI/all-maxseq128.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AJ/all-maxseq128.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AK/all-maxseq128.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AL/all-maxseq128.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AM/all-maxseq128.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AN/all-maxseq128.tfrecord
INFO:tensorflow:  gs://ber

In [0]:
!python bert-japanese/src/run_pretraining.py \
  --input_file={INPUT_FILE} \
  --output_dir={OUTPUT_GCS} \
  --use_tpu=True \
  --tpu_name={TPU_ADDRESS} \
  --num_tpu_cores=8 \
  --do_train=True \
  --do_eval=True \
  --train_batch_size=64 \
  --max_seq_length={MAX_SEQ_LEN} \
  --max_predictions_per_seq=20 \
  --num_train_steps=1400000 \
  --num_warmup_steps=10000 \
  --save_checkpoints_steps=10000 \
  --learning_rate=1e-4

INFO:tensorflow:*** Input Files ***
INFO:tensorflow:  gs://bert-wiki-ja/data/AA/all-maxseq512.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AB/all-maxseq512.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AC/all-maxseq512.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AD/all-maxseq512.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AE/all-maxseq512.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AF/all-maxseq512.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AG/all-maxseq512.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AH/all-maxseq512.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AI/all-maxseq512.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AJ/all-maxseq512.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AK/all-maxseq512.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AL/all-maxseq512.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AM/all-maxseq512.tfrecord
INFO:tensorflow:  gs://bert-wiki-ja/data/AN/all-maxseq512.tfrecord
INFO:tensorflow:  gs://ber

In [0]:
t