## Setup

In [1]:
!pip install tensorflow-gpu==2.0 tensorflow_datasets gpustat transformers -Uq

**About**

<img src="https://upload.wikimedia.org/wikipedia/en/thumb/6/6d/Nvidia_image_logo.svg/200px-Nvidia_image_logo.svg.png" width="90px" align="right" style="margin-right: 0px;">

This notebook is put together by Timothy Liu (`timothyl@nvidia.com`) for the [**PyCon SG**](https://pycon.sg/) 2019 tutorial on [**Improving Deep Learning Performance in TensorFlow**](https://github.com/NVAITC/pycon-sg19-tensorflow-tutorial).

**Acknowledgements**

* This notebook uses some materials adapted from TensorFlow documentation.
* This notebook uses the [HuggingFace Transformers library](https://github.com/huggingface/transformers).
* This notebook uses the [GLUE (MRPC) Dataset](https://gluebenchmark.com/) ([TensorFlow Datasets page](https://www.tensorflow.org/datasets/catalog/glue)).

**Dataset Citation**

```
@inproceedings{wang2019glue,
  title={ {GLUE}: A Multi-Task Benchmark and Analysis Platform for Natural Language Understanding},
  author={Wang, Alex and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel R.},
  note={In the Proceedings of ICLR.},
  year={2019}
}
```

In [2]:
import tensorflow.compat.v2 as tf
import tensorflow_datasets

In [3]:
import time

class TimeHistory(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.times = []
    def on_epoch_begin(self, epoch, logs={}):
        self.epoch_time_start = time.time()
    def on_epoch_end(self, epoch, logs={}):
        self.times.append(time.time() - self.epoch_time_start)

# Sequence Classification with BERT in TF 2.0

In [4]:
# enable XLA
tf.config.optimizer.set_jit(True)

# enable AMP
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})

## Load BERT Tokenizer

In [5]:
from transformers import BertTokenizer, TFBertForSequenceClassification, glue_convert_examples_to_features

# load tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

## Input Pipeline

### Load Dataset

In [6]:
data, info = tensorflow_datasets.load("glue/mrpc", with_info=True)

train_examples = info.splits["train"].num_examples
valid_examples = info.splits["validation"].num_examples

INFO:absl:Overwrite dataset info from restored data version.
INFO:absl:Reusing dataset glue (/home/jovyan/tensorflow_datasets/glue/mrpc/0.0.2)
INFO:absl:Constructing tf.data.Dataset for split None, from /home/jovyan/tensorflow_datasets/glue/mrpc/0.0.2


## Build Input Pipeline

In [7]:
BATCH_SIZE = 40

# Prepare dataset for GLUE as a tf.data.Dataset instance
train_dataset = glue_convert_examples_to_features(data["train"], tokenizer, 128, "mrpc")
train_dataset = train_dataset.shuffle(512).batch(BATCH_SIZE).repeat(-1).prefetch(8)

valid_dataset = glue_convert_examples_to_features(data["validation"], tokenizer, 128, "mrpc")
valid_dataset = valid_dataset.batch(BATCH_SIZE)

## Build BERT Model

### Load Pre-trained BERT Model

In [8]:
model = TFBertForSequenceClassification.from_pretrained("bert-base-cased")

In [9]:
opt = tf.keras.optimizers.Adam(learning_rate=3e-5)
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, "dynamic")

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
acc = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
model.compile(optimizer=opt,
              loss=loss,
              metrics=[acc])

## Train BERT Model

In [10]:
time_callback = TimeHistory()

In [11]:
history = model.fit(train_dataset, epochs=4, steps_per_epoch=train_examples//BATCH_SIZE,
                    validation_data=valid_dataset, validation_steps=valid_examples//BATCH_SIZE,
                    validation_freq=3, callbacks=[time_callback])

Train for 91 steps, validate for 10 steps
Epoch 1/4


  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


Epoch 2/4
Epoch 3/4
Epoch 4/4


In [12]:
epoch_time = min(time_callback.times)
egs_per_sec = train_examples//epoch_time

print("Peak Examples/s:", egs_per_sec)

Peak Examples/s: 128.0
