# JAX & FLAX for Token Classification

> #### This is the basic transformer training code for token classification. The advantage of using JAX and FLAX is the speed of learning - it takes 30 minutes for one epoch. To improve the result, you can modify the model to accept additional data.

##### Data processing taken from here: [https://www.kaggle.com/thedrcat/feedback-prize-huggingface-baseline-training](http://)
##### The code was based on the following works: [https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface](http://)
#####                               [https://github.com/huggingface/transformers/blob/master/examples/flax/token-classification/run_flax_ner.py](http://)

# Install dependencies

In [None]:
!mkdir ./model
!mkdir ./token

In [None]:
!pip install --upgrade pip

In [None]:
%%capture
! pip install jax==0.2.25
! pip install jaxlib==0.1.74+cuda11 -f https://storage.googleapis.com/jax-releases/jax_releases.html
! pip install git+https://github.com/huggingface/transformers.git
! pip install git+https://github.com/deepmind/optax.git
! pip install --upgrade -q git+https://github.com/google/flax.git #pip install flax
! pip install seqeval
! conda install -y -c conda-forge datasets
! conda install -y importlib-metadata

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "False"

In [None]:
import jax, flax, tensorflow
for _m in (jax, flax, tensorflow):
    print(f'{_m.__name__}: {_m.__version__}')
# jax.devices()
jax.lib.xla_client._xla.is_optimized_build() 

# Dataset

In [None]:
import numpy as np
import datasets
from datasets import load_dataset, load_metric, Dataset
import pandas as pd
from jax import lax, random, numpy as jnp
# import random
from typing import Tuple
from pathlib import Path
from collections import defaultdict
# import flax
import optax
import json
import datetime
import pickle
from itertools import chain
from tqdm.auto import tqdm
from typing import Callable
# import jax
# import jax.numpy as np #jnp
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.training import train_state
from flax import traverse_util
from jax.experimental.maps import xmap

In [None]:
max_length = 1024
stride = 128
min_tokens = 6

# TRAINING HYPERPARAMS
BS = 1
GRAD_ACC = 8
LR = 5e-5
WD = 0.01
WARMUP = 0.1
N_EPOCHS = 10

In [None]:
#read train data
train = pd.read_csv('../input/feedback-prize-2021/train.csv')
train.head(1)

In [None]:
# check unique classes
classes = train.discourse_type.unique().tolist()
classes

In [None]:
tags = defaultdict()

for i, c in enumerate(classes):
    tags[f'B-{c}'] = i
    tags[f'I-{c}'] = i + len(classes)
tags[f'O'] = len(classes) * 2
tags[f'Special'] = -100
    
l2i = dict(tags)

i2l = defaultdict()
for k, v in l2i.items(): 
    i2l[v] = k
i2l[-100] = 'Special'

i2l = dict(i2l)

N_LABELS = len(i2l) - 1 # not accounting for -100

In [None]:
model_checkpoint = 'google/bigbird-roberta-base' 

In [None]:
from transformers import AutoTokenizer, BigBirdTokenizerFast, BigBirdTokenizer
tokenizer = BigBirdTokenizerFast.from_pretrained(model_checkpoint)
tokenizer.tokenize("The weather is fine today. ")

In [None]:
# Not sure if this is needed, but in case we create a span with certain class without starting token of that class,
# let's convert the first token to be the starting token.

e = [0,7,7,7,1,1,8,8,8,9,9,9,14,4,4,4]

def fix_beginnings(labels):
    for i in range(1,len(labels)):
        curr_lab = labels[i]
        prev_lab = labels[i-1]
        if curr_lab in range(7,14):
            if prev_lab != curr_lab and prev_lab != curr_lab - 7:
                labels[i] = curr_lab -7
    return labels

fix_beginnings(e)

In [None]:
with open('/kaggle/input/feedback-pickles/bigbird_tokenizer.pickle', 'rb') as handle:
    tokenized_datasets = pickle.load(handle)

In [None]:
tokenized_datasets

In [None]:
# The test was created by the 0.1 split of the data which is our validation/evaluation dataset.
train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["test"]

# Model
We have a regression problem at hand so the model just needs to output 1 number.

In [None]:
from transformers import FlaxAutoModelForTokenClassification, AutoConfig, FlaxBigBirdForTokenClassification #, is_tensorboard_available
# from transformers.file_utils import get_full_repo_name
# from transformers.utils import check_min_version
# from transformers.utils.versions import require_version

num_labels = 15
seed = 0

config = AutoConfig.from_pretrained(model_checkpoint, num_labels=num_labels) #, dtype=jnp.dtype("bfloat16"))
model = FlaxAutoModelForTokenClassification.from_pretrained(model_checkpoint, config=config, seed=seed)

# Training and evaluation loop

In [None]:
num_train_epochs = 15
learning_rate = 1e-5

There are 8 cores in TPUv3-8, so the effective `batch_size = 8 * per_device_batch_size`

In [None]:
total_batch_size = 4 #per_device_batch_size * jax.local_device_count()
print("The overall batch size (both for training and eval) is", total_batch_size)

In [None]:
# eval_dataset

In [None]:
num_train_steps = len(train_dataset) // total_batch_size * num_train_epochs
step_per_epoch = len(train_dataset) // total_batch_size
total_steps = num_train_steps * num_train_epochs
learning_rate_function = optax.cosine_onecycle_schedule(transition_steps=num_train_steps, peak_value=learning_rate, pct_start=0.01, )
print("The number of train steps (all the epochs) is", num_train_steps)

In [None]:
class TrainState(train_state.TrainState):
    logits_function: Callable = flax.struct.field(pytree_node=False)
    loss_function: Callable = flax.struct.field(pytree_node=False)

In [None]:
def decay_mask_fn(params):
    flat_params = traverse_util.flatten_dict(params)
    flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
    return traverse_util.unflatten_dict(flat_mask)

In [None]:
def adamw(weight_decay):
    #return optax.adamw(learning_rate=learning_rate_function, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay, mask=decay_mask_fn)
    return optax.lamb(learning_rate=0.000001, b1=0.9, b2=0.999, eps=1e-06, eps_root=0.0, weight_decay=weight_decay, mask=decay_mask_fn) # wd=0.0, mask=decay_mask_fn

In [None]:
adamw = adamw(1e-2)

## Loss and eval functions
The standard loss function for regression problems is the MSE loss. The book by Bishop has an additional 0.5 term, but we're skipping in that without loss of generality. That term just scales the loss by a constant factor and doesn't have an impact on the gradients (other than scaling).

In [None]:
def eval_fn(logits):
    return logits.argmax(-1)

In [None]:
@jax.jit

def cross_entropy_loss(logits, labels):
    xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=15)) #num_labels))
    return jnp.mean(xentropy)

In [None]:
metric = load_metric("seqeval")

In [None]:
# @jax.jit
def get_labels(y_pred, y_true):
    true_predictions = [
            [i2l[p] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(y_pred, y_true) #zip(predictions, labels)
        ]
    true_labels = [
            [i2l[l] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(y_pred, y_true) #zip(predictions, labels)
        ]
    return true_predictions, true_labels

In [None]:
# #  @jax.jit
def compute_metrics():
    results = metric.compute()
    if return_entity_level_metrics:
        # Unpack nested dictionaries
        final_results = {}
        for key, value in results.items():
            if isinstance(value, dict):
                for n, v in value.items():
                    final_results[f"{key}_{n}"] = v
            else:
                final_results[key] = value
        return final_results
    else:
        return {
            "precision": results["overall_precision"],
            "recall": results["overall_recall"],
            "f1": results["overall_f1"],
            "accuracy": results["overall_accuracy"],
        }

In [None]:
# import jax
# import jax.numpy as np

# key = jax.random.PRNGKey(0)

print('JAX is running on', jax.lib.xla_bridge.get_backend().platform)

## Create the initial train state

In [None]:
rng = jax.random.PRNGKey(seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())

In [None]:
state = TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx=adamw,
    logits_function=eval_fn, #lambda logits: logits.argmax(-1), #eval_function,
    loss_function=cross_entropy_loss,
)

In [None]:
def train_data_collator(rng, dataset, batch_size):
    """Returns shuffled batches of size `batch_size` from truncated `train dataset`, sharded over all local devices."""
    """Возвращает перетасованные пакеты размером `batch_size` из усеченного `набора данных поезда`, разделенные на все локальные устройства."""
    steps_per_epoch = len(dataset) // batch_size
    perms = jax.random.permutation(rng, len(dataset))
    perms = perms[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    perms = perms.reshape((steps_per_epoch, batch_size))

    for perm in perms:
        batch = dataset[perm]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)

        yield batch

In [None]:
def eval_data_collator(dataset, batch_size):
    """Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices."""
    for i in range(len(dataset) // batch_size):
        batch = dataset[i * batch_size : (i + 1) * batch_size]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)

        yield batch

In [None]:
# token classification
@jax.jit
def train_step(state, batch, dropout_rng):
    targets = batch.pop("labels")
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
    def loss_function(params):
#         print(params)
        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
        loss = state.loss_function(logits, targets)
        return loss

    grad_function = jax.value_and_grad(loss_function)
    loss, grad = grad_function(state.params)
    grad = jax.lax.pmean(grad, "batch")
    new_state = state.apply_gradients(grads=grad)
    metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_function(state.step)}, axis_name="batch")
    return new_state, metrics, new_dropout_rng

In [None]:
parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,), backend='gpu')

In [None]:
# token classification
@jax.jit
def eval_step(state, batch):
    logits = state.apply_fn(**batch, params=state.params, train=False)[0]
    return state.logits_function(logits)

In [None]:
parallel_eval_step = jax.pmap(eval_step, axis_name="batch", backend='gpu')

In [None]:
# import jax
# import jax.numpy as np

# key = jax.random.PRNGKey(0)

# print('JAX is running on', jax.lib.xla_bridge.get_backend().platform)

In [None]:
state = flax.jax_utils.replicate(state)

In [None]:
return_entity_level_metrics=False
eval_steps = 404
total_eval_steps = 4040
eval_stp = 404
best = 0

In [None]:
# import jax
# import jax.numpy as np

# key = jax.random.PRNGKey(0)

print('JAX is running on', jax.lib.xla_bridge.get_backend().platform)

In [None]:
from flax.jax_utils import replicate, unreplicate
from itertools import chain
import time

train_time = 0
epochs = tqdm(range(num_train_epochs), desc=f"Epoch ... (1/{num_train_epochs})", position=0)
for epoch in epochs:

    train_start = time.time()
    train_metrics = []

    # Create sampling rng
    rng, input_rng = jax.random.split(rng)

    # train
    for step, batch in enumerate(tqdm(train_data_collator(input_rng, train_dataset, total_batch_size),total=step_per_epoch,desc="Training...",position=1,)):
        state, train_metric, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
        train_metrics.append(train_metric)

        cur_step = (epoch * step_per_epoch) + (step + 1)
        if cur_step % step_per_epoch == 0 and cur_step > 0:
            # Save metrics
            train_metric = unreplicate(train_metric)
            train_time += time.time() - train_start
            
            epochs.write(
                f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
            )

            train_metrics = []
        
        # if cur_step % eval_steps == 0 and cur_step > 0:

            eval_metrics = {}
            # evaluate
            for batch in tqdm(
                eval_data_collator(eval_dataset, total_batch_size),
                total=len(eval_dataset) // total_batch_size,
                desc="Evaluating ...",
                position=2,
            ):
                labels = batch.pop("labels")
                predictions = parallel_eval_step(state, batch)
                predictions = np.array([pred for pred in chain(*predictions)])
                labels = np.array([label for label in chain(*labels)])
                labels[np.array(chain(*batch["attention_mask"])) == 0] = -100
                preds, refs = get_labels(predictions, labels)
                metric.add_batch(
                    predictions=preds,
                    references=refs,
                )

            # # evaluate also on leftover examples (not divisible by batch_size)
            # num_leftover_samples = len(eval_dataset) % total_batch_size

            # # make sure leftover batch is evaluated on one device
            # if num_leftover_samples > 0 and jax.process_index() == 0:
            #     # take leftover samples
            #     batch = eval_dataset[-num_leftover_samples:]
            #     batch = {k: np.array(v) for k, v in batch.items()}

            #     labels = batch.pop("labels")
            #     predictions = eval_step(unreplicate(state), batch)
            #     labels = np.array(labels)
            #     labels[np.array(batch["attention_mask"]) == 0] = -100
            #     preds, refs = get_labels(predictions, labels)
            #     metric.add_batch(
            #         predictions=preds,
            #         references=refs,
            #     )
            
            eval_metrics = compute_metrics()
            epochs.write(
                f"Step... ({eval_steps}/{total_eval_steps} | Val Loss: prc: {eval_metrics['precision']}, rec: {eval_metrics['recall']}, f1: {eval_metrics['f1']}, acc: {eval_metrics['accuracy']})"
            )
            # print(eval_metrics)
            eval_steps = eval_steps+eval_stp
        if (cur_step % step_per_epoch == 0 and cur_step > 0) or (cur_step == total_steps):
            
            # save checkpoint after each epoch and push checkpoint to the hub
            if jax.process_index() == 0:
                params = jax.device_get(unreplicate(state.params))
            if best < float(eval_metrics['f1']):
                model.save_pretrained('./model', params=params)
                tokenizer.save_pretrained('./token')
                best = float(eval_metrics['f1'])
                
    epochs.desc = f"Epoch ... {epoch + 1}/{num_train_epochs}"

In [None]:
# from IPython.display import FileLink
# FileLink(r'./model/config.json')