# Fine-tuning BERT in Flax on GLUE

This notebook fine-tunes a BERT model one of the [GLUE tasks](https://gluebenchmark.com/). It has the following features:

*   Uses the [HuggingFace](https://github.com/huggingface/) datasets and tokenizers libraries.
*   Loads the pre-trained BERT weights from HuggingFace.
*   Model and training code is written in [Flax](http://www.github.com/google/flax).
*   Can be configured to fine-tune on COLA, MRPC, SST2, STSB, QNLI, and RTE.

Run-times:

*   Single GPU: 8min
*   Cloud TPU v3-8: 2min

In [1]:
# General imports.
import os
import jax
import jax.numpy as jnp
import flax

# Huggingface datasets and transformers libraries.
import datasets
from transformers import BertTokenizerFast

# flax_bert-specific imports.
from flax import optim
import data
import modeling as flax_models
from demo_lib import get_config, import_pretrained_params, create_model, create_optimizer, run_train, run_eval

os.environ['TOKENIZERS_PARALLELISM'] = 'true'

## Set your Training Settings

In [2]:
train_settings = {
    'train_batch_size': 32,
    'eval_batch_size': 8,
    'learning_rate': 3e-5,
    'num_train_epochs': 3,
    'dataset_path': 'glue',
    'dataset_name': 'mrpc'  # ['cola', 'mrpc', 'sst2', 'stsb', 'qnli', 'rte']
}

## Load dataset, tokenizers, and model.

In [8]:
# Load the GLUE task.
dataset = datasets.load_dataset('glue', train_settings['dataset_name'])

# Get pre-trained config and update it with the train configuration.
config = get_config('bert-base-uncased', dataset)
config.update(train_settings)

# Load HuggingFace tokenizer and data pipeline.
tokenizer = BertTokenizerFast.from_pretrained(config.tokenizer)
data_pipeline = data.ClassificationDataPipeline(dataset, tokenizer)

# Create Flax model and optimizer.
pretrained_params = import_pretrained_params(config)
model = create_model(config, pretrained_params)
optimizer = create_optimizer(config, model, pretrained_params)

Reusing dataset glue (/home/marcvanzee/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2.0), HTML(value='')))


loading weights file https://cdn.huggingface.co/bert-base-uncased-pytorch_model.bin from cache at /home/marcvanzee/.cache/torch/transformers/f2ee78bdd635b758cc0a12352586868bef80e47401abe4c4fcc3832421e7338b.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157


## Run Training

In [9]:
optimizer = run_train(optimizer, data_pipeline, tokenizer, config)


Starting training on mrpc for 343 steps (3 epochs)...

Step 0       grad_norm = 2.1773173809051514
             loss = 0.655493438243866
Step 200     grad_norm = 84.15443420410156
             loss = 0.31002330780029297
             seconds_per_step = 0.31129512190818787

Finished training.


## Run Evaluation

The target eval_f1 for MRPC is 88.9.

In [10]:
run_eval(optimizer, data_pipeline, config)


Running eval...

eval_accuracy = 0.843137
eval_f1 = 0.890785
