# Fine-tuning BERT in Flax on GLUE

This notebook fine-tunes a BERT model one of the [GLUE tasks](https://gluebenchmark.com/). It has the following features:

*   Uses the [HuggingFace](https://github.com/huggingface/) datasets and tokenizers libraries.
*   Loads the pre-trained BERT weights from HuggingFace.
*   Model and training code is written in [Flax](http://www.github.com/google/flax).
*   Can be configured to fine-tune on COLA, MRPC, SST2, STSB, QNLI, and RTE.

Run-times on MRPC:

*   Cloud TPU v3-8: 40s

## Training Settings

In [1]:
# TODO: either use HF config or unify this in another way
train_settings = {
    'train_batch_size': 32,
    'eval_batch_size': 8,
    'learning_rate': 5e-5,
    'num_train_epochs': 3,
    'dataset_path': 'glue',
    'dataset_name': 'mrpc'  # ['cola', 'mrpc', 'sst2', 'stsb', 'qnli', 'rte']
}

## Create HuggingFace dataset, tokenizer, and pre-processing pipeline

In [2]:
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
# os.environ['XLA_FLAGS'] = '--xla_multiheap_size_constraint_per_heap=-1'

# HF imports
import datasets
from transformers import BertTokenizerFast
# Demo imports
import data
from demo_lib import get_config

dataset = datasets.load_dataset('glue', train_settings['dataset_name'])

config = get_config('bert-base-uncased', dataset)
config.update(train_settings)
tokenizer = BertTokenizerFast.from_pretrained(config.tokenizer)
tokenizer.model_max_length = config.max_seq_length

data_pipeline = data.ClassificationDataPipeline(dataset, tokenizer)
train_iter = data_pipeline.get_inputs(
  split='train', batch_size=config.train_batch_size, training=True)

Reusing dataset glue (/home/skyewm/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)
Loading cached processed dataset at /home/skyewm/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-319515621adfe7f0.arrow
Loading cached processed dataset at /home/skyewm/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-53b8739ff1af4dc2.arrow
Loading cached processed dataset at /home/skyewm/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-f72c7188db37a2b9.arrow


## Create flax model and optimizer

In [3]:
from demo_lib import import_pretrained_params, create_model, create_optimizer

# TODO: inline some or all of this; TODO: what is flax
pretrained_params = import_pretrained_params(config)
model = create_model(config, pretrained_params)
optimizer = create_optimizer(config, model, pretrained_params)

loading weights file https://huggingface.co/bert-base-uncased/resolve/main/pytorch_model.bin from cache at /home/skyewm/.cache/huggingface/transformers/a8041bf617d7f94ea26d15e218abd04afc2004805632abc0ed2066aa16d50d04.faf6ea826ae9c5867d12b22257f9877e6b8367890837bd60f7c54a29633f7f2f


## Set up training step function

In [4]:
# Demo imports
import training
from demo_lib import get_num_train_steps, get_learning_rate_fn

# TODO: remove train history and maybe train state
num_train_steps = get_num_train_steps(config, data_pipeline)
learning_rate_fn = get_learning_rate_fn(config, num_train_steps)
train_history = training.TrainStateHistory(learning_rate_fn)
train_state = train_history.initial_state()

# TODO: move pmap out of create_train_step
train_step_fn = training.create_train_step(clip_grad_norm=1.0)

Num train examples: 3668


## Run Training

In [5]:
print(f'\nStarting training on {config.dataset_name} for {num_train_steps} '
      f'steps ({config.num_train_epochs:.0f} epochs)...\n')

for step, batch in zip(range(0, num_train_steps), train_iter):
  optimizer, train_state = train_step_fn(optimizer, batch, train_state)
  if step % 10 == 0:
    print(f'step {step}/{num_train_steps}')

print('\nTraining finished!')


Starting training on mrpc for 343 steps (3 epochs)...

Compiling train (takes about 20s)
compiling train_step
done 23.806512117385864
Step 0       grad_norm = 4.391973972320557
             loss = 0.42315247654914856
step 0/343
step 10/343
step 20/343
step 30/343
step 40/343
step 50/343
step 60/343
step 70/343
step 80/343
step 90/343
step 100/343
step 110/343
step 120/343
step 130/343
step 140/343
step 150/343
step 160/343
step 170/343
step 180/343
step 190/343
Step 200     grad_norm = 18.06053924560547
             loss = 0.9341562986373901
             seconds_per_step = 0.06082286685705185
step 200/343
step 210/343
step 220/343
step 230/343
step 240/343
step 250/343
step 260/343
step 270/343
step 280/343
step 290/343
step 300/343
step 310/343
step 320/343
step 330/343
step 340/343

Training finished!


## Run Evaluation

The target eval_f1 for MRPC is 88.9 (variance of about 1.0).

In [6]:
# demo imports
from demo_lib import get_validation_splits, get_prefix

# TODO: inline eval_fn
eval_fn = training.create_eval_fn()

for split in get_validation_splits(config.dataset_name):
  eval_iter = data_pipeline.get_inputs(
      split='validation', batch_size=config.eval_batch_size, training=False)
  eval_stats = eval_fn(optimizer, eval_iter)
  eval_metric = datasets.load_metric(config.dataset_path, config.dataset_name)
  eval_metric.add_batch(
    predictions=eval_stats['prediction'],
    references=eval_stats['label'])
  eval_metrics = eval_metric.compute()
  for name, val in sorted(eval_metrics.items()):
    print(f'{get_prefix(split)}_{name} = {val:.06f}', flush=True)

compiling compute_classification_stats
done 2.0442707538604736
eval_accuracy = 0.850490
eval_f1 = 0.896785
