In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

<h1 style="background-color:#DC143C; font-family:'Brush Script MT',cursive;color:white;font-size:200%; text-align:center;border-radius: 50% 20% / 10% 40%">JAX is NumPy + autodiff + GPU/TPU</h1>

It allows for fast scientific computing and machine learning with the normal NumPy API (+ additional APIs for special accelerator ops when needed)

JAX comes with powerful primitives, which you can compose arbitrarily:

Autodiff (jax.grad): Efficient any-order gradients w.r.t any variables

JIT compilation (jax.jit): Trace any function ⟶ fused accelerator ops

Vectorization (jax.vmap): Automatically batch code written for individual samples

Parallelization (jax.pmap): Automatically parallelize code across multiple accelerators (including across hosts, e.g. for TPU pods)

If you don’t know JAX but just want to learn what you need to use Flax, you can check our JAX for the impatient notebook.

https://flax.readthedocs.io/en/latest/overview.html

<h1 style="background-color:#DC143C; font-family:'Brush Script MT',cursive;color:white;font-size:200%; text-align:center;border-radius: 50% 20% / 10% 40%">JAX for the Impatient</h1>

JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.

They cover the basics of JAX so that you can get started with Flax, however we very much recommend that you go through JAX’s documentation here after going over the basics there.

https://flax.readthedocs.io/en/latest/notebooks/jax_for_the_impatient.html

#Code by Kartheek Akella  https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook

#Started at 18:17 Finished 19:14

In [None]:
#Four minutes here

%%capture
!conda install -y -c conda-forge jax jaxlib flax optax datasets transformers
!conda install -y importlib-metadata

In [None]:
#Code by Kartheek Akella  https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook

import os
if 'TPU_NAME' in os.environ:
    import requests
    if 'TPU_DRIVER_MODE' not in globals():
        url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
        resp = requests.post(url)
        TPU_DRIVER_MODE = 1


    from jax.config import config
    config.FLAGS.jax_xla_backend = "tpu_driver"
    config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
    print('Registered TPU:', config.FLAGS.jax_backend_target)
else:
    print('No TPU detected. Can be changed under "Runtime/Change runtime type".')

In [None]:
import jax
jax.local_devices()

In [None]:
model_checkpoint = "bert-base-uncased" # 'roberta-base' has an error remaining are working.
per_device_batch_size = 32

In [None]:
#Code by Kartheek Akella  https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook

import numpy as np
import datasets

def simple_rmse(preds, labels):
    rmse = np.sqrt(np.sum(np.square(preds-labels))/preds.shape[0])
    return rmse


class RMSE(datasets.Metric):
    def _info(self):
        return datasets.MetricInfo(
            description="Calculates Root Mean Squared Error (RMSE) metric.",
            citation="TODO: _CITATION",
            inputs_description="_KWARGS_DESCRIPTION",
            features=datasets.Features({
                'predictions': datasets.Value('float32'),
                'references': datasets.Value('float32'),
            }),
            codebase_urls=[],
            reference_urls=[],
            format='numpy'
        )

    def _compute(self, predictions, references):
        return {"RMSE": simple_rmse(predictions, references)}

#Loading dataset and metric

I personally prefer HugginFace datasets because they are very well designed and makes it easy to pre-process all the samples very easily and it has several features like easily loading from the CSV file without using any Pandas data frame objects as intermediates.

In [None]:
#Code by Kartheek Akella  https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook

from datasets import load_dataset, load_metric
raw_train = load_dataset("csv", data_files={'train': ['../input/nlp-with-disaster-tweets-cleaning-data/train_data_cleaning.csv']})
raw_test = load_dataset('csv', data_files={'test': ['../input/nlp-with-disaster-tweets-cleaning-data/test_data_cleaning.csv']})

In [None]:
# Split the train set into train and valid sets
raw_train = raw_train["train"].train_test_split(0.1)

In [None]:
metric = RMSE()

#Pre-process the dataset

This is a very generic pre-processing nothing special. Just tokenized the sentence and padded it appropriately.

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
#Code by Kartheek Akella  https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook

def preprocess_function(examples):
    texts = (examples["text"],)
    processed = tokenizer(*texts, padding="max_length", max_length=128, truncation=True)
    
    processed["labels"] = examples["target"]
    return processed

In [None]:
tokenized_dataset = raw_train.map(preprocess_function, batched=True, remove_columns=raw_train["train"].column_names)

In [None]:
tokenized_dataset

In [None]:
# The test was created by the 0.1 split of the data which is our validation/evaluation dataset.
train_dataset = tokenized_dataset["train"]
eval_dataset = tokenized_dataset["test"]

#Model

In [None]:
#Code by Kartheek Akella  https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook

from transformers import FlaxAutoModelForSequenceClassification, AutoConfig

num_labels = 1
seed = 0

config = AutoConfig.from_pretrained(model_checkpoint, num_labels=num_labels)
model = FlaxAutoModelForSequenceClassification.from_pretrained(model_checkpoint, config=config, seed=seed)

#Training and evaluation loop

In [None]:
import flax
import jax
import optax

from itertools import chain
from tqdm.notebook import tqdm
from typing import Callable

import jax.numpy as jnp

from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.training import train_state
from flax import traverse_util

In [None]:
num_train_epochs = 10
learning_rate = 2e-5

There are 8 cores in TPUv3-8, so the effective batch_size = 8 * per_device_batch_size

In [None]:
#Code by Kartheek Akella  https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook

total_batch_size = per_device_batch_size * jax.local_device_count()
print("The overall batch size (both for training and eval) is", total_batch_size)

I used the One-Cycle LR Scheduler with Cosine Annealing. It is super easy to create this LR Schedule with the Optax library, it is the recommended library while using any JAX based NN libraries. Optax is being developed by DeepMind has several amazing features, definitely give it a try!

TODO: Add citations to the original One-Cycle and Cosine Annealing papers.

In [None]:
#Code by Kartheek Akella  https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook

num_train_steps = len(train_dataset) // total_batch_size * num_train_epochs

learning_rate_function = optax.cosine_onecycle_schedule(transition_steps=num_train_steps, peak_value=learning_rate, pct_start=0.1, )
print("The number of train steps (all the epochs) is", num_train_steps)

#Create a Train State

Next, we will create the training state that includes the optimizer, the loss function, and is responsible for updating the model's parameters during training.

Most JAX transformations (notably jax.jit) require functions that are transformed to have no side-effects as it follows a functional programming type paradigm at its core. This is because any such side-effects will only be executed once, when the Python version of the function is run during compilation (see Stateful Computations in JAX). As a consequence, Flax models (which can be transformed by JAX transformations) are immutable, and the state of the model (i.e., its weight parameters) are stored outside of the model instance.

Flax provides a convenience class flax.training.train_state.TrainState, which stores things such as the model parameters, the loss function, the optimizer, and exposes an apply_gradients function to update the model's weight parameters.

We create a derived TrainState class that additionally stores the model's forward pass as eval_function as well as a loss_function.

In [None]:
#Code by Kartheek Akella  https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook

class TrainState(train_state.TrainState):
    logits_function: Callable = flax.struct.field(pytree_node=False)
    loss_function: Callable = flax.struct.field(pytree_node=False)

#AdamW Optimizer

We will be using the standard Adam optimizer with weight decay. For more information on AdamW (Adam + weight decay), one can take a look at this blog post. weight_decay value of 0.01 is a good starting point, you can tweak this hyper-parameter and experiment with how it influences the final trained model.

Regularizing the bias and/or LayerNorm has not shown to improve performance and can even be disadvantageous, which is why we disable it here. For more information on this, please check out the following blog post or paper.

Hence we create a decay_mask_fn which makes sure that weight decay is not applied to any bias or LayerNorm weights. This can easily be done by passing a mask_fn to optax.adamw.

NOTE: Beginners (myself) can ignore the decay_mask_fn, the changes are minimal if you leave out doing this step.

In [None]:
#Code by Kartheek Akella  https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook

def decay_mask_fn(params):
    flat_params = traverse_util.flatten_dict(params)
    flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
    return traverse_util.unflatten_dict(flat_mask)

In [None]:
#Code by Kartheek Akella  https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook

def adamw(weight_decay):
    return optax.adamw(learning_rate=learning_rate_function, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay, mask=decay_mask_fn)

In [None]:
adamw = adamw(1e-2)

#Loss and eval functions

The standard loss function for regression problems is the MSE loss. The book by Bishop has an additional 0.5 term, but we're skipping in that without loss of generality. That term just scales the loss by a constant factor and doesn't have an impact on the gradients (other than scaling).

In [None]:
@jax.jit
def loss_function(logits, labels):
    return jnp.mean((logits[..., 0] - labels) ** 2)

@jax.jit    
def eval_function(logits):
    return logits[..., 0]

#Create the initial train state

In [None]:
#Code by Kartheek Akella  https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook

state = TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx=adamw,
    logits_function=eval_function,
    loss_function=loss_function,
)

#Defining the training and evaluation step

During fine-tuning, we want to update the model parameters and evaluate the performance after each epoch.

Let's write the functions train_step and eval_step accordingly. During training the weight parameters should be updated as follows:

Define a loss function loss_function that first runs a forward pass of the model given data input. Remember that Flax models are immutable, and we explicitly pass it the state (in this case the model parameters and the RNG). loss_function returns a scalar loss (using the previously defined state.loss_function) between the model output and input targets.
Differentiate this loss function using jax.value_and_grad. This is a JAX transformation called automatic differentiation, which computes the gradient of loss_function given the input to the function (i.e., the parameters of the model), and returns the value and the gradient in a pair (loss, gradients).
Compute the mean gradient over all devices using the collective operation lax.pmean. As we will see below, each device runs train_step on a different batch of data, but by taking the mean here we ensure the model parameters are the same on all devices.
Use state.apply_gradients, which applies the gradients to the weights.
Below, you can see how each of the described steps above is put into practice.

NOTE: Taken from HuggingFace examples

In [None]:
#Code by Kartheek Akella  https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook

def train_step(state, batch, dropout_rng):
    targets = batch.pop("labels")
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

    def loss_function(params):
        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
        loss = state.loss_function(logits, targets)
        return loss

    grad_function = jax.value_and_grad(loss_function)
    loss, grad = grad_function(state.params)
    grad = jax.lax.pmean(grad, "batch")
    new_state = state.apply_gradients(grads=grad)
    metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_function(state.step)}, axis_name="batch")
    return new_state, metrics, new_dropout_rng

Now, we want to do parallelized training over all TPU devices. To do so, we use jax.pmap. This will compile the function once and run the same program on each device (it is an SPMD program). When calling this pmapped function, all inputs ("state", "batch", "dropout_rng") should be replicated for all devices, which means that the first axis of each argument is used to map over all TPU devices.

The argument donate_argnums is used to tell JAX that the first argument "state" is "donated" to the computation, because it is not needed anymore afterwards. XLA can make use of donated buffers to reduce the memory needed.

In [None]:
parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))

In [None]:
#Code by Kartheek Akella  https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook

def eval_step(state, batch):
    logits = state.apply_fn(**batch, params=state.params, train=False)[0]
    return state.logits_function(logits)

In [None]:
parallel_eval_step = jax.pmap(eval_step, axis_name="batch")

#Define Data Loaders

In a final step before we can start training, we need to define the data collators. The data collator is important to shuffle the training data before each epoch and to prepare the batch for each training and evaluation step.

First, a random permutation of the whole dataset is defined. Then, every time the training data collator is called the next batch of the randomized dataset is extracted, converted to a JAX array and sharded over all local TPU devices.

In [None]:
#Code by Kartheek Akella  https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook

def train_data_loader(rng, dataset, batch_size):
    steps_per_epoch = len(dataset) // batch_size
    perms = jax.random.permutation(rng, len(dataset))
    perms = perms[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    perms = perms.reshape((steps_per_epoch, batch_size))

    for perm in perms:
        batch = dataset[perm]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)

        yield batch

In [None]:
#Code by Kartheek Akella  https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook

def eval_data_loader(dataset, batch_size):
    for i in range(len(dataset) // batch_size):
        batch = dataset[i * batch_size : (i + 1) * batch_size]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)

        yield batch

Next, we replicate/copy the weight parameters on each device, so that we can pass them to our pmapped functions.

In [None]:
state = flax.jax_utils.replicate(state)

#Training

Now we define the full training loop. For each batch in each epoch, we run a training step. Here, we also need to make sure that the PRNGKey is sharded/split over each device. Having completed an epoch, we report the training metrics and can run the evaluation.

The first batch takes a bit longer to process but nothing to worry because during the first batch, XLA compiler is working hard to make everything super fast. The first takes close to 5 mins for processing and then entire epochs take ~5 sec to process. Aren't TPUs amazing!!

5 seconds for an entire EPOCH!!

Note: The times mentioned above are an average estimate over 8 different runs on several different TPU machines and several model architectures.

In [None]:
#Code by Kartheek Akella  https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook

rng = jax.random.PRNGKey(seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())

In [None]:
#Code by Kartheek Akella  https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook


for i, epoch in enumerate(tqdm(range(1, num_train_epochs + 1), desc=f"Epoch ...", position=0, leave=True)):
    rng, input_rng = jax.random.split(rng)

    # train
    with tqdm(total=len(train_dataset) // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
        for batch in train_data_loader(input_rng, train_dataset, total_batch_size):
            state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
            progress_bar_train.update(1)

    # evaluate
    with tqdm(total=len(eval_dataset) // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
        for batch in eval_data_loader(eval_dataset, total_batch_size):
            labels = batch.pop("labels")
            predictions = parallel_eval_step(state, batch)
            metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
            progress_bar_eval.update(1)

    eval_metric = metric.compute()

    loss = round(flax.jax_utils.unreplicate(train_metrics)['loss'].item(), 3)
    eval_score = round(list(eval_metric.values())[0], 3)
    metric_name = list(eval_metric.keys())[0]

    print(f"{i+1}/{num_train_epochs} | Train loss: {loss} | Eval {metric_name}: {eval_score}")

#Generating Results

Our test dataset has slightly different pre-processing step because we do not have a label in the dataset. So, we should handle accordingly.

In [None]:
#Code by Kartheek Akella  https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook

def preprocess_test_set_function(examples):
    texts = (examples["text"],)
    processed = tokenizer(*texts, padding="max_length", max_length=128, truncation=True)
    
    return processed

In [None]:
tokenized_test_dataset = raw_test.map(preprocess_test_set_function, batched=True, remove_columns=raw_test["test"].column_names)

In [None]:
test_dataset = tokenized_test_dataset["test"]
test_dataset

We won't shard our data anymore because usually the test sets are very small and can be done entirely on one-core without having the additional overheads. So, we also have to "un-shard" our model and run entirely on the single device of the device slice. So we use the unreplicate method in the flax library.

#Generation

Final step. We have successfully fine-tuned a BERT model to the Lit-Readability task. That's amazing! It took us less than 10 mins to reach a very good score! Now it is time to get our model predictions on our test set.

In [None]:
#Code by Kartheek Akella  https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook


def test_data_loader(dataset, batch_size):
    if len(dataset)<batch_size:
        batch = dataset[:]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        yield batch
    else:
        for i in range(len(dataset) // batch_size):
            batch = dataset[i * batch_size : (i + 1) * batch_size]
            batch = {k: jnp.array(v) for k, v in batch.items()}

            yield batch
        batch = dataset[(i+1) * batch_size:]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        yield batch

In [None]:
from flax.jax_utils import unreplicate

unrep_state = unreplicate(state)

In [None]:
#Code by Kartheek Akella  https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook

def generate_results():
    preds = []
    for batch in test_data_loader(test_dataset, total_batch_size):
        if jax.process_index()==0:
            predictions = unrep_state.apply_fn(**batch, train=False, return_dict=False)
            preds.append(predictions[0])
    return preds

In [None]:
preds = generate_results()

Now we clean-up and make our results "Submission ready". First we convert all JAX DeviceArray objects to Numpy arrays, then we create a submission file.

In [None]:
import numpy as np
preds = np.vstack([np.asarray(x) for x in preds])
preds

In [None]:
#import pandas as pd
#sample = pd.read_csv('../input/commonlitreadabilityprize/sample_submission.csv')
#sample.target = preds
#sample

In [None]:
#sample.to_csv('submission.csv',index=False)

#Background: JAX  https://flax.readthedocs.io/en/latest/overview.html#background-jax

TODO: Add an example for running on Google Cloud.

#Code by Kartheek Akella  https://www.kaggle.com/asvskartheek/bert-tpus-jax-huggingface/notebook