In [None]:
# Install dependencies
!pip install datasets
!pip install spacy
!python -m spacy download pt_core_news_sm
!pip install -q evaluate seqeval

### 1) Prepare dataset to NER format

In [None]:
from datasets import Dataset, ClassLabel, Sequence, Features, Value, DatasetDict
import json
import spacy
import re
from transformers import AutoTokenizer, DataCollatorForTokenClassification, AutoModelForTokenClassification, TrainingArguments, Trainer, pipeline
import evaluate
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

def sliding_window(text, window_size=200, overlap=50):
    """
    Splits a text into sliding windows of size `window_size` with an `overlap`.

    Args:
        text (str): The text to be processed.
        window_size (int): The maximum number of tokens per window.
        overlap (int): The number of overlapping tokens between windows.

    Returns:
        List of str: List of texts divided into sliding windows.
    """
    # Load the SpaCy model for Portuguese (or another language, if needed)
    nlp = spacy.load('pt_core_news_sm')

    # Process the entire text with SpaCy
    doc = nlp(text)

    # Extract tokens
    tokens = [token.text for token in doc]

    # List to store the text windows
    windows = []

    # Sliding window implementation
    for i in range(0, len(tokens), window_size - overlap):
        # Capture a window of tokens
        window = tokens[i:i + window_size]
        windows.append(fix_tags_with_replace(" ".join(window)))

        # Stop if we are at the end of the text
        if i + window_size >= len(tokens):
            break

    return windows


def fix_tags_with_replace(text):
    """
    Fixes the format of tags in the text using string replacement for all possible cases.

    Args:
        text (str): The text generated by the model containing malformed tags.

    Returns:
        str: Text with corrected tags.
    """
    # List of tags that need to be fixed
    tags = [
        "AGE", "PHONE", "FAX", "EMAIL", "URL", "IP_ADDRESS", "DATE", "IDNUM",
        "MEDICAL_RECORD", "DEVICE", "HEALTH_PLAN", "BIOID", "STREET", "CITY",
        "ZIP", "STATE", "COUNTRY", "LOCATION_OTHER", "ORGANIZATION", "HOSPITAL",
        "PATIENT", "DOCTOR", "USERNAME", "PROFESSION", "OTHER", "LOCATION"
    ]

    for tag in tags:
        # Fix spaces around the opening tag
        text = text.replace(f"< {tag} >", f"<{tag}>").replace(f"< {tag}>", f"<{tag}>").replace(f"<{tag} >", f"<{tag}>")
        # Fix spaces around the closing tag
        text = text.replace(f"</ {tag} >", f"</{tag}/>").replace(f"</ {tag}>", f"</{tag}/>").replace(f"</{tag} >", f"</{tag}/>")
        # Fix malformed closing with extra slashes
        text = text.replace(f"<{tag}/> ", f"</{tag}/>").replace(f"<{tag}/ >", f"</{tag}/>").replace(f"</{tag}/ >", f"</{tag}/>")
        # Remove spaces between tags and inner content
        text = text.replace(f"<{tag}> ", f"<{tag}>").replace(f" </{tag}>", f"</{tag}/>").replace(f"</{tag}>", f"</{tag}/>")

    return text


def label_ner_format(strings, tag_values):
    """
    Converts a list of strings into a sequence of numeric labels for NER tasks.

    Each string is checked against the provided `tag_values` dictionary. If a tag
    is found in the string, its corresponding numeric value is assigned.
    If the same tag appears consecutively, its value is incremented by 1.
    Strings without tags are labeled with 0.

    Args:
        strings (list of str): List of strings containing text with possible tags.
        tag_values (dict): A dictionary mapping tags to numeric values.

    Returns:
        list of int: A list of numeric labels corresponding to the detected tags.
    """
    result = []
    previous_tag = None  # Track the last seen tag

    for string in strings:
        # Extract the tag if it exists in the current string
        current_tag = next((tag for tag in tag_values if tag in string), None)

        if current_tag:
            # If the tag matches the previous one, increment its value
            if current_tag == previous_tag:
                value = tag_values[current_tag] + 1
            else:
                value = tag_values[current_tag]

            previous_tag = current_tag  # Update the previous tag
            result.append(value)
        else:
            result.append(0)  # No tag, assign 0
            previous_tag = current_tag  # Update the previous tag

    return result


# Define labels to be used as entities
labels = [
    'O',  # No entity
    'B-AGE', 'I-AGE',
    'B-PHONE', 'I-PHONE',
    'B-EMAIL', 'I-EMAIL',
    'B-DATE', 'I-DATE',
    'B-IDNUM', 'I-IDNUM',
    'B-MEDICAL_RECORD', 'I-MEDICAL_RECORD',
    'B-HEALTH_PLAN', 'I-HEALTH_PLAN',
    'B-STREET', 'I-STREET',
    'B-CITY', 'I-CITY',
    'B-ZIP', 'I-ZIP',
    'B-STATE', 'I-STATE',
    'B-COUNTRY', 'I-COUNTRY',
    'B-LOCATION_OTHER', 'I-LOCATION_OTHER',
    'B-ORGANIZATION', 'I-ORGANIZATION',
    'B-HOSPITAL', 'I-HOSPITAL',
    'B-PATIENT', 'I-PATIENT',
    'B-DOCTOR', 'I-DOCTOR',
    'B-PROFESSION', 'I-PROFESSION',
    'B-OTHER', 'I-OTHER'
]

# Create the entity2id dictionary
entity2id = {label: idx for idx, label in enumerate(labels)}

# Create the tag2id dictionary
tag2id = {"<"+tag.split('-')[1]+">": idx for idx, tag in enumerate(labels[1:])}

tag2id, entity2id

In [None]:
def split_tags(text):
    # Transform <DATE>dd/mm/yyyy</DATE/> and <DATE>dd/mm/yy</DATE/>
    text = re.sub(
        r'<DATE>(\d{2})/(\d{2})/(\d{2,4})</DATE/>',
        r'<DATE>\1</DATE/> <DATE>/</DATE/> <DATE>\2</DATE/> <DATE>/</DATE/> <DATE>\3</DATE/>',
        text
    )
    # Transform <DATE>dd/mm</DATE/>
    text = re.sub(
        r'<DATE>(\d{2})/(\d{2,4})</DATE/>',
        r'<DATE>\1</DATE/> <DATE>/</DATE/> <DATE>\2</DATE/>',
        text
    )
    # Transform <PHONE>(21)</PHONE/>  → ( <PHONE>21</PHONE/> )
    text = re.sub(
        r'<PHONE>\((\d{2})\)</PHONE/>',
        r'( <PHONE>\1</PHONE/> )',
        text
    )
    # Transform <PHONE>99856-7421</PHONE/>  → <PHONE>99856</PHONE/> <PHONE>-</PHONE/> <PHONE>7421</PHONE/>
    text = re.sub(
        r'<PHONE>(\d{4,5})-(\d{4})</PHONE/>',
        r'<PHONE>\1</PHONE/> <PHONE>-</PHONE/> <PHONE>\2</PHONE/>',
        text
    )
    # Transform <EMAIL>marcia.silva@gmail.com</EMAIL/> → <EMAIL>marcia.silva</EMAIL/> <EMAIL>@</EMAIL/> <EMAIL>gmail.com</EMAIL/>
    text = re.sub(
        r'<EMAIL>([\w\.-]+)@([\w\.-]+)</EMAIL/>',
        r'<EMAIL>\1</EMAIL/> <EMAIL>@</EMAIL/> <EMAIL>\2</EMAIL/>',
        text
    )

    return text

In [None]:
preprocess_data = True

if preprocess_data:

    ### READ DATA
    dataset_ = load_dataset("Venturus/AnonyMED-BR")

    ### PRE-PROCESS DATA
    # Train
    list_chunks = [" ".join(chunk.replace('\n', '').split()).strip() for train_sample in tqdm(dataset_['train']) for chunk in sliding_window(split_tags(train_sample["text"]))]
    train_chunks = [{'id':idx,
                    'tokens': re.sub(r"<[^>]+>", "", chunk.replace("/>-<", "/> - <")).split(' '),
                    'ner_tags': label_ner_format(chunk.replace("/>-<", "/> - <").split(' '), tag2id)}
                    for idx, chunk in enumerate(list_chunks)]

    # Save intermediary step
    with open('bert_train.json', 'w', encoding='utf-8') as f:
        json.dump(train_chunks, f, ensure_ascii=False, indent=4)

    # Eval
    list_chunks_eval = [" ".join(chunk.replace('\n', '').split()).strip() for eval_sample in tqdm(dataset_['validation']) for chunk in sliding_window(split_tags(eval_sample["text"]))]
    eval_chunks = [{'id':idx,
                    'tokens': re.sub(r"<[^>]+>", "", chunk.replace("/>-<", "/> - <")).split(' '),
                    'ner_tags': label_ner_format(chunk.replace("/>-<", "/> - <").split(' '), tag2id)}
                    for idx, chunk in enumerate(list_chunks_eval)]

    # Save intermediary step
    with open('bert_eval.json', 'w', encoding='utf-8') as f:
        json.dump(eval_chunks, f, ensure_ascii=False, indent=4)

else:

    # Open and read the training set
    with open('bert_train.json', 'r') as file:
        train_chunks = json.load(file)

    # Open and read the evaluation set
    with open('bert_eval.json', 'r') as file:
        eval_chunks = json.load(file)


# Create ClassLabel structure
ner_tag_feature = ClassLabel(names=labels)

# Create the dataset schema
features = Features({
    'id': Value(dtype='string'),  # Identificador único para cada amostra
    'tokens': Sequence(Value(dtype='string')),  # Lista de tokens
    'ner_tags': Sequence(ner_tag_feature)  # Lista de rótulos alinhados aos tokens
})

# Convert the data into Hugging Face Dataset objects
train_dataset = Dataset.from_list(train_chunks, features=features)
eval_dataset = Dataset.from_list(eval_chunks[0:500], features=features)

# Combine into a DatasetDict
dataset = DatasetDict({
    "train": train_dataset,
    "validation": eval_dataset
})

# Map ids to labels to load a BERT model with correct output head
tag_names = dataset["train"].features[f"ner_tags"].feature.names
id2label = dict(enumerate(tag_names))
label2id = dict(zip(id2label.values(), id2label.keys()))

dataset

### 2) Tokenize data

In [None]:
### Load model
model_name = "google-bert/bert-base-multilingual-cased"
tokenizer = AutoTokenizer.from_pretrained(model_name)


def tokenize_and_align_tags(records):
    """
    Tokenizes input text and aligns Named Entity Recognition (NER) tags with the resulting tokens.

    This function handles cases where words are split into multiple subtokens by the tokenizer
    (e.g., "ChatGPT" → ["Chat", "##G", "##PT"]). For such cases, the first subtoken inherits the
    original tag, while subsequent subtokens are ignored by assigning them the label `-100`.
    Special tokens (e.g., [CLS], [SEP], [PAD]) are also assigned `-100` to exclude them from loss
    calculation during training.

    Args:
        records (dict): A dictionary containing:
            - "tokens" (list of list of str): The tokenized input sentences (split by words).
            - "ner_tags" (list of list of int): The corresponding NER tags for each word.

    Returns:
        dict: A dictionary containing:
            - Tokenized inputs (as produced by the Hugging Face tokenizer).
            - "labels" (list of list of int): The aligned NER tags, where `-100` is used for
              subtokens and special tokens.
    """
    tokenized_results = tokenizer(records["tokens"], truncation=True, is_split_into_words=True,
                                  padding="max_length", max_length=512)

    input_tags_list = []

    # Iterate through each set of tags in the records.
    for i, given_tags in enumerate(records["ner_tags"]):
        # Get the word IDs corresponding to each token. This tells us to which original word each token corresponds.
        word_ids = tokenized_results.word_ids(batch_index=i)

        previous_word_id = None
        input_tags = []

        # For each token, determine which tag it should get.
        for wid in word_ids:
            # If the token does not correspond to any word (e.g., it's a special token), set its tag to -100.
            if wid is None:
                input_tags.append(-100)
            # If the token corresponds to a new word, use the tag for that word.
            elif wid != previous_word_id:
                input_tags.append(given_tags[wid])
            # If the token is a subtoken (i.e., part of a word we've already tagged), set its tag to -100.
            else:
                input_tags.append(-100)
            previous_word_id = wid

        input_tags_list.append(input_tags)

    # Add the assigned tags to the tokenized results.
    # In the Hugging Face Transformers library, a model recognizes the labels parameter
    # for computing losses along with logits (predictions)
    tokenized_results["labels"] = input_tags_list

    return tokenized_results

tokenized_datasets = dataset.map(tokenize_and_align_tags, batched=True)

tokenized_datasets

### 3) Fine-Tuning

In [None]:
seqeval = evaluate.load("seqeval")

def compute_metrics(p):
    # p is the results containing a list of predictions and a list of labels
    # Unpack the predictions and true labels from the input tuple 'p'.
    predictions_list, labels_list = p

    # Convert the raw prediction scores into tag indices by selecting the tag with the highest score for each token.
    predictions_list = np.argmax(predictions_list, axis=2)

    # Filter out the '-100' labels that were used to ignore certain tokens (like sub-tokens or special tokens).
    # Convert the numeric tags in 'predictions' and 'labels' back to their string representation using 'tag_names'.
    # Only consider tokens that have tags different from '-100'.
    true_predictions = [
        [tag_names[p] for (p, l) in zip(predictions, labels) if l != -100]
        for predictions, labels in zip(predictions_list, labels_list)
    ]
    true_tags = [
        [tag_names[l] for (p, l) in zip(predictions, labels) if l != -100]
        for predictions, labels in zip(predictions_list, labels_list)
    ]

    # Evaluate the predictions using the 'seqeval' library, which is commonly used for sequence labeling tasks like NER.
    # This provides metrics like precision, recall, and F1 score for sequence labeling tasks.
    results = seqeval.compute(predictions=true_predictions, references=true_tags)

    # Return the evaluated metrics as a dictionary.
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }


model = AutoModelForTokenClassification.from_pretrained(
    model_name, num_labels=len(id2label), id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True)

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

training_args = TrainingArguments(
    report_to="none",
    output_dir="<path_to_output>",
    learning_rate=5e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
   # weight_decay=0.01,
   # evaluation_strategy="steps",
    eval_steps=500,
    logging_dir='logs',
    logging_steps=500,
    save_strategy="epoch",
    #load_best_model_at_end=True,
    fp16 = False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    data_collator=data_collator
)

trainer.train()

# Save fine-tuned model
model.save_pretrained("BERT_model")
tokenizer.save_pretrained("BERT_model")
print("Model saved!")

### 4) Inference

In [None]:
def check_case_majority(s):
    """
    Normalize the case of a string based on majority rule.

    The function counts uppercase and lowercase characters in the string.
    - If uppercase letters are the majority, the entire string is converted to uppercase.
    - If lowercase letters are the majority, the entire string is converted to lowercase.
    - If counts are equal, the string is converted to lowercase.

    Args:
        s (str): Input string.

    Returns:
        str: Normalized string in either uppercase or lowercase.
    """
    upper_count = sum(1 for c in s if c.isupper())
    lower_count = sum(1 for c in s if c.islower())

    if upper_count > lower_count:
        return s.upper()
    elif lower_count > upper_count:
        return s.lower()
    else:
        return s.lower()


def insert_email_com(entities):
    """
    Insert missing ".com" or equivalent into detected email entities and merge them.

    This function scans through a list of entities, identifies incomplete
    email patterns (like "example@" or "example.") and ensures the domain
    suffix ".com" (or "com") is correctly appended. After insertion, it merges
    consecutive "EMAIL" entities into a single complete email.

    Args:
        entities (list): A list of entity dictionaries containing 'entity',
                         'word', 'start', 'end', and related metadata.

    Returns:
        list: Updated list of entities with corrected email addresses.
    """
    new_entities = []
    for i, ent in enumerate(entities):
        new_entities.append(ent)

        # Check if the current element is I-EMAIL
        if ent['entity'] == 'I-EMAIL':
            # Check if there is a previous and next element
            prev_ent = entities[i - 1] if i > 0 else None
            next_ent = entities[i + 1] if i + 1 < len(entities) else None

            # Boundary condition
            if (
                prev_ent is not None and
                prev_ent['entity'] == 'I-EMAIL' and
                (next_ent is None or next_ent['entity'] != 'I-EMAIL')
            ):
                # Decide what to add
                if ent['word'] == '.':
                    word_to_add = 'com'
                elif ent['word'].lower() == 'com':
                    word_to_add = ''
                else:
                    word_to_add = '.com'

                # Calculate start and end
                start = ent['end']
                end = start + len(word_to_add)

                # Create new element
                new_element = {
                    'entity': 'I-EMAIL',
                    'score': 1.0,
                    'index': ent['index'] + 0.1,  # temporary for ordering
                    'word': word_to_add,
                    'start': start,
                    'end': end
                }

                new_entities.append(new_element)

    # Step 2: Merge sequences of B-EMAIL and I-EMAIL
    merged_entities = []
    i = 0
    while i < len(new_entities):
        ent = new_entities[i]

        if ent['entity'] in ['B-EMAIL', 'I-EMAIL']:
            # Start group
            start_idx = i
            words = [ent['word']]
            start_pos = ent['start']
            end_pos = ent['end']
            scores = [ent['score']]

            # Continue while sequence of EMAIL
            i += 1
            while i < len(new_entities) and new_entities[i]['entity'] in ['B-EMAIL', 'I-EMAIL']:
                words.append(new_entities[i]['word'])
                end_pos = new_entities[i]['end']
                scores.append(new_entities[i]['score'])
                i += 1

            merged_entity = {
                'entity': 'B-EMAIL',
                'score': sum(scores) / len(scores),
                'index': new_entities[start_idx]['index'],
                'word': check_case_majority(''.join(words)),
                'start': start_pos,
                'end': end_pos
            }
            merged_entities.append(merged_entity)

        else:
            # Keep other entities unchanged
            merged_entities.append(ent)
            i += 1

    return merged_entities


def merge_time_entities(entities, tag):
    """
    Merge consecutive temporal entities into a single DATE entity.

    This function combines sequential entities of type `tag` (e.g., "DATE")
    into one continuous entity when they are adjacent in the text.

    Args:
        entities (list): A list of entity dictionaries.
        tag (str): The entity tag to merge (e.g., "DATE").

    Returns:
        list: A list of entities with merged temporal expressions.
    """
    merged_entities = []
    i = 0

    while i < len(entities):
        current = entities[i]
        if tag in current['entity']:
            # Initialize a buffer to accumulate related temporal elements
            merged_word = current['word']
            start = current['start']
            index = current['index']
            end = current['end']

            # Traverse next elements to merge DATE-related fields
            while i + 1 < len(entities) and tag in entities[i + 1]['entity'] and entities[i]['end'] == entities[i + 1]['start']:
                i += 1
                merged_word += entities[i]['word']
                end = entities[i]['end']

            # Add the merged element to the output list
            merged_entities.append({
                'entity': 'B-{}'.format(tag),
                'score': current['score'],  # Could be adjusted to an average if needed
                'index': index,
                'word': merged_word,
                'start': start,
                'end': end
            })
        else:
            # Keep non-temporal elements unchanged
            merged_entities.append(current)

        i += 1

    return merged_entities


def merge_subwords(entities):
    """
    Merge subword tokens into complete words.

    Subwords are typically indicated by the "##" prefix (e.g., "token" + "##ization").
    This function concatenates subwords to their preceding word entity, updates
    positions, and adjusts scores accordingly.

    Args:
        entities (list): List of entity dictionaries containing 'word', 'start', 'end', and 'score'.

    Returns:
        list: Entities with merged subwords into full words.
    """
    merged_entities = []
    temp_entity = None

    for entity in entities:
        if entity["word"].startswith("##"):
            # Concatenate with the previous entity
            if temp_entity:
                temp_entity["word"] += entity["word"][2:]  # Remove "##" before concatenation
                temp_entity["end"] = entity["end"]  # Update 'end' position
                temp_entity["score"] = min(temp_entity["score"], entity["score"])  # Use lowest score
        else:
            # Save the previous entity before starting a new one
            if temp_entity:
                merged_entities.append(temp_entity)
            # Start a new entity
            temp_entity = entity.copy()

    # Add the last processed entity
    if temp_entity:
        merged_entities.append(temp_entity)

    return merged_entities


def replace_with_entities(text, entities):
    """
    Replace words in the text with their corresponding entity tags.

    Words identified as entities are substituted by their tag format `<ENTITY>`.
    Replacement is performed in reverse order to maintain index consistency.

    Args:
        text (str): Original text.
        entities (list): List of entities with 'start', 'end', 'word', and 'entity' fields.

    Returns:
        str: Text with entities replaced by their tag markers.
    """
    for entity in reversed(entities):  # Iterate over entities in reverse order
        start = entity["start"]
        end = entity["end"]
        entity_type = entity["entity"].split('-')[1]

        # Check if the substring matches the entity word
        if text[start:end] != entity['word']:
            print(f"Warning: Word mismatch. Expected '{entity['word']}', found '{text[start:end]}' at position {start}-{end}. Skipping.")
            continue

        # Replace in text
        replacement = f"<{entity_type}>"
        text = text[:start] + replacement + text[end:]

    return text


def create_generative_format(text):
    """
    Convert a text with annotated tags into a generative format.

    This function replaces each annotated tag of the form <TAG>word</TAG/>
    with a simplified generative format <TAG>, removing the inner content.

    Args:
        text (str): Text containing tagged entities.

    Returns:
        str: Text with entities in generative format.
    """
    # Regex to capture opening tag and content
    pattern = r"<(.*?)>(.*?)</\1/>"

    # Replacement function that keeps only the opening tag
    def replace_match(match):
        tag = match.group(1)
        return f"<{tag}>"

    # Replace all occurrences in the text
    replaced_text = re.sub(pattern, replace_match, text)
    return replaced_text


def find_missing_words(predicted_words, labels):
    """
    Identify missing words from predictions compared to labels.

    Args:
        predicted_words (list): List of predicted words.
        labels (list): List of ground-truth label words.

    Returns:
        tuple: (Number of missing words, list of missing words).
    """
    missing_words = [word for word in labels if word not in predicted_words]
    return len(missing_words), missing_words


def calculate_f1_score(tp, fp, fn, verbose=False):
    """
    Calculate F1-score, recall, and precision.

    Args:
        tp (int): True positives.
        fp (int): False positives.
        fn (int): False negatives.
        verbose (bool, optional): If True, prints metrics.

    Returns:
        tuple: (f1_score, recall, precision)
    """
    # Precision = True positives / all predicted positives
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0

    # Recall = True positives / all actual positives
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0

    # F1 Score = Harmonic mean of precision and recall
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    if verbose:
        print('Recall:', recall)
        print('Precision:', precision)
        print('F1 Score:', f1_score, '\n')

    return f1_score, recall, precision


def eval(extractive_pred, dict_labels, verbose=False):
    """
    Evaluate entity extraction predictions against ground-truth labels.

    The function compares predicted entities with reference labels,
    computes precision, recall, and F1-score, and provides insights
    on correct predictions, incorrect categories, and missing words.

    Args:
        extractive_pred (dict): Dictionary with predictions (must include 'preds').
        dict_labels (dict): Ground-truth labels mapping words to entity categories.
        verbose (bool, optional): If True, prints detailed evaluation results.

    Returns:
        tuple: (f1, recall, precision)
    """
    TP, FP, FN = 0, 0, 0
    correct_predicted_words = []
    wrong_predicted_words = []
    predicted_words = []
    wrong_predicted_category = []
    for pred in extractive_pred['preds']:

        if pred['word'] in list(dict_labels.keys()):

            if pred['entity'].split('-')[1] == dict_labels[pred['word']]:
                TP+=1
                correct_predicted_words.append((pred['word'], pred['entity'].split('-')[1]))
                predicted_words.append(pred['word'])
            else:
                FP+=1
                wrong_predicted_category.append((pred['word'], pred['entity'].split('-')[1]))
        else:
            FP+=1
            wrong_predicted_words.append(pred['word'])

    # Calculate False Negatives
    FN, missing_words = find_missing_words(predicted_words, list(dict_labels.keys()))
    if verbose:
        print('Missing words:', missing_words)
        print('Correct Predicted words:', correct_predicted_words)
        print('Correct word but wrong category:', wrong_predicted_category)
        print('Wrong Predicted words:', wrong_predicted_words)
        print('Labels:', dict_labels)

    # Calculate F1 Score
    f1, recall, precision = calculate_f1_score(TP, FP, FN, verbose=verbose)

    return f1, recall, precision


def insert_intermediate_element(data, tag, sep):
    """
    Insert an intermediate separator entity between consecutive digit entities.

    This function ensures numeric entities that should contain separators
    (e.g., phone numbers or IDs with hyphens/dots) are completed with
    the given separator character.

    Args:
        data (list): List of entity dictionaries.
        tag (str): Tag suffix to match (e.g., "PHONE").
        sep (str): Separator string to insert.

    Returns:
        list: Updated list of entities with inserted separators.
    """
    updated_data = []
    for i in range(len(data)):
        current = data[i]
        updated_data.append(current)

        # Check if current and next entity match the condition
        if (
            current["entity"].endswith(tag) and
            current["word"].isdigit() and
            i + 1 < len(data) and
            data[i + 1]["entity"].endswith(tag) and
            data[i + 1]["word"].isdigit() and
            current["end"] + 1 == data[i + 1]["start"]
        ):
            # Create the intermediate element
            intermediate_element = {
                "entity": "I-{}".format(tag),
                "score": (current["score"] + data[i + 1]["score"]) / 2,
                "index": current["index"] + 1,  # Intermediate index
                "word": sep,  # Insert separator
                "start": current["end"],
                "end": data[i + 1]["start"]
            }
            updated_data.append(intermediate_element)

    return updated_data


def insert_phone_closing_parenthesis(entities):
    """
    Insert a missing closing parenthesis in phone number entities.

    This function identifies phone numbers where the opening parenthesis "("
    is present, followed by a two-digit code, but the closing parenthesis ")"
    is missing. It inserts the closing parenthesis in the correct position.

    Args:
        entities (list): List of entity dictionaries.

    Returns:
        list: Updated list of entities with corrected phone numbers.
    """
    i = 0
    while i < len(entities) - 2:
        if (entities[i]['entity'] == 'B-PHONE' and entities[i]['word'] == '(' and
            entities[i + 1]['entity'] == 'B-PHONE' and len(entities[i + 1]['word']) == 2 and
            (i + 2 >= len(entities) or entities[i + 2]['word'] != ')')):

            # Create the new entity for ')'
            new_entity = {
                'entity': 'B-PHONE',
                'score': entities[i]['score'],  # Using score of '(' for consistency
                'index': entities[i + 1]['index'] + 1,
                'word': ')',
                'start': entities[i + 1]['end'],
                'end': entities[i + 1]['end'] + 1
            }

            # Insert new entity after the second analyzed element (i + 2)
            entities.insert(i + 2, new_entity)

        i += 1

    return entities

model_name = "BERT_model"
model = AutoModelForTokenClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

classifier = pipeline("ner", model=model, tokenizer=tokenizer)

Device set to use cpu


In [None]:
# Get test data
test_set = dataset_['test']

list_f1 = []
list_recall = []
list_precision = []
list_pred_gen = []
list_syn = []
list_preds = []
list_id = []
for test_sample in tqdm(test_set):

    # Create the sliding window to input on the NLP model
    clean_text = re.sub(r"<[^>]+>", "", test_sample['text'])
    windows = sliding_window((clean_text), window_size=150, overlap=0)

    full_masked_text = ''
    tags = []
    for window in windows:

        # Fix "(" and ")" in sliding window function
        window = window.replace("( ","(").replace(" )",")")

        # Run NLP model and merge tokens
        preds = merge_subwords(classifier(window))

        ## Fix missing "/" on DATE predictions
        preds = insert_intermediate_element(preds,"DATE", "/")
        ## FIX missing "-" on PHONE predictions
        preds = insert_intermediate_element(preds, "PHONE", "-")
        ## FIX missing ")" on PHONE predictions
        preds = insert_phone_closing_parenthesis(preds)
        ## FIX missing "com" on EMAIL predictions
        preds = insert_email_com(preds)

        # Merge DATE and PHONE tags
        preds = merge_time_entities(preds, 'DATE')
        preds = merge_time_entities(preds, 'PHONE')

        # Mask the entities on original text
        masked_text = replace_with_entities(window, preds)

        # Save a list of dicts containing predictions on extractive format
        tags += preds

        # Save text to be used in generative format
        full_masked_text += masked_text

    ### Save prediction in a list
    list_preds.append(tags)

    ### Extractive Format Evaluation ###
    extractive_pred = {'id': test_sample['id'], 'preds': tags}

    # Get labels in the evaluation format

    dict_labels = {(re.sub(r'[()]', '', item['word'])
                  if item['subcategory'] == 'PHONE' and re.fullmatch(r'\(\d{2}\)', item['word'])
                  else item['word']
                  ): item['subcategory']
                  for item in test_sample['labels']}

    # Run evaluation
    f1, recall, precision = eval(extractive_pred, dict_labels, verbose=True)

    # Save results for correct words and classes
    list_f1.append(f1)
    list_recall.append(recall)
    list_precision.append(precision)

    # Save if it is synthetic or not
    list_syn.append(test_sample['synthetic'])

    # Save example id
    list_id.append(test_sample['id'])

    ### Generative Format Evaluation ###
    list_pred_gen.append({'text': re.sub(r"<[^>]+>", "", test_sample['text']), 'masked_text':create_generative_format(test_sample['text']), 'prediction':full_masked_text})

### Create a dataframe with the evaluations
save_df = pd.DataFrame()
save_df['id'] = list_id
save_df['Recall'] = list_recall
save_df['Precision'] = list_precision
save_df['F1'] = list_f1
save_df['synthetic'] = list_syn
save_df['Prediction'] = list_preds
save_df.to_csv('BERT_results.csv', index=False)

### Save on the JSON format for evaluate by entity
save_df.to_json('BERT_results.json', orient="records", lines=True)

avg_f1 = sum(list_f1) / len(list_f1) if list_f1 else 0
avg_recall = sum(list_recall) / len(list_recall) if list_recall else 0
avg_precision = sum(list_precision) / len(list_precision) if list_precision else 0

print('Recall:', avg_recall)
print('Precision:', avg_precision)
print('F1:', avg_f1)

## Save predictions on the generative format
with open('/bert_generative_predictions.json', 'w', encoding='utf-8') as f:
      json.dump(list_pred_gen, f, ensure_ascii=False, indent=4)

### Evaluation per entity

In [None]:
list_entities = ["PHONE", "AGE", "FAX", "EMAIL", "URL", "IP_ADDRESS", "DATE", "IDNUM",
        "MEDICAL_RECORD", "DEVICE", "HEALTH_PLAN", "BIOID", "STREET", "CITY",
        "ZIP", "STATE", "COUNTRY", "LOCATION_OTHER", "ORGANIZATION", "HOSPITAL",
        "PATIENT", "DOCTOR", "USERNAME", "PROFESSION", "OTHER", "LOCATION"]

def eval_entity(extractive_pred, dict_labels, verbose=False):
    """
    Evaluate the performance of extractive predictions against reference labels for entities.

    This function compares predicted words and their subcategories to a reference dictionary of labels.
    It calculates True Positives (TP), False Positives (FP), and False Negatives (FN), and returns
    the corresponding F1 score, recall, and precision. Optionally, it can print detailed information
    about correct and incorrect predictions.

    Args:
        extractive_pred (dict): Dictionary containing predicted entities under the key 'preds',
                                where each prediction is a dictionary with 'word' and 'subcategory'.
        dict_labels (dict): Dictionary of reference labels with words as keys and subcategories as values.
        verbose (bool, optional): If True, prints detailed information about missing words,
                                  correct predictions, and incorrect predictions. Defaults to False.

    Returns:
        tuple: A tuple containing:
            - f1 (float): F1 score for entity predictions.
            - recall (float): Recall score for entity predictions.
            - precision (float): Precision score for entity predictions.
    """
    TP, FP, FN = 0, 0, 0
    correct_predicted_words = []
    wrong_predicted_words = []
    predicted_words = []
    wrong_predicted_category = []
    for pred in extractive_pred['preds']:

        if pred['word'] in list(dict_labels.keys()):

            if pred['entity'].split('-')[1] == dict_labels[pred['word']]:
                TP+=1
                correct_predicted_words.append((pred['word'], pred['entity'].split('-')[1]))
                predicted_words.append(pred['word'])
            else:
                FP+=1
                wrong_predicted_category.append((pred['word'], pred['entity'].split('-')[1]))
        else:
            FP+=1
            wrong_predicted_words.append(pred['word'])

    # Calculate False Negatives
    FN, missing_words = find_missing_words(predicted_words, list(dict_labels.keys()))
    if verbose:
        print('Missing words:', missing_words)
        print('Correct Predicted words:', correct_predicted_words)
        print('Correct word but wrong category:', wrong_predicted_category)
        print('Wrong Predicted words:', wrong_predicted_words)
        print('Labels:', dict_labels)

    # Calculate F1 Score
    f1, recall, precision = calculate_f1_score(TP, FP, FN, verbose=verbose)

    return f1, recall, precision

In [None]:
test_set = dataset_['test']

## Read predictions
dict_preds = pd.read_json('BERT_results.json', dtype={"id": str}, orient="records", lines=True).to_dict('records')

list_f1_entity = []
list_recall_entity = []
list_precision_entity = []
list_entity = []
list_id = []
list_syn_entity = []
for test_sample, dict_pred in zip(test_set, dict_preds):

    assert str(test_sample['id']) == dict_pred['id']

    ### Extractive Format Evaluation ###
    extractive_pred = {'id': test_sample['id'], 'preds': dict_pred['Prediction']}

    # Get labels in the evaluation format
    dict_labels = {(re.sub(r'[()]', '', item['word'])
                  if item['subcategory'] == 'PHONE' and re.fullmatch(r'\(\d{2}\)', item['word'])
                  else item['word']
                  ): item['subcategory']
                  for item in test_sample['labels']}

    ## Evaluate performance per entity
    for entity in list_entities:
        ## Filter the entity to be evaluated inside the label
        filtered_dict_labels = {key: value for key, value in dict_labels.items() if value == entity}

        ## Filter the entity to be evaluated that were predicted by the model
        filtered_tags = [sample for sample in extractive_pred['preds'] if sample['entity'].split('-')[1] == entity]

        filtered_extractive_pred = {'id': test_sample['id'], 'preds': filtered_tags}

        ## Check if the entity exists inside the label to run evaluation
        if len(filtered_dict_labels) > 0:
            f1, recall, precision = eval_entity(filtered_extractive_pred, filtered_dict_labels, verbose=False)

            list_f1_entity.append(f1)
            list_recall_entity.append(recall)
            list_precision_entity.append(precision)
            list_entity.append(entity)
            list_id.append(test_sample['id'])

            # Save if it is synthetic or not
            list_syn_entity.append(test_sample['synthetic'])

### Create a dataframe with the evaluations
entity_save_df = pd.DataFrame()
entity_save_df['id'] = list_id
entity_save_df['Entity'] = list_entity
entity_save_df['Recall'] = list_recall_entity
entity_save_df['Precision'] = list_precision_entity
entity_save_df['F1'] = list_f1_entity
entity_save_df['synthetic'] = list_syn_entity

entity_save_df.to_csv('BERT_entity_results.csv')

In [None]:
# Generate results grouped per Entity
df_grouped_entity = entity_save_df.groupby('Entity')[['Precision', 'Recall', 'F1']].mean().reset_index()
df_grouped_entity.to_csv('BERT_entity_final_results.csv', index=False)
print(df_grouped_entity)

In [None]:
# Generate results grouped per Entity but separated between real and synthetic samples
df_grouped_syn_entity = entity_save_df.groupby(['synthetic', 'Entity'])[['Precision', 'Recall', 'F1']].mean().reset_index()
df_grouped_syn_entity.to_csv('BERT_entity_final_grouped_results.csv', index=False)
print(df_grouped_syn_entity)

In [None]:
# Generate F1 scores grouped by real and synthetic samples
df_results_ = pd.read_csv('BERT_results.csv')
df_results_grouped = df_results_.groupby(['synthetic'])[['Precision', 'Recall', 'F1']].mean().reset_index()
df_results_grouped.to_csv('BERT_grouped_results.csv', index= False)