# Fine Tune Modern BERT for Named Entity Recognition Task
### This notebook fine tunes modern BERT using the "bigcode/pii-dataset"  dataset for detecting Named entities and then recognize PII sensitive elements in the identified NER elements.

In [25]:
!pip install --quiet transformers datasets seqeval torch
!pip install --quiet tf_keras
!pip install transformers[torch]



In [49]:
import numpy as np
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoConfig
from transformers import Trainer, TrainingArguments
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score

# ****************************************
# Configure the root of the Project Repo
# ****************************************

In [50]:
debug=False
ROOTDIR ="/Users/pals/MICS/MIDS_266/project/privacy-ner-att"

In [51]:
MAX_SEQ_LENGTH = 128

In [52]:
conll_ds = load_dataset("conll2003", trust_remote_code=True)

In [53]:
conll_ds

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3453
    })
})

In [54]:
ds_ele = conll_ds['train'][4]

In [55]:
if debug:
    print(ds_ele['tokens'], len(ds_ele['tokens']))
    print(ds_ele['pos_tags'], len(ds_ele['pos_tags']))
    print(ds_ele['chunk_tags'], len(ds_ele['chunk_tags']))
    print(ds_ele['ner_tags'], len(ds_ele['ner_tags']))

In [56]:
pos_tags   = conll_ds["train"].features["pos_tags"].feature.names
chunk_tags = conll_ds["train"].features["chunk_tags"].feature.names
ner_tags   = conll_ds["train"].features["ner_tags"].feature.names
if debug:
    print(ner_tags)


In [57]:
pos_id2tag   = { k: v for k,v in enumerate(pos_tags)}
chunk_id2tag = { k: v for k,v in enumerate(chunk_tags)}
ner_id2tag   = { k: v for k,v in enumerate(ner_tags)}
if debug:
    print(ner_id2tag)

In [58]:
pos_tag2id   = { v: k for k,v in enumerate(pos_tags)}
chunk_tag2id = { v: k for k,v in enumerate(chunk_tags)}
ner_tag2id   = { v: k for k,v in enumerate(ner_tags)}
if debug:
    print(ner_tag2id)

In [59]:
def convert_ner_ids_to_label(ner_ids: list) -> list[str]:
    return [ner_id2tag[i] for i in ner_ids]

In [60]:
model_name = "answerdotai/ModernBERT-base"

In [61]:
mbert_tokenizer = AutoTokenizer.from_pretrained(model_name)

In [62]:
def tokenize_and_align_ner_tags(conll_ds: dict, tokenizer: object) -> None:
    input_tokens = tokenizer(conll_ds['tokens'], 
                             truncation=True, 
                             padding="max_length",  # Ensure fixed input size
                             max_length=MAX_SEQ_LENGTH,        # Set max token limit
                             is_split_into_words=True)
    input_labels = []
    for i, ner_tag_i in enumerate(conll_ds['ner_tags']):
        # For each ner_tag line associated with input sentence
        # get the tokenized word ids and then associated the 
        # ner_label. Note: word_ids is not the actual tokens, rather it is 
        # just word index corresponding to original input
        word_ids = input_tokens.word_ids(batch_index=i) 
        aligned_ner_tag = []
        for word_id in word_ids:
            if word_id is None:
                aligned_ner_tag.append(-100)
            else:
                aligned_ner_tag.append(ner_tag_i[word_id]) 
        while len(aligned_ner_tag) < MAX_SEQ_LENGTH:
            aligned_ner_tag.append(-100)
        input_labels.append(aligned_ner_tag)
        if debug:
            if i == 0:
                print("input tokens ", i, conll_ds['tokens'][i])
                print("NER IDS      ", i, ner_tag_i)
                print("NER Labels   ", i, convert_ner_ids_to_label(ner_tag_i))
                print("Word Ids     ", i, input_tokens.word_ids(batch_index=i))
                print("Aligned NER  ", i, aligned_ner_tag)
    input_tokens['labels'] = input_labels
    return input_tokens

In [63]:
conll_ds = conll_ds.map(tokenize_and_align_ner_tags, batched=True, fn_kwargs={'tokenizer': mbert_tokenizer})
conll_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

Map: 100%|████████████████████████| 3250/3250 [00:00<00:00, 22521.11 examples/s]


# Prepare the MBert configuration for creating the Modern Bert Classifier task

In [64]:
mbert_config = AutoConfig.from_pretrained(
    model_name,
    num_labels = len(ner_tags),
    id2label = ner_id2tag,
    label2ids = ner_tag2id
)

In [65]:
mbert_model = AutoModelForTokenClassification.from_pretrained(
    model_name,
    config=mbert_config
)

Some weights of ModernBertForTokenClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [66]:
device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
print(f"Loading model to device: {device}")
mbert_model.to(device)

Loading model to device: mps


ModernBertForTokenClassification(
  (model): ModernBertModel(
    (embeddings): ModernBertEmbeddings(
      (tok_embeddings): Embedding(50368, 768, padding_idx=50283)
      (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (drop): Dropout(p=0.0, inplace=False)
    )
    (layers): ModuleList(
      (0): ModernBertEncoderLayer(
        (attn_norm): Identity()
        (attn): ModernBertAttention(
          (Wqkv): Linear(in_features=768, out_features=2304, bias=False)
          (rotary_emb): ModernBertRotaryEmbedding()
          (Wo): Linear(in_features=768, out_features=768, bias=False)
          (out_drop): Identity()
        )
        (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): ModernBertMLP(
          (Wi): Linear(in_features=768, out_features=2304, bias=False)
          (act): GELUActivation()
          (drop): Dropout(p=0.0, inplace=False)
          (Wo): Linear(in_features=1152, out_features=768, bias=False)
        )
      )
 

In [67]:


def compute_metrics(pred):
    logits, labels = pred

    # Convert logits to predicted class indices
    predictions = np.argmax(logits, axis=2)

    true_labels = []
    pred_labels = []

    for label_row, pred_row in zip(labels, predictions):  # Process each sentence
        temp_true = []
        temp_pred = []
        
        for label_id, pred_id in zip(label_row, pred_row):  # Process each token
            if label_id != -100:  # Ignore padding tokens
                temp_true.append(ner_id2tag[label_id])  # Convert true label to string
                temp_pred.append(ner_id2tag[pred_id])  # Convert predicted label to string

        if temp_true:  # Only add non-empty sequences
            true_labels.append(temp_true)
            pred_labels.append(temp_pred)

    return {
        "accuracy": accuracy_score(true_labels, pred_labels),
        "f1": f1_score(true_labels, pred_labels),
        "precision": precision_score(true_labels, pred_labels),
        "recall": recall_score(true_labels, pred_labels),
    }


In [68]:
MODEL   = "mbert"
OUTDIR  = f"{ROOTDIR}/build/{MODEL}/results"
LOGDIR  = f"{ROOTDIR}/build/{MODEL}/logs"
BATCH_SIZE = 8
NUM_EPOCHS = 10
WT_DECAY = 0.01

In [69]:
# Training arguments
training_args = TrainingArguments(
    output_dir=OUTDIR,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    weight_decay=WT_DECAY,
    logging_dir=LOGDIR
)

In [70]:
# Initialize Trainer
trainer = Trainer(
    model=mbert_model,
    args=training_args,
    train_dataset=conll_ds["train"],
    eval_dataset=conll_ds["validation"],
    compute_metrics=compute_metrics
)


In [71]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,0.0861,0.076213,0.981009,0.92784,0.928048,0.927633
2,0.0386,0.080684,0.982083,0.93221,0.930373,0.934055
3,0.0231,0.079557,0.981858,0.929151,0.921608,0.936819
4,0.012,0.093292,0.982548,0.936286,0.93316,0.939432
5,0.0061,0.116714,0.983357,0.937702,0.93179,0.943689
6,0.0051,0.112998,0.9815,0.929968,0.92577,0.934205
7,0.0027,0.113267,0.984762,0.942995,0.942748,0.943241
8,0.0011,0.121052,0.984762,0.945334,0.946608,0.944063
9,0.0002,0.12575,0.984444,0.9415,0.937699,0.945332
10,0.0001,0.128712,0.984921,0.942687,0.93954,0.945855


TrainOutput(global_step=17560, training_loss=0.02277638827147801, metrics={'train_runtime': 7459.6955, 'train_samples_per_second': 18.822, 'train_steps_per_second': 2.354, 'total_flos': 1.196203276493568e+16, 'train_loss': 0.02277638827147801, 'epoch': 10.0})

In [77]:
def predict_ner(text, tokenizer, model):
    tokens = tokenizer(text, return_tensors="pt", truncation=True)
    device = torch.device("mps")
    print(f"Loading model to device: {device}")
    model.to(device)
    model.eval()
    tokens = {key: value.to(device) for key, value in tokens.items()}
    with torch.no_grad():
        outputs = model(**tokens)

    predictions = torch.argmax(outputs.logits, dim=2)
    predicted_labels = [ner_id2tag[id.item()] for id in predictions[0]]

    return list(zip(tokenizer.tokenize(text), predicted_labels))

In [78]:
# Test with a news sentence
text = "Kennedy delivered his famous speech at Rice University."
print(predict_ner(text, mbert_tokenizer, mbert_model))

Loading model to device: mps
[('Kenn', 'O'), ('edy', 'B-PER'), ('Ġdelivered', 'B-PER'), ('Ġhis', 'O'), ('Ġfamous', 'O'), ('Ġspeech', 'O'), ('Ġat', 'O'), ('ĠRice', 'O'), ('ĠUniversity', 'B-ORG'), ('.', 'I-ORG')]
