# Getting Started with ModernBERT & GLUE

Created by: [Wayde Gilliam](https://twitter.com/waydegilliam)

## Encoders Strike Back!

Like many, I have fond memories of finetuning deberta, roberta and bert models for a number of Kaggle comps and real-world problems (e.g., NER, sentiment analysis, etc.).  Encoder models were "the thing" back in the day and continue to be the primary workhorse for many ML pipelines today though they have been eclipsed by recent advancements in LLMs which typically are based on decoder-only architectures. Long have we awaited a return to an encoder model for the modern world. With ModernBERT, that wait is over! ModernBERT is a new encoder-only model that incorporates the latest features in making neural networks more efficient, faster, and better at handling tasks that encoder models have long excelled at such at text classification.  In addition, ModernBERT allows us to break out of that max 512 token limit with their long context capabilities which give us 8,192 tokens to play with.

In this tutorial, we'll go through the steps of fine-tuning ModernBERT for one of the GLUE tasks, MRPC.  We'll cover some key settings required to use it with the HuggingFace trainer and include with some recommended hyperparameters that have served us well in fine-tuning ModernBERT for GLUE.  We'll also see how to use the model for inference and cleanup the model from the GPU to free up resources.

As an aside, I'm running all this code on a single 3090 with plenty of GPU memory to spare.

Though not strictly necessary, **ModernBERT trains better with FlashAttention!**. Training and inference will be much faster with it installed. See below:

ModernBERT is built on top of FlashAttention which is a highly optimized implementation of the attention mechanism that is faster and more memory efficient than the standard implementation.  ***The beauty of this is all you need to do is install it for ModernBERT to work with it!***  Here's how ...

For NVIDIA GPUs with compute capability 8.0+ (Ampere/Ada/Hopper architecture - A100, A6000, RTX 3090, RTX 4090, H100 etc):
```python
pip install flash-attn --no-build-isolation
```

For older NVIDIA GPUs (pre-Ampere):
```python
pip install flash-attn --no-deps
```


In [None]:
#! pip install setuptools transformers datasets accelerate scikit-learn -Uqq
# install setuptools and do this before installing flash-attn
# pip install flash-attn --no-build-isolation


In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"


In [3]:
import numpy as np
import pandas as pd
import torch
from functools import partial
import gc

from datasets import load_dataset
from sklearn.metrics import f1_score
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    TrainingArguments,
    Trainer,
    TrainerCallback,
)

from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score
from scipy.stats import pearsonr, spearmanr

os.environ["TOKENIZERS_PARALLELISM"] = "false"


  from .autonotebook import tqdm as notebook_tqdm


## What is GLUE?

The [General Language Understanding Evaluation (GLUE) benchmark](https://gluebenchmark.com/) is a collection of nine diverse natural language understanding tasks designed to evaluate and compare the performance of NLP models across various language comprehension challenges. By providing a standardized framework, GLUE facilitates the development of models that generalize well across multiple tasks, promoting advancements in creating robust and versatile language understanding systems. 

Let's put this all these tasks in a dictionary along with some other helpful metadata about each one that might prove useful to iteratting over all of them.



In [4]:
glue_tasks = {
    "cola": {
        "abbr": "CoLA",
        "name": "Corpus of Linguistic Acceptability",
        "description": "Predict whether a sequence is a grammatical English sentence",
        "task_type": "Single-Sentence Task",
        "domain": "Misc.",
        "size": "8.5k",
        "metrics": "Matthews corr.",
        "dataset_names": {"train": "train", "valid": "validation", "test": "test"},
        "inputs": ["sentence"],
        "target": "label",
        "metric_funcs": [matthews_corrcoef],
        "n_labels": 2,
    },
    "sst2": {
        "abbr": "SST-2",
        "name": "Stanford Sentiment Treebank",
        "description": "Predict the sentiment of a given sentence",
        "task_type": "Single-Sentence Task",
        "domain": "Movie reviews",
        "size": "67k",
        "metrics": "Accuracy",
        "dataset_names": {"train": "train", "valid": "validation", "test": "test"},
        "inputs": ["sentence"],
        "target": "label",
        "metric_funcs": [accuracy_score],
        "n_labels": 2,
    },
    "mrpc": {
        "abbr": "MRPC",
        "name": "Microsoft Research Paraphrase Corpus",
        "description": "Predict whether two sentences are semantically equivalent",
        "task_type": "Similarity and Paraphrase Tasks",
        "domain": "News",
        "size": "3.7k",
        "metrics": "F1/Accuracy",
        "dataset_names": {"train": "train", "valid": "validation", "test": "test"},
        "inputs": ["sentence1", "sentence2"],
        "target": "label",
        "metric_funcs": [accuracy_score, f1_score],
        "n_labels": 2,
    },
    "stsb": {
        "abbr": "SST-B",
        "name": "Semantic Textual Similarity Benchmark",
        "description": "Predict the similarity score for two sentences on a scale from 1 to 5",
        "task_type": "Similarity and Paraphrase Tasks",
        "domain": "Misc.",
        "size": "7k",
        "metrics": "Pearson/Spearman corr.",
        "dataset_names": {"train": "train", "valid": "validation", "test": "test"},
        "inputs": ["sentence1", "sentence2"],
        "target": "label",
        "metric_funcs": [pearsonr, spearmanr],
        "n_labels": 1,
    },
    "qqp": {
        "abbr": "QQP",
        "name": "Quora question pair",
        "description": "Predict if two questions are a paraphrase of one another",
        "task_type": "Similarity and Paraphrase Tasks",
        "domain": "Social QA questions",
        "size": "364k",
        "metrics": "F1/Accuracy",
        "dataset_names": {"train": "train", "valid": "validation", "test": "test"},
        "inputs": ["question1", "question2"],
        "target": "label",
        "metric_funcs": [f1_score, accuracy_score],
        "n_labels": 2,
    },
    "mnli-matched": {
        "abbr": "MNLI",
        "name": "Mulit-Genre Natural Language Inference",
        "description": "Predict whether the premise entails, contradicts or is neutral to the hypothesis",
        "task_type": "Inference Tasks",
        "domain": "Misc.",
        "size": "393k",
        "metrics": "Accuracy",
        "dataset_names": {"train": "train", "valid": "validation_matched", "test": "test_matched"},
        "inputs": ["premise", "hypothesis"],
        "target": "label",
        "metric_funcs": [accuracy_score],
        "n_labels": 3,
    },
    "mnli-mismatched": {
        "abbr": "MNLI",
        "name": "Mulit-Genre Natural Language Inference",
        "description": "Predict whether the premise entails, contradicts or is neutral to the hypothesis",
        "task_type": "Inference Tasks",
        "domain": "Misc.",
        "size": "393k",
        "metrics": "Accuracy",
        "dataset_names": {"train": "train", "valid": "validation_mismatched", "test": "test_mismatched"},
        "inputs": ["premise", "hypothesis"],
        "target": "label",
        "metric_funcs": [accuracy_score],
        "n_labels": 3,
    },
    "qnli": {
        "abbr": "QNLI",
        "name": "Stanford Question Answering Dataset",
        "description": "Predict whether the context sentence contains the answer to the question",
        "task_type": "Inference Tasks",
        "domain": "Wikipedia",
        "size": "105k",
        "metrics": "Accuracy",
        "dataset_names": {"train": "train", "valid": "validation", "test": "test"},
        "inputs": ["question", "sentence"],
        "target": "label",
        "metric_funcs": [accuracy_score],
        "n_labels": 2,
    },
    "rte": {
        "abbr": "RTE",
        "name": "Recognize Textual Entailment",
        "description": "Predict whether one sentece entails another",
        "task_type": "Inference Tasks",
        "domain": "News, Wikipedia",
        "size": "2.5k",
        "metrics": "Accuracy",
        "dataset_names": {"train": "train", "valid": "validation", "test": "test"},
        "inputs": ["sentence1", "sentence2"],
        "target": "label",
        "metric_funcs": [accuracy_score],
        "n_labels": 2,
    },
    "wnli": {
        "abbr": "WNLI",
        "name": "Winograd Schema Challenge",
        "description": "Predict if the sentence with the pronoun substituted is entailed by the original sentence",
        "task_type": "Inference Tasks",
        "domain": "Fiction books",
        "size": "634",
        "metrics": "Accuracy",
        "dataset_names": {"train": "train", "valid": "validation", "test": "test"},
        "inputs": ["sentence1", "sentence2"],
        "target": "label",
        "metric_funcs": [accuracy_score],
        "n_labels": 2,
    },
}

# for v in glue_tasks.values(): print(v)
glue_tasks.values()

glue_df = pd.DataFrame(glue_tasks.values(), columns=["abbr", "name", "task_type", "description", "size", "metrics"])
glue_df.columns = glue_df.columns.str.replace("_", " ").str.capitalize()
display(glue_df.style.set_properties(**{"text-align": "left"}))


Unnamed: 0,Abbr,Name,Task type,Description,Size,Metrics
0,CoLA,Corpus of Linguistic Acceptability,Single-Sentence Task,Predict whether a sequence is a grammatical English sentence,8.5k,Matthews corr.
1,SST-2,Stanford Sentiment Treebank,Single-Sentence Task,Predict the sentiment of a given sentence,67k,Accuracy
2,MRPC,Microsoft Research Paraphrase Corpus,Similarity and Paraphrase Tasks,Predict whether two sentences are semantically equivalent,3.7k,F1/Accuracy
3,SST-B,Semantic Textual Similarity Benchmark,Similarity and Paraphrase Tasks,Predict the similarity score for two sentences on a scale from 1 to 5,7k,Pearson/Spearman corr.
4,QQP,Quora question pair,Similarity and Paraphrase Tasks,Predict if two questions are a paraphrase of one another,364k,F1/Accuracy
5,MNLI,Mulit-Genre Natural Language Inference,Inference Tasks,"Predict whether the premise entails, contradicts or is neutral to the hypothesis",393k,Accuracy
6,MNLI,Mulit-Genre Natural Language Inference,Inference Tasks,"Predict whether the premise entails, contradicts or is neutral to the hypothesis",393k,Accuracy
7,QNLI,Stanford Question Answering Dataset,Inference Tasks,Predict whether the context sentence contains the answer to the question,105k,Accuracy
8,RTE,Recognize Textual Entailment,Inference Tasks,Predict whether one sentece entails another,2.5k,Accuracy
9,WNLI,Winograd Schema Challenge,Inference Tasks,Predict if the sentence with the pronoun substituted is entailed by the original sentence,634,Accuracy


# Convert to huggingface format

In [15]:
import torch
from transformers import AutoConfig, AutoModelForMaskedLM

# 1. Load your config (edit as needed)
config = AutoConfig.from_pretrained("hf-checkpoints/modernbert-q")

# 2. Instantiate the model
model = AutoModelForMaskedLM(config)

# 3. Load the .pt checkpoint (update path as needed)
state_dict = torch.load("checkpoints/modernbert-base-pretrain/latest-rank0.pt", map_location="cpu")

# If the checkpoint is a Composer or Lightning checkpoint, you may need to extract the state_dict:
# state_dict = torch.load("...pt", map_location="cpu")["state"]["model"]

# 4. Load weights into the model
model.load_state_dict(state_dict, strict=False)

# 5. Save in HuggingFace format
model.save_pretrained("hf-checkpoints/modernbert")
config.save_pretrained("hf-checkpoints/modernbert")

OSError: hf-checkpoints/modernbert-q is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `hf auth login` or by passing `token=<your_token>`

## Let's Fine-Tune ModernBERT for MRPC

### Configuration

ModernBERT currently comes in two flavors, base and large. To keep things lean and mean, we'll use the "answerdotai/ModernBERT-base" checkpoint for this example.

In [13]:
task = "mrpc"
task_meta = glue_tasks[task]
train_ds_name = task_meta["dataset_names"]["train"]
valid_ds_name = task_meta["dataset_names"]["valid"]
test_ds_name = task_meta["dataset_names"]["test"]

task_inputs = task_meta["inputs"]
task_target = task_meta["target"]
n_labels = task_meta["n_labels"]
task_metrics = task_meta["metric_funcs"]

checkpoint = "../checkpoints/modernbert-base-pretrain/latest-rank0.pt"  # "answerdotai/ModernBERT-base", "answerdotai/ModernBERT-large"

### Data

We'll use the `Datasets` library to load the data.  As its always recommended to "look at your data" before we get training, we'll also print out a single example to see what we're working with as well as the features of the dataset.

In [6]:
raw_datasets = load_dataset("glue", task)

print(f"{raw_datasets}\n")
print(f"{raw_datasets[train_ds_name][0]}\n")
print(f"{raw_datasets[train_ds_name].features}\n")

Generating train split: 100%|██████████| 3668/3668 [00:00<00:00, 221346.77 examples/s]
Generating validation split: 100%|██████████| 408/408 [00:00<00:00, 294691.93 examples/s]
Generating test split: 100%|██████████| 1725/1725 [00:00<00:00, 885361.53 examples/s]

DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 3668
    })
    validation: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 408
    })
    test: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 1725
    })
})

{'sentence1': 'Amrozi accused his brother , whom he called " the witness " , of deliberately distorting his evidence .', 'sentence2': 'Referring to him as only " the witness " , Amrozi accused his brother of deliberately distorting his evidence .', 'label': 1, 'idx': 0}

{'sentence1': Value('string'), 'sentence2': Value('string'), 'label': ClassLabel(names=['not_equivalent', 'equivalent']), 'idx': Value('int32')}






We can use the following dictionaries when building our model with `AutoModelForSequenceClassification` to map between the label ids and names.

In [7]:
def get_label_maps(raw_datasets, train_ds_name):
    labels = raw_datasets[train_ds_name].features["label"]

    id2label = {idx: name.upper() for idx, name in enumerate(labels.names)} if hasattr(labels, "names") else None
    label2id = {name.upper(): idx for idx, name in enumerate(labels.names)} if hasattr(labels, "names") else None

    return id2label, label2id

In [8]:
id2label, label2id = get_label_maps(raw_datasets, train_ds_name)

print(f"{id2label}")
print(f"{label2id}")


{0: 'NOT_EQUIVALENT', 1: 'EQUIVALENT'}
{'NOT_EQUIVALENT': 0, 'EQUIVALENT': 1}


MRPC is a sentence-pair classification task where we're given two sentences and asked to predict whether they are paraphrases of one another.  The dataset is split into train, validation and test sets. We'll need to keep all this in mind when we set up tokenization next with `AutoTokenizer`.

### Tokenizer

Next we define our Tokenizer and a preprocess function to create the input_ids, attention_mask, and token_type_ids the model nees to train.  For this example, including `truncation=True` is enough as we'll rely on our data collation function below to put our batches into the correct shape.

In [14]:
hf_tokenizer = AutoTokenizer.from_pretrained(checkpoint)

HFValidationError: Repo id must be in the form 'repo_name' or 'namespace/repo_name': '../checkpoints/modernbert-base-pretrain/latest-rank0.pt'. Use `repo_type` argument if needed.

In [10]:
task_inputs

['sentence1', 'sentence2']

In [11]:
def preprocess_function(examples, task_inputs):
    inps = [examples[inp] for inp in task_inputs]
    tokenized = hf_tokenizer(*inps, truncation=True)
    return tokenized

In [12]:
tokenized_datasets = raw_datasets.map(partial(preprocess_function, task_inputs=task_inputs), batched=True)

print(f"{tokenized_datasets}\n")
print(f"{tokenized_datasets[train_ds_name][0]}\n")
print(f"{tokenized_datasets[train_ds_name].features}\n")

DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 3668
    })
    validation: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 408
    })
    test: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 1725
    })
})

{'sentence1': 'Amrozi accused his brother , whom he called " the witness " , of deliberately distorting his evidence .', 'sentence2': 'Referring to him as only " the witness " , Amrozi accused his brother of deliberately distorting his evidence .', 'label': 1, 'idx': 0, 'input_ids': [50281, 8096, 287, 9877, 10145, 521, 4929, 1157, 5207, 344, 1925, 346, 253, 5517, 346, 1157, 273, 21547, 940, 12655, 521, 1941, 964, 50282, 7676, 24247, 281, 779, 347, 760, 346, 253, 5517, 346, 1157, 3052, 287, 9877, 10145, 521, 4929, 273, 21547, 940, 12655, 521,

It's always good to see what the tokenizer is doing to our data to ensure the special tokens are where we expect them to be!

In [13]:
hf_tokenizer.decode(tokenized_datasets[train_ds_name][0]["input_ids"])

'[CLS]Amrozi accused his brother, whom he called " the witness ", of deliberately distorting his evidence.[SEP]Referring to him as only " the witness ", Amrozi accused his brother of deliberately distorting his evidence.[SEP]'

### Metrics

We'll use our `task_metrics` to compute the metrics for our model.  We'll return a dictionary of the metric name and value for each metric we're interested in.

In [14]:
def compute_metrics(eval_pred, task_metrics):
    predictions, labels = eval_pred

    metrics_d = {}
    for metric_func in task_metrics:
        metric_name = metric_func.__name__
        if metric_name in ["pearsonr", "spearmanr"]:
            score = metric_func(labels, np.squeeze(predictions))
        else:
            score = metric_func(np.argmax(predictions, axis=-1), labels)

        if isinstance(score, tuple):
            metrics_d[metric_func.__name__] = score[0]
        else:
            metrics_d[metric_func.__name__] = score

    return metrics_d

### Train

This is where the fun begins! Here we setup a few hyperparameters than have proven to work well for us in fine-tuning ModernBERT-base on GLUE tasks.  We'll also setup our model, data collator, and training arguments.

In [15]:
train_bsz, val_bsz = 32, 32
lr = 8e-5
betas = (0.9, 0.98)
n_epochs = 2
eps = 1e-6
wd = 8e-6

When configuring `AutoModelForSequenceClassification`, two settings are critical to get things working with the HuggingFace `Trainer`. One is the `num_labels` we're expecting and the other is to set `compile=False` to avoid using the `torch.compile` function which is not supported in Transformers at the time of this writing.

In [16]:
hf_model = AutoModelForSequenceClassification.from_pretrained(
    checkpoint, num_labels=n_labels, id2label=id2label, label2id=label2id, compile=False
)


Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Collation is easy for GLUE tasks as we can use the `DataCollatorWithPadding` class to pad our input_ids and attention_mask to the max length in the batch.

**Note**: If you have installed Flash Attention, ModernBERT removes the padding internally, which makes it the fastest version. SPDA and Eager mode will be slower.

In [17]:
hf_data_collator = DataCollatorWithPadding(tokenizer=hf_tokenizer)

With all the pieces in place, we can now setup our `TrainingArguments` and `Trainer` and get to training! Lots of customization is possible here and it is recommended to play with different schedulers and the hyperparameters we've started y'all off with above to improve results.

In [18]:
training_args = TrainingArguments(
    output_dir=f"aai_ModernBERT_{task}_ft",
    learning_rate=lr,
    per_device_train_batch_size=train_bsz,
    per_device_eval_batch_size=val_bsz,
    num_train_epochs=n_epochs,
    lr_scheduler_type="linear",
    optim="adamw_torch",
    adam_beta1=betas[0],
    adam_beta2=betas[1],
    adam_epsilon=eps,
    logging_strategy="epoch",
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    bf16=True,
    bf16_full_eval=True,
    push_to_hub=False,
)

We define `TrainerCallback` so that we can capture all the training and evaluation logs and store them for later analysis. By default, the `Trainer` class will only keep the latest logs.


In [19]:
class MetricsCallback(TrainerCallback):
    def __init__(self):
        self.training_history = {"train": [], "eval": []}

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            if "loss" in logs:  # Training logs
                self.training_history["train"].append(logs)
            elif "eval_loss" in logs:  # Evaluation logs
                self.training_history["eval"].append(logs)

In [20]:
trainer = Trainer(
    model=hf_model,
    args=training_args,
    train_dataset=tokenized_datasets[train_ds_name],
    eval_dataset=tokenized_datasets[valid_ds_name],
    processing_class=hf_tokenizer,
    data_collator=hf_data_collator,
    compute_metrics=partial(compute_metrics, task_metrics=task_metrics),
)

metrics_callback = MetricsCallback()
trainer.add_callback(metrics_callback)

trainer.train()

train_history_df = pd.DataFrame(metrics_callback.training_history["train"])
train_history_df = train_history_df.add_prefix("train_")
eval_history_df = pd.DataFrame(metrics_callback.training_history["eval"])
train_res_df = pd.concat([train_history_df, eval_history_df], axis=1)

args_df = pd.DataFrame([training_args.to_dict()])

display(train_res_df)
display(args_df)

Epoch,Training Loss,Validation Loss,Accuracy Score,F1 Score
1,0.606,0.550361,0.720588,0.817308
2,0.4148,0.499648,0.754902,0.822064


Unnamed: 0,train_loss,train_grad_norm,train_learning_rate,train_epoch,eval_loss,eval_accuracy_score,eval_f1_score
0,0.606,3.989608,4e-05,1.0,0.550361,0.720588,0.817308
1,0.4148,5.57359,0.0,2.0,0.499648,0.754902,0.822064


Unnamed: 0,output_dir,overwrite_output_dir,do_train,do_eval,do_predict,eval_strategy,prediction_loss_only,per_device_train_batch_size,per_device_eval_batch_size,per_gpu_train_batch_size,...,split_batches,include_tokens_per_second,include_num_input_tokens_seen,neftune_noise_alpha,optim_target_modules,batch_eval_metrics,eval_on_start,use_liger_kernel,eval_use_gather_object,average_tokens_across_devices
0,aai_modernbert_mrpc_ft,False,False,True,False,epoch,False,32,32,,...,,False,False,,,False,False,False,False,False


### Inference

There'a number of options for inference within the HuggingFace ecosystem.  We'll go a bit old school here and just use the `forward` method of the model. We're not uploading this model to the hub, but this is an easy enough task for you to try out on your own should you like to share your ModernBERT finetune :).

In [21]:
ex_1 = "The quick brown fox jumps over the lazy dog."
ex_2 = "I love lamp!"

inf_inputs = hf_tokenizer(ex_1, ex_2, return_tensors="pt")
inf_inputs = inf_inputs.to("cuda")

with torch.no_grad():
    logits = hf_model(**inf_inputs).logits

print(logits)
print(f"Prediction: {hf_model.config.id2label[logits.argmax().item()]}")


tensor([[ 1.2422, -1.5234]], device='cuda:0')
Prediction: NOT_EQUIVALENT


### Cleanup

In [22]:
def cleanup(things_to_delete: list | None = None):
    if things_to_delete is not None:
        for thing in things_to_delete:
            if thing is not None:
                del thing

    gc.collect()
    torch.cuda.empty_cache()


In [23]:
cleanup(things_to_delete=[hf_model, trainer])

## Train all the GLUE!

If you got this far you're probably wondering why I put together that dictionary of GLUE tasks if all we're doing is finetuning a single model. The answer is basically that I'm a good and lazy programmer who would like to easily run hyperparameter sweeps and/or fine-tunes on all the GLUE tasks. So ... let's do that!

We'll run with the training hyperparameters specified above and I leave it to the reader to improve the method below to be able to override these values should folks be looking for something to do :)

In [24]:
def finetune_glue_task(
    task: str, checkpoint: str = "answerdotai/ModernBERT-base", train_subset: int | None = None, do_cleanup: bool = True
):  # 1. Load the task metadata
    task_meta = glue_tasks[task]
    train_ds_name = task_meta["dataset_names"]["train"]
    valid_ds_name = task_meta["dataset_names"]["valid"]

    task_inputs = task_meta["inputs"]
    n_labels = task_meta["n_labels"]
    task_metrics = task_meta["metric_funcs"]

    # 2. Load the dataset
    raw_datasets = load_dataset("glue", task.split("-")[0] if "-" in task else task)
    if train_subset is not None and len(raw_datasets["train"]) > train_subset:
        raw_datasets["train"] = raw_datasets["train"].shuffle(seed=42).select(range(train_subset))

    id2label, label2id = get_label_maps(raw_datasets, train_ds_name)

    # 3. Load the tokenizer
    hf_tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    tokenized_datasets = raw_datasets.map(partial(preprocess_function, task_inputs=task_inputs), batched=True)

    # 4. Define the compute metrics function
    task_compute_metrics = partial(compute_metrics, task_metrics=task_metrics)

    # 5. Load the model and data collator
    model_additional_kwargs = {"id2label": id2label, "label2id": label2id} if id2label and label2id else {}
    hf_model = AutoModelForSequenceClassification.from_pretrained(
        checkpoint, num_labels=n_labels, compile=False, **model_additional_kwargs
    )

    hf_data_collator = DataCollatorWithPadding(tokenizer=hf_tokenizer)

    # 6. Define the training arguments and trainer
    training_args = TrainingArguments(
        output_dir=f"aai_ModernBERT_{task}_ft",
        learning_rate=lr,
        per_device_train_batch_size=train_bsz,
        per_device_eval_batch_size=val_bsz,
        num_train_epochs=n_epochs,
        lr_scheduler_type="linear",
        optim="adamw_torch",
        adam_beta1=betas[0],
        adam_beta2=betas[1],
        adam_epsilon=eps,
        logging_strategy="epoch",
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        bf16=True,
        bf16_full_eval=True,
        push_to_hub=False,
    )

    trainer = Trainer(
        model=hf_model,
        args=training_args,
        train_dataset=tokenized_datasets[train_ds_name],
        eval_dataset=tokenized_datasets[valid_ds_name],
        processing_class=hf_tokenizer,
        data_collator=hf_data_collator,
        compute_metrics=task_compute_metrics,
    )

    # Add callback to trainer
    metrics_callback = MetricsCallback()
    trainer.add_callback(metrics_callback)

    trainer.train()

    # 7. Get the training results and hyperparameters
    train_history_df = pd.DataFrame(metrics_callback.training_history["train"])
    train_history_df = train_history_df.add_prefix("train_")
    eval_history_df = pd.DataFrame(metrics_callback.training_history["eval"])
    train_res_df = pd.concat([train_history_df, eval_history_df], axis=1)

    args_df = pd.DataFrame([training_args.to_dict()])

    # 8. Cleanup (optional)
    if do_cleanup:
        cleanup(things_to_delete=[trainer, hf_model, hf_tokenizer, tokenized_datasets, raw_datasets])

    return train_res_df, args_df, hf_model, hf_tokenizer

This helpful function encapsulates all the steps we've been through above and allows us to easily run a fine-tune on a single task. In addition to the HuggingFace objects, it returns the training results, training hyperparameters (all potentially helpful for performing sweeps and or documenting your results).

Let's give it a go on both MRPC and CoLA.


In [25]:
train_res_df, args_df, hf_model, hf_tokenizer = finetune_glue_task(
    "mrpc", checkpoint="answerdotai/ModernBERT-base", do_cleanup=True
)

display(train_res_df)
display(args_df)


Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy Score,F1 Score
1,0.5637,0.414041,0.811275,0.872305
2,0.3267,0.371047,0.838235,0.881295


Unnamed: 0,train_loss,train_grad_norm,train_learning_rate,train_epoch,eval_loss,eval_accuracy_score,eval_f1_score
0,0.5637,3.246023,4e-05,1.0,0.414041,0.811275,0.872305
1,0.3267,5.002162,0.0,2.0,0.371047,0.838235,0.881295


Unnamed: 0,output_dir,overwrite_output_dir,do_train,do_eval,do_predict,eval_strategy,prediction_loss_only,per_device_train_batch_size,per_device_eval_batch_size,per_gpu_train_batch_size,...,split_batches,include_tokens_per_second,include_num_input_tokens_seen,neftune_noise_alpha,optim_target_modules,batch_eval_metrics,eval_on_start,use_liger_kernel,eval_use_gather_object,average_tokens_across_devices
0,aai_modernbert_mrpc_ft,False,False,True,False,epoch,False,32,32,,...,,False,False,,,False,False,False,False,False


In [26]:
train_res_df, args_df, hf_model, hf_tokenizer = finetune_glue_task(
    "cola", checkpoint="answerdotai/ModernBERT-base", do_cleanup=True
)

display(train_res_df)
display(args_df)

Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Matthews Corrcoef
1,0.6151,0.525902,0.323947
2,0.4066,0.44187,0.492141


Unnamed: 0,train_loss,train_grad_norm,train_learning_rate,train_epoch,eval_loss,eval_matthews_corrcoef
0,0.6151,8.100638,4e-05,1.0,0.525902,0.323947
1,0.4066,11.885351,0.0,2.0,0.44187,0.492141


Unnamed: 0,output_dir,overwrite_output_dir,do_train,do_eval,do_predict,eval_strategy,prediction_loss_only,per_device_train_batch_size,per_device_eval_batch_size,per_gpu_train_batch_size,...,split_batches,include_tokens_per_second,include_num_input_tokens_seen,neftune_noise_alpha,optim_target_modules,batch_eval_metrics,eval_on_start,use_liger_kernel,eval_use_gather_object,average_tokens_across_devices
0,aai_modernbert_cola_ft,False,False,True,False,epoch,False,32,32,,...,,False,False,,,False,False,False,False,False


**Send it!**

Grab yourself a good cup of coffee, take your pups out for a walk, or whatever as your GPU purrs along while finetuning all things GLUE!

Note the `train_subset` parameter which allows us to train on a subset of the dataset. This is helpful for quickly testing the model on a small dataset to make sure all the bits work as expected.  Feel free to set it to `None` for a full send!

In [27]:
for task in glue_tasks.keys():
    print(f"----- Finetuning {task} -----")
    train_res_df, args_df, hf_model, hf_tokenizer = finetune_glue_task(
        task, checkpoint="answerdotai/ModernBERT-base", train_subset=1_000, do_cleanup=True
    )

    print(":: Results ::")
    display(train_res_df)
    display(args_df)


----- Finetuning cola -----


Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Matthews Corrcoef
1,0.6078,0.604273,0.04376
2,0.405,0.595953,0.22808


:: Results ::


Unnamed: 0,train_loss,train_grad_norm,train_learning_rate,train_epoch,eval_loss,eval_matthews_corrcoef
0,0.6078,40.798195,4e-05,1.0,0.604273,0.04376
1,0.405,12.89271,0.0,2.0,0.595953,0.22808


Unnamed: 0,output_dir,overwrite_output_dir,do_train,do_eval,do_predict,eval_strategy,prediction_loss_only,per_device_train_batch_size,per_device_eval_batch_size,per_gpu_train_batch_size,...,split_batches,include_tokens_per_second,include_num_input_tokens_seen,neftune_noise_alpha,optim_target_modules,batch_eval_metrics,eval_on_start,use_liger_kernel,eval_use_gather_object,average_tokens_across_devices
0,aai_modernbert_cola_ft,False,False,True,False,epoch,False,32,32,,...,,False,False,,,False,False,False,False,False


----- Finetuning sst2 -----


Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy Score
1,0.5692,0.393088,0.834862
2,0.2314,0.313808,0.870413


:: Results ::


Unnamed: 0,train_loss,train_grad_norm,train_learning_rate,train_epoch,eval_loss,eval_accuracy_score
0,0.5692,13.669675,4e-05,1.0,0.393088,0.834862
1,0.2314,5.367749,0.0,2.0,0.313808,0.870413


Unnamed: 0,output_dir,overwrite_output_dir,do_train,do_eval,do_predict,eval_strategy,prediction_loss_only,per_device_train_batch_size,per_device_eval_batch_size,per_gpu_train_batch_size,...,split_batches,include_tokens_per_second,include_num_input_tokens_seen,neftune_noise_alpha,optim_target_modules,batch_eval_metrics,eval_on_start,use_liger_kernel,eval_use_gather_object,average_tokens_across_devices
0,aai_modernbert_sst2_ft,False,False,True,False,epoch,False,32,32,,...,,False,False,,,False,False,False,False,False


----- Finetuning mrpc -----


Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy Score,F1 Score
1,0.6376,0.508369,0.752451,0.831386
2,0.4279,0.446494,0.786765,0.851789


:: Results ::


Unnamed: 0,train_loss,train_grad_norm,train_learning_rate,train_epoch,eval_loss,eval_accuracy_score,eval_f1_score
0,0.6376,13.762523,4e-05,1.0,0.508369,0.752451,0.831386
1,0.4279,10.919047,0.0,2.0,0.446494,0.786765,0.851789


Unnamed: 0,output_dir,overwrite_output_dir,do_train,do_eval,do_predict,eval_strategy,prediction_loss_only,per_device_train_batch_size,per_device_eval_batch_size,per_gpu_train_batch_size,...,split_batches,include_tokens_per_second,include_num_input_tokens_seen,neftune_noise_alpha,optim_target_modules,batch_eval_metrics,eval_on_start,use_liger_kernel,eval_use_gather_object,average_tokens_across_devices
0,aai_modernbert_mrpc_ft,False,False,True,False,epoch,False,32,32,,...,,False,False,,,False,False,False,False,False


----- Finetuning stsb -----


Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Pearsonr,Spearmanr
1,2.4414,1.323127,0.735928,0.744267
2,0.7436,0.831915,0.802678,0.804716


:: Results ::


Unnamed: 0,train_loss,train_grad_norm,train_learning_rate,train_epoch,eval_loss,eval_pearsonr,eval_spearmanr
0,2.4414,43.095963,4e-05,1.0,1.323127,0.735928,0.744267
1,0.7436,22.308477,0.0,2.0,0.831915,0.802678,0.804716


Unnamed: 0,output_dir,overwrite_output_dir,do_train,do_eval,do_predict,eval_strategy,prediction_loss_only,per_device_train_batch_size,per_device_eval_batch_size,per_gpu_train_batch_size,...,split_batches,include_tokens_per_second,include_num_input_tokens_seen,neftune_noise_alpha,optim_target_modules,batch_eval_metrics,eval_on_start,use_liger_kernel,eval_use_gather_object,average_tokens_across_devices
0,aai_modernbert_stsb_ft,False,False,True,False,epoch,False,32,32,,...,,False,False,,,False,False,False,False,False


----- Finetuning qqp -----


Map:   0%|          | 0/390965 [00:00<?, ? examples/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Map: 100%|██████████| 390965/390965 [00:33<00:00, 11761.24 examples/s]
Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,F1 Score,Accuracy Score
1,0.6074,0.499629,0.621215,0.728321
2,0.485,0.524866,0.615188,0.732278


:: Results ::


Unnamed: 0,train_loss,train_grad_norm,train_learning_rate,train_epoch,eval_loss,eval_f1_score,eval_accuracy_score
0,0.6074,8.23535,4e-05,1.0,0.499629,0.621215,0.728321
1,0.485,36.08305,0.0,2.0,0.524866,0.615188,0.732278


Unnamed: 0,output_dir,overwrite_output_dir,do_train,do_eval,do_predict,eval_strategy,prediction_loss_only,per_device_train_batch_size,per_device_eval_batch_size,per_gpu_train_batch_size,...,split_batches,include_tokens_per_second,include_num_input_tokens_seen,neftune_noise_alpha,optim_target_modules,batch_eval_metrics,eval_on_start,use_liger_kernel,eval_use_gather_object,average_tokens_across_devices
0,aai_modernbert_qqp_ft,False,False,True,False,epoch,False,32,32,,...,,False,False,,,False,False,False,False,False


----- Finetuning mnli-matched -----


Map:   0%|          | 0/9847 [00:00<?, ? examples/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Map: 100%|██████████| 9847/9847 [00:01<00:00, 9226.73 examples/s]
Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy Score
1,1.1471,1.068802,0.406215
2,0.9358,1.027365,0.471014


:: Results ::


Unnamed: 0,train_loss,train_grad_norm,train_learning_rate,train_epoch,eval_loss,eval_accuracy_score
0,1.1471,7.386725,4e-05,1.0,1.068802,0.406215
1,0.9358,13.833081,0.0,2.0,1.027365,0.471014


Unnamed: 0,output_dir,overwrite_output_dir,do_train,do_eval,do_predict,eval_strategy,prediction_loss_only,per_device_train_batch_size,per_device_eval_batch_size,per_gpu_train_batch_size,...,split_batches,include_tokens_per_second,include_num_input_tokens_seen,neftune_noise_alpha,optim_target_modules,batch_eval_metrics,eval_on_start,use_liger_kernel,eval_use_gather_object,average_tokens_across_devices
0,aai_modernbert_mnli-matched_ft,False,False,True,False,epoch,False,32,32,,...,,False,False,,,False,False,False,False,False


----- Finetuning mnli-mismatched -----


Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy Score
1,1.1471,1.055565,0.423922
2,0.9355,1.011183,0.494304


:: Results ::


Unnamed: 0,train_loss,train_grad_norm,train_learning_rate,train_epoch,eval_loss,eval_accuracy_score
0,1.1471,7.386725,4e-05,1.0,1.055565,0.423922
1,0.9355,12.523493,0.0,2.0,1.011183,0.494304


Unnamed: 0,output_dir,overwrite_output_dir,do_train,do_eval,do_predict,eval_strategy,prediction_loss_only,per_device_train_batch_size,per_device_eval_batch_size,per_gpu_train_batch_size,...,split_batches,include_tokens_per_second,include_num_input_tokens_seen,neftune_noise_alpha,optim_target_modules,batch_eval_metrics,eval_on_start,use_liger_kernel,eval_use_gather_object,average_tokens_across_devices
0,aai_modernbert_mnli-mismatched_ft,False,False,True,False,epoch,False,32,32,,...,,False,False,,,False,False,False,False,False


----- Finetuning qnli -----


Map:   0%|          | 0/5463 [00:00<?, ? examples/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Map: 100%|██████████| 5463/5463 [00:00<00:00, 7358.18 examples/s]
Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy Score
1,0.6957,0.52629,0.742998
2,0.4177,0.603152,0.731832


:: Results ::


Unnamed: 0,train_loss,train_grad_norm,train_learning_rate,train_epoch,eval_loss,eval_accuracy_score
0,0.6957,69.748146,4e-05,1.0,0.52629,0.742998
1,0.4177,16.136906,0.0,2.0,0.603152,0.731832


Unnamed: 0,output_dir,overwrite_output_dir,do_train,do_eval,do_predict,eval_strategy,prediction_loss_only,per_device_train_batch_size,per_device_eval_batch_size,per_gpu_train_batch_size,...,split_batches,include_tokens_per_second,include_num_input_tokens_seen,neftune_noise_alpha,optim_target_modules,batch_eval_metrics,eval_on_start,use_liger_kernel,eval_use_gather_object,average_tokens_across_devices
0,aai_modernbert_qnli_ft,False,False,True,False,epoch,False,32,32,,...,,False,False,,,False,False,False,False,False


----- Finetuning rte -----


Map:   0%|          | 0/3000 [00:00<?, ? examples/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Map: 100%|██████████| 3000/3000 [00:00<00:00, 5602.09 examples/s]
Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy Score
1,0.7551,0.707934,0.509025
2,0.6556,0.690976,0.498195


:: Results ::


Unnamed: 0,train_loss,train_grad_norm,train_learning_rate,train_epoch,eval_loss,eval_accuracy_score
0,0.7551,5.044735,4e-05,1.0,0.707934,0.509025
1,0.6556,6.33869,0.0,2.0,0.690976,0.498195


Unnamed: 0,output_dir,overwrite_output_dir,do_train,do_eval,do_predict,eval_strategy,prediction_loss_only,per_device_train_batch_size,per_device_eval_batch_size,per_gpu_train_batch_size,...,split_batches,include_tokens_per_second,include_num_input_tokens_seen,neftune_noise_alpha,optim_target_modules,batch_eval_metrics,eval_on_start,use_liger_kernel,eval_use_gather_object,average_tokens_across_devices
0,aai_modernbert_rte_ft,False,False,True,False,epoch,False,32,32,,...,,False,False,,,False,False,False,False,False


----- Finetuning wnli -----


Map:   0%|          | 0/146 [00:00<?, ? examples/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Map: 100%|██████████| 146/146 [00:00<00:00, 5646.34 examples/s]
Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy Score
1,0.7898,0.7097,0.56338
2,0.7183,0.691901,0.56338


:: Results ::


Unnamed: 0,train_loss,train_grad_norm,train_learning_rate,train_epoch,eval_loss,eval_accuracy_score
0,0.7898,6.372931,4e-05,1.0,0.7097,0.56338
1,0.7183,3.087399,0.0,2.0,0.691901,0.56338


Unnamed: 0,output_dir,overwrite_output_dir,do_train,do_eval,do_predict,eval_strategy,prediction_loss_only,per_device_train_batch_size,per_device_eval_batch_size,per_gpu_train_batch_size,...,split_batches,include_tokens_per_second,include_num_input_tokens_seen,neftune_noise_alpha,optim_target_modules,batch_eval_metrics,eval_on_start,use_liger_kernel,eval_use_gather_object,average_tokens_across_devices
0,aai_modernbert_wnli_ft,False,False,True,False,epoch,False,32,32,,...,,False,False,,,False,False,False,False,False


## Conclusion

With ModernBERT encoders are back baby!  We've seen that ModernBERT-base can compete with the best of them on GLUE tasks and with a little more tuning, we'll see that ModernBERT-large can do even better.  I'm excited to see what the community will do with this model and I'm looking forward to seeing what you all build with it! We'll be exploring more of the capabilities of ModernBERT in future tutorials.

Until next time, happy coding!
