# Fine-tuning BERT in Flax on GLUE

This notebook fine-tunes a BERT model one of the [GLUE tasks](https://gluebenchmark.com/). It has the following features:

*   Uses the [HuggingFace](https://github.com/huggingface/) datasets and tokenizers libraries.
*   Loads the pre-trained BERT weights from HuggingFace.
*   Model and training code is written in [Flax](http://www.github.com/google/flax).
*   Can be configured to fine-tune on COLA, MRPC, SST2, STSB, QNLI, and RTE.

Run-times:

*   Single GPU: 8min
*   Cloud TPU v3-8: 2min

In [3]:
# General imports.
import os
import jax
import jax.numpy as jnp
import flax

# Huggingface datasets and transformers libraries.
import datasets
from transformers import BertTokenizerFast

# flax_bert-specific imports.
from flax import optim
import data
import modeling as flax_models
from demo_lib import get_config, import_pretrained_params, create_model, create_optimizer, run_train, run_eval

os.environ['TOKENIZERS_PARALLELISM'] = 'true'

## Set your Training Settings

In [4]:
train_settings = {
    'train_batch_size': 32,
    'eval_batch_size': 8,
    'learning_rate': 3e-5,
    'num_train_epochs': 3,
    'dataset_path': 'glue',
    'dataset_name': 'mrpc'  # ['cola', 'mrpc', 'sst2', 'stsb', 'qnli', 'rte']
}

## Load dataset, tokenizers, and model.

In [5]:
# Load the GLUE task.
dataset = datasets.load_dataset('glue', train_settings['dataset_name'])

# Get pre-trained config and update it with the train configuration.
config = get_config('bert-base-uncased', dataset)
config.update(train_settings)

# Load HuggingFace tokenizer and data pipeline.
tokenizer = BertTokenizerFast.from_pretrained(config.tokenizer)
data_pipeline = data.ClassificationDataPipeline(dataset, tokenizer)

# Create Flax model and optimizer.
pretrained_params = import_pretrained_params(config)
model = create_model(config, pretrained_params)
optimizer = create_optimizer(config, model, pretrained_params)

Reusing dataset glue (/home/marcvanzee/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…

Loading cached processed dataset at /home/marcvanzee/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-613c8ebe967bb60f.arrow
Loading cached processed dataset at /home/marcvanzee/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-9c525b7013511374.arrow
Loading cached processed dataset at /home/marcvanzee/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-9c0c46070208810a.arrow





HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…


loading weights file https://huggingface.co/bert-base-uncased/resolve/main/pytorch_model.bin from cache at /home/marcvanzee/.cache/torch/transformers/a8041bf617d7f94ea26d15e218abd04afc2004805632abc0ed2066aa16d50d04.faf6ea826ae9c5867d12b22257f9877e6b8367890837bd60f7c54a29633f7f2f


## Run Training

In [6]:
optimizer = run_train(optimizer, data_pipeline, tokenizer, config)


Starting training on mrpc for 343 steps (3 epochs)...

Step 0       grad_norm = 108.0107421875
             loss = 0.8161654472351074
Step 200     grad_norm = 49.66865539550781
             loss = 0.27637040615081787
             seconds_per_step = 0.2847726345062256

Finished training.


## Run Evaluation

The target eval_f1 for MRPC is 88.9.

In [8]:
run_eval(optimizer, data_pipeline, config)


Running eval...

eval_accuracy = 0.852941
eval_f1 = 0.895833
