# Fine-Tuning a Sentence Transformer for People Name Matching

This notebook fine-tunes a `SentenceTransformer` model to match people's names efficiently - even across character sets.

In [38]:
import json
import logging
import os
import random
import re
import sys
import time
import warnings
from collections import defaultdict
from itertools import product
from numbers import Number
from typing import Callable, Dict, List, Literal, Sequence, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.dataset as ds
import pyspark.sql.functions as F
import pyspark.sql.types as T
import pytest
import random
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as nF
import wandb
from datasets import Dataset
from pyspark.sql import SparkSession, DataFrame
from scipy.spatial import distance
from sklearn.metrics import (
    accuracy_score,
    precision_recall_curve,
    precision_recall_fscore_support,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    roc_curve,
    auc
)
from sklearn.model_selection import train_test_split
from sentence_transformers import InputExample, SentenceTransformer, SentencesDataset, SentenceTransformerTrainer, losses, models
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction, BinaryClassificationEvaluator
from sentence_transformers.model_card import SentenceTransformerModelCardData
from sentence_transformers.training_args import BatchSamplers, SentenceTransformerTrainingArguments
from tenacity import retry
from torch.utils.data import DataLoader
from torch.optim import RAdam
from tqdm.autonotebook import tqdm
from transformers import AutoTokenizer, AutoModel, EarlyStoppingCallback, TrainingArguments, Trainer
from transformers.integrations import WandbCallback

from utils import (
    augment_gold_labels,
    compute_sbert_metrics,
    compute_classifier_metrics,
    format_dataset,
    gold_label_report,
    preprocess_logits_for_metrics,
    structured_encode_address,
    tokenize_function,
    to_dict,
    save_transformer,
    load_transformer,
)

#### Pin Random Seeds for Reproducibility

In [2]:
RANDOM_SEED = 31337

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.mps.manual_seed(RANDOM_SEED)

#### Setup Basic Logging

In [3]:
logging.basicConfig(stream=sys.stderr, level=logging.ERROR)

logger = logging.getLogger(__name__)

#### Ignore Warnings

In [4]:
warnings.simplefilter("ignore", FutureWarning)
warnings.simplefilter("ignore", UserWarning)

#### Configure Weights & Biases

`wandb` needs some environment variables to work.

In [5]:
os.environ["WANDB_LOG_MODEL"] = "end"
os.environ["WANDB_WATCH"] = "gradients"
os.environ["WANDB_PROJECT"] = "libpostal-reborn"
os.environ["WANDB_DISABLED"] = "false"
os.environ["WANDB_IGNORE_GLOBS"] = ".env"

#### Optionally Disable `wandb` Uploads

Weights and Biases can be slow...

In [6]:
os.environ["WANDB_MODE"] = "online"

#### Configure Huggingface APIs

In [7]:
os.environ["HF_ENDPOINT"] = "https://huggingface.co/"

#### Configure Huggingface APIs

Squash any warnings...

In [8]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

#### Configure Pandas to Show More Rows

In [9]:
pd.set_option("display.max_rows", 40)
pd.set_option("display.max_columns", None)

### Use CUDA or MPS if Avaialable

CPU training and even inference with sentence transformers and deep learning models is quite slow. Since all machine learning in this library is based on [PyTorch](https://pytorch.org/get-started/locally/), we can assign all ML operations to a GPU in this one block of code. Otherwise we default to CPU without acceleration. The notebook is still workable in this mode, you just may need to grab a cup of tea or coffee while you wait for it to train the Sentence-BERT model below.

In [10]:
# Check for CUDA or MPS availability and set the device
if torch.backends.mps.is_available():
    device = torch.device("mps")
    logger.debug("Using Apple GPU acceleration")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    logger.debug("Using NVIDIA CUDA GPU acceleration")
else:
    device = "cpu"
    logger.debug("Using CPU for ML")

device

device(type='cuda')

### Use Weights & Biases for Logging Metrics

Weights & Biases has a free account for individuals with public projects. Using it will produce charts during our training runs that anyone can view. You can create your own project for this notebook and login with that key to log your own training runs.

You may need to run the following command from your shell before the next cell, otherwise you will have to paste your project key into the 

```bash
wandb login
```

In [11]:
# Login to wandb. Comment out if you already haven't via `wandb login` from a CLI
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mrjurney[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## Parse [Open Sanctions' Matcher Training Data](https://www.opensanctions.org/docs/pairs/) using PySpark

It is convenient to work with nested JSON in PySpark's dataflow oriented API rather than Pandas.

In [16]:
spark = SparkSession.builder.appName("Feature Calculation - Nature Paper").getOrCreate()

25/01/02 19:05:10 WARN Utils: Your hostname, heracles resolves to a loopback address: 127.0.0.1; using 10.1.10.3 instead (on interface eno1)
25/01/02 19:05:10 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/01/02 19:05:11 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [22]:
records_df = spark.read.json("data/pairs.json", )
records_df.show()

                                                                                

+---------+--------------------+--------------------+
|judgement|                left|               right|
+---------+--------------------+--------------------+
| positive|{Aminat Ramzanovn...|{АМИНАТ РАМЗАНОВН...|
| positive|{Приватне підприє...|{Private enterpri...|
| positive|{Private enterpri...|{ПП "МАГІСТАР-СГ"...|
| positive|{Приватне підприє...|{ПП "МАГІСТАР-СГ"...|
| positive|{Акціонерне товар...|{Акціонерне товар...|
| positive|{Акціонерне товар...|{Joint-stock comp...|
| positive|{Акціонерне товар...|{ОТКРЫТОЕ АКЦИОНЕ...|
| positive|{Акціонерне товар...|{Joint-stock comp...|
| positive|{Акціонерне товар...|{ОТКРЫТОЕ АКЦИОНЕ...|
| positive|{Joint-stock comp...|{ОТКРЫТОЕ АКЦИОНЕ...|
| positive|{GRUPO MECÁNICA D...|{Grupo Mecánica d...|
| positive|{Grupo Mecánica d...|{GRUPO MECANICA D...|
| positive|{Grupo Mecánica d...|{Grupo Mecánica d...|
| positive|{GRUPO MECÁNICA D...|{GRUPO MECANICA D...|
| positive|{GRUPO MECÁNICA D...|{Grupo Mecánica d...|
| positive|{Grupo Mecánica d

In [33]:
records_df.printSchema()

root
 |-- judgement: string (nullable = true)
 |-- left: struct (nullable = true)
 |    |-- caption: string (nullable = true)
 |    |-- datasets: array (nullable = true)
 |    |    |-- element: string (containsNull = true)
 |    |-- first_seen: string (nullable = true)
 |    |-- id: string (nullable = true)
 |    |-- last_change: string (nullable = true)
 |    |-- last_seen: string (nullable = true)
 |    |-- properties: struct (nullable = true)
 |    |    |-- address: array (nullable = true)
 |    |    |    |-- element: string (containsNull = true)
 |    |    |-- addressEntity: array (nullable = true)
 |    |    |    |-- element: string (containsNull = true)
 |    |    |-- alias: array (nullable = true)
 |    |    |    |-- element: string (containsNull = true)
 |    |    |-- asset: array (nullable = true)
 |    |    |    |-- element: string (containsNull = true)
 |    |    |-- bikCode: array (nullable = true)
 |    |    |    |-- element: string (containsNull = true)
 |    |    |-- bir

In [28]:
name_pairs_schema = T.ArrayType(
    T.StructType([
        T.StructField("name1", T.StringType(), True),
        T.StructField("name2", T.StringType(), True),
        T.StructField("label", T.IntegerType(), True),
    ])
)

In [32]:
people_df = records_df.filter(records_df.left.schema == "Person")
people_df.select("left.properties", "right.properties").show(5, truncate=False)

+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------

In [50]:
@F.pandas_udf(name_pairs_schema, functionType=F.PandasUDFType.SCALAR)
def parse_name_pairs(
    left_names: pd.Series,
    left_aliases: pd.Series,
    right_names: pd.Series,
    right_aliases: pd.Series,
    judgement_series: pd.Series
) -> pd.Series:
    """
    For each row:
      1) Combine left_names and left_aliases into one list, same for right.
      2) Take the cartesian product of (combined_left, combined_right).
      3) label = 1 if judgement == 'positive', else 0.
      4) Skip pairs if label=0 and name1==name2.
      5) Return an array of { name1, name2, label } dicts per row.
    """

    results = []

    for ln, la, rn, ra, judgement in zip(
        left_names, left_aliases, right_names, right_aliases, judgement_series
    ):
        # Handle potential null arrays (use empty lists if None)
        ln = ln if ln is not None else []
        la = la if la is not None else []
        rn = rn if rn is not None else []
        ra = ra if ra is not None else []

        combined_left = list(ln) + list(la)
        combined_right = list(rn) + list(ra)

        # Convert judgement to binary
        label_val = 1 if (judgement or "").lower() == "positive" else 0

        pair_list = []
        for name1, name2 in product(combined_left, combined_right):
            # Skip if negative and identical
            if label_val == 0 and name1 == name2:
                continue
            pair_list.append({
                "name1": name1,
                "name2": name2,
                "label": label_val
            })

        results.append(pair_list)

    return pd.Series(results)

In [46]:
# Flatten the columns you need
flat_people_df = people_df.select(
    F.col("left.properties.name").alias("left_names"),
    F.col("left.properties.alias").alias("left_aliases"),
    F.col("right.properties.name").alias("right_names"),
    F.col("right.properties.alias").alias("right_aliases"),
    F.col("judgement")
)
flat_people_df.show(5, truncate=False)

+---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------+----------------------------+-------------------------------------------------------------------+---------+
|left_names                                               |left_aliases                                                                                                     |right_names                 |right_aliases                                                      |judgement|
+---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------+----------------------------+-------------------------------------------------------------------+---------+
|[Аминат Рамзановна Ахмадова, Aminat Ramzanovna Akhmadova]|[Axmadova Aminat Ramzanovna, Ahmadova Aminat Ramzanovna, Ахмадова Аминат Рамзановна, Akhmadova Ami

In [52]:
parsed_df = flat_people_df.select(
    parse_name_pairs(
        F.col("left_names"),
        F.col("left_aliases"),
        F.col("right_names"),
        F.col("right_aliases"),
        F.col("judgement")
    ).alias("pairs")
)
parsed_df.show(10, truncate=False)

+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|pairs                                                                                                                                                                                                                                                                                                                                                                                                                                

In [69]:
# Explode to one row per name pair
name_pairs_df = (
    parsed_df.select(F.explode("pairs").alias("pair"))
    .selectExpr(
        "pair.name1 AS name1",
        "pair.name2 AS name2",
        "pair.label",
    )
)
print(f"Total generated pairs: {name_pairs_df.count():,}")

name_pairs_df.sample(0.0001, seed=RANDOM_SEED).show(30, truncate=False)

                                                                                

Total generated pairs: 3,650,933


[Stage 40:>                                                         (0 + 4) / 4]

+-----------------------------+---------------------------------+-----+
|name1                        |name2                            |label|
+-----------------------------+---------------------------------+-----+
|Шалигін Андрій Андрійович    |ШАЛИГІН Андрій Андрійович        |1    |
|Зайко Татьяна Ивановна       |ZAIKO Tatiana Ivanovna           |1    |
|Emad Hmeisho                 |Al Sabuni                        |0    |
|Magdalina Marina VLADIMIROVNA|ВЛАДИМИРОВНА Магдалина Марина    |1    |
|Amis Ashi                    |Amis al Ashi                     |1    |
|Ротенберг Карина Юрїївна     |Karina Rotenberg                 |1    |
|Vavilov Oleksandr            |VAVILOV Aleksander Aleksandrovich|1    |
|SHUVALOVA Olga Viktorovna    |Olga Viktorovna SHUVALOVA        |1    |
|Manal AL-AKHRAZ              |Manal AL-ASSAD                   |1    |
|Andrej Jurjevitj PAVLJUTJENKO|ПАВЛЮЧЕНКО Андрей Юрьевич        |1    |
|Доля Ніколай Івнович         |ДОЛЯ Ниĸолай Иванович            

                                                                                

# Machine Learning Approaches to Name Matching

In this section we fine-tune a pre-trained embedding model to our task, try it on our data and search for a threshold similarity that results in good performance for our address matching problem.

## Text Embeddings, Sentence Encoding, `SentenceTransformers`, Vector Distance and Cosine Similarity

Text embeddings are trained on large volumes of text that include the names of people and companies. As a result they have some understanding of address strings and can do a form of semantic comparison that is less explicit than logical comparisons with address parsing. They're an important benchmark to explore. Huggingface has an excellent [introduction to sentence similarity](https://huggingface.co/tasks/sentence-similarity).

In our first machine learning approach, we are going to use transfer learning to load a pre-trained [sentence transformer](https://sbert.net) models from huggingface. We will use the training data we've prepared to fine-tune this model to our task, before rigorously evaluating it along with our other approaches.

Sentence transformers sentence encode strings of different distances into fixed-length vectors, a technique called sentence encoding. Once two address strings are embedded into a pair of equal length vectors, they can be compared with cosine similarity to get a distance, the inverse of which is a similarity score.

### Convert our `pyspark.sql.DataFrame` to a `List[sentence_transformers.InputExample]`

First we need to convert our PySpark `pyspark.sql.DataFrame` to a list of sentence transformer input examples. `InputExamples` require two fields `texts=List[str, str]` and `label`.

In [None]:
train_df, tmp_df = train_test_split(augment_results_df, test_size=0.2, shuffle=True)
eval_df, test_df = train_test_split(tmp_df, test_size=0.5, random_state=42, shuffle=True)

train_dataset = Dataset.from_dict({
    "sentence1": train_df["Address1"].tolist(),
    "sentence2": train_df["Address2"].tolist(),
    "label": train_df["Label"].tolist(),
})

eval_dataset = Dataset.from_dict({
    "sentence1": eval_df["Address1"].tolist(),
    "sentence2": eval_df["Address2"].tolist(),
    "label": eval_df["Label"].tolist(),
})

test_dataset = Dataset.from_dict({
    "sentence1": test_df["Address1"].tolist(),
    "sentence2": test_df["Address2"].tolist(),
    "label": test_df["Label"].tolist(),
})

print(f"Training data:   {len(train_df):,}")
print(f"Evaluation data: {len(eval_df):,}")
print(f"Test data:       {len(eval_df):,}")

### Configure Fine-Tuning, Initialize a `SentenceTransformer`

To use the training data we prepared to fine-tune a `SentenceTransformer`, we need to select and load a pre-trained model from Huggingface Hub. Here are some models you can try:

* [sentence-transformers/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2) - multilingual paraphrase models are designed to compare sentences in terms of their semantics.
* [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3) - a robust, multilingual model optimized for a variety of tasks
* [sentence-transformers/paraphrase-multilingual-mpnet-base-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-mpnet-base-v2) - MPNet is another paraphrase model architecture we can fine-tune for address comparison
* [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) - a top performing MPNet model

In [None]:
# SBERT_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
# SBERT_MODEL = "joe32140/ModernBERT-base-msmarco"  # See https://huggingface.co/joe32140/ModernBERT-base-msmarco
# SBERT_MODEL = "rjurney/ModernBERT-address-matcher"
SBERT_MODEL = "answerdotai/ModernBERT-base"
VARIANT = "original"
MODEL_SAVE_NAME = (SBERT_MODEL + "-" + VARIANT).replace("/", "-")

# Make sure these match the values in the data augmentation notebook for accurate loggging and reporting
CLONES_PER_RUN = 100
RUNS_PER_EXAMPLE = 2

EPOCHS = 10
BATCH_SIZE = 32
PATIENCE = 2
LEARNING_RATE = 5e-5
DATASET_MULTIPLE = CLONES_PER_RUN * RUNS_PER_EXAMPLE
SBERT_OUTPUT_FOLDER = f"data/fine-tuned-sbert-{MODEL_SAVE_NAME}"
SAVE_EVAL_STEPS = 100

### Initialize Weights & Biases

Weights and biases `wandb` package makes it simple to monitor the performance of your training runs.

In [None]:
# Initialize Weights & Biases
wandb.init(
    entity="rjurney",
    # set the wandb project where this run will be logged
    project="libpostal-reborn",
    # track hyperparameters and run metadata
    config={
        "variant": VARIANT,
        "dataset_multiple": DATASET_MULTIPLE,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "patience": PATIENCE,
        "learning_rate": LEARNING_RATE,
        "sbert_model": SBERT_MODEL,
        "model_save_name": MODEL_SAVE_NAME,
        "sbert_output_folder": SBERT_OUTPUT_FOLDER,
        "save_eval_steps": SAVE_EVAL_STEPS,
    },
    save_code=True,
)

### Setup our `SentenceTransformer` Model

Choose the model to fine-tune above in `SBERT_MODEL` and instantiate it below.

In [None]:
# sbert_model = SentenceTransformer(
#     SBERT_MODEL,
#     device=device,
#     model_card_data=SentenceTransformerModelCardData(
#         language="en",
#         license="apache-2.0",
#         model_name=f"{SBERT_MODEL}-address-matcher-{VARIANT}",
#     ),
# )

### Pool the new `ModernBERT` Model

We will setup mean pooling to aggregate the tokens in a string into a single overall value (which pre-loaded `SentenceTransformers` already do).

In [None]:
# 1) Transformer module
word_embedding_model = models.Transformer(SBERT_MODEL, max_seq_length=8192)

# 2) Pooling module
pooling_model = models.Pooling(
    word_embedding_model.get_word_embedding_dimension(),
    pooling_mode_cls_token=False,
    pooling_mode_mean_tokens=False,
    pooling_mode_max_tokens=True
)

# Combine the modules to get a SentenceTransformer
sbert_model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

### Use Multi-GPU if Available

In [None]:
import torch

device_count = torch.cuda.device_count()
print(f"GPU count: {device_count:,}")

pool = sbert_model.start_multi_process_pool()

In [None]:
# pool = sbert_model.start_multi_process_pool()

# sbert_model.encode_multi_process(sentences, pool)

### Evaluate our Model Before Fine-Tuning

Let's see what it can do without fine-tuning, then we'll compare our subjective results afterwards. This won't work very well, fine-tuning is required!

In [None]:
def sbert_compare(address1: str, address2: str) -> float:
    """sbert_compare - sentence encode each address into a fixed-length text embedding.
    Fixed-length means they can be compared with cosine similarity."""
    embedding1 = sbert_model.encode(address1)
    embedding2 = sbert_model.encode(address2)

    # Compute cosine similarity
    return 1 - distance.cosine(embedding1, embedding2)


def sbert_match(row: pd.Series) -> pd.Series:
    """sbert_match - SentenceTransformer address matching, float iytoyt"""
    return sbert_compare(row["Address1"], row["Address2"])


def sbert_compare_binary(address1: str, address2: str, threshold: float = 0.5) -> Literal[0, 1]:
    """sbert_match - compare and return a binary match"""
    similarity = sbert_compare(address1, address2)
    return 1 if similarity >= threshold else 0


def sbert_match_binary(row: pd.Series, threshold: float = 0.5) -> pd.Series:
    """sbert_match_binary - SentenceTransformer address matching, binary output"""
    return sbert_compare_binary(row["Address1"], row["Address2"], threshold=threshold)

In [None]:
# Still too similar - very hard to train them away from this behavior!
sbert_compare(
    "101 Oak Lane, Atlanta, GA 30308",
    "102 Oak Lane, Atlanta, GA 30308",
)

In [None]:
# A little bit further away ...
sbert_compare(
    "101 Oak Lane, Atlanta, GA 30308",
    "101 Oak Ln., Atlanta, GA 30308",
)

In [None]:
# Properly distant ...
sbert_compare(
    "3413 Sean Way, Lawrenceville, GA 30044",
    "1202 Oak Rd., Lawrenceville, GA 30304",
)

In [None]:
# Properly similar ...
sbert_compare(
    "3413 Sean Way, Lawrenceville, GA 30044",
    "3413 Sean Way, Lawrenceville, GA 30044, USA",
)

### Evaluate the Test Set with the Untrained Model

Let's see how well the [paraphrase-multilingual-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2) model does on its own. This is our baseline score.

In [None]:
# Initialize the evaluator
binary_acc_evaluator = BinaryClassificationEvaluator(
    sentences1=eval_dataset["sentence1"],
    sentences2=eval_dataset["sentence2"],
    labels=eval_dataset["label"],
    name=SBERT_MODEL,
)
pd.DataFrame([binary_acc_evaluator(sbert_model)])

### Computing Metrics with `sklearn.metrics`

We use [scikit-learn metrics](https://scikit-learn.org/stable/modules/model_evaluation.html) instead to compute our evaluation metrics.

In [None]:
# This will rapidly train the embedding model. MultipleNegativesRankingLoss did not work.
loss = losses.ContrastiveLoss(model=sbert_model)

sbert_args = SentenceTransformerTrainingArguments(
    output_dir=SBERT_OUTPUT_FOLDER,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    warmup_ratio=0.1,
    run_name=SBERT_MODEL,
    load_best_model_at_end=True,
    save_total_limit=5,
    save_steps=SAVE_EVAL_STEPS,
    eval_steps=SAVE_EVAL_STEPS,
    save_strategy="steps",
    eval_strategy="steps",
    greater_is_better=False,
    metric_for_best_model="eval_loss",
    learning_rate=LEARNING_RATE,
    logging_dir="./logs",
    weight_decay=0.02,
)

trainer = SentenceTransformerTrainer(
    model=sbert_model,
    args=sbert_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=binary_acc_evaluator,
    compute_metrics=compute_sbert_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=PATIENCE)],
)

trainer.train()

In [None]:
print(f"Best model checkpoint path: {trainer.state.best_model_checkpoint}")

In [None]:
pd.DataFrame([trainer.evaluate()])

In [None]:
trainer.save_model(SBERT_OUTPUT_FOLDER)

In [None]:
print(SBERT_OUTPUT_FOLDER)

In [None]:
wandb.finish()

### Try the Model from Our Best Epoch

We fine-tuned the model for `EPOCHS` nubmer of epochs, but the last epoch isn't always best. The `TrainingArgument` `load_best_model_at_end=True` loads the model at the end.

Another way to load the best model is to load our output folder and evaluate that `SentenceTransformer` on some examples to get a gestalt sense for its performance.

```python
sbert_model = SentenceTransformer(OUTPUT_FOLDER, device=device)
```

In [None]:
sbert_compare(
    "101 Oak Lane, Atlanta, GA 30308",
    "102 Oak Lane, Atlanta, GA 30308",
)

In [None]:
sbert_compare(
    "101 Oak Lane, Macon, GA 30308",
    "101 Oak Lane, Atlanta, GA 30408",
)

In [None]:
sbert_compare(
    "101 Oak Lane, Atlanta, GA 30308",
    "101 Oak Ln., Atlanta, GA 30308",
)

In [None]:
sbert_compare(
    "3413 Sean Way, Lawrenceville, GA 30044",
    "1202 Oak Rd., Lawrenceville, GA 30304",
)

In [None]:
sbert_compare(
    "3413 Sean Way, Lawrenceville, GA 30044",
    "3413 Sean Way, Lawrenceville, GA 30044, USA",
)

### Evaluate ROC Curve to Determine Optimum Similarity Threshold

0.5 is an arbitrary line on which to divide positive (match, 1) and negative (mismatch, 0). Let's evaluate the ROC Curve of the F1 score to see what it should be set to. Recall that the `sbert_match` function has a `threshold: float = 0.5` argument.

#### Evaluate on our Augmented Test Dataset

First we'll evaluate the ROC curve on our augmented test dataset.

In [None]:
y_true = test_df["Label"]
y_scores = test_df.apply(sbert_match, axis=1)

In [None]:
# Compute precision-recall curve
precision, recall, thresholds = precision_recall_curve(y_true, y_scores)

# Compute F1 score for each threshold
f1_scores = [f1_score(y_true, y_scores >= t) for t in thresholds]

# Find the threshold that maximizes the F1 score
best_threshold_index = np.argmax(f1_scores)
best_threshold = thresholds[best_threshold_index]
best_f1_score = f1_scores[best_threshold_index]

print(f'Best Threshold: {best_threshold}')
print(f'Best F1 Score: {best_f1_score}')

roc_auc = roc_auc_score(y_true, y_scores)
print(f'AUC-ROC: {roc_auc}')

In [None]:
# Create a DataFrame for Seaborn
pr_data = pd.DataFrame({
    'Precision': precision[:-1],
    'Recall': recall[:-1],
    'F1 Score': f1_scores
})

# Plot Precision-Recall curve using Seaborn
sns.lineplot(data=pr_data, x='Recall', y='Precision', marker='o')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Augmented Test Set Precision-Recall Curve')
plt.show()

### Plot a ROC Curve for our Gold Labeled Data

We need to see the ROC Curve for our gold labeled data as well. We care more about performance on this data.

In [None]:
y_true = gold_df["Label"]
y_scores = gold_df.apply(sbert_match, axis=1)

In [None]:
# Compute precision-recall curve
precision, recall, thresholds = precision_recall_curve(y_true, y_scores)

# Compute F1 score for each threshold
f1_scores = [f1_score(y_true, y_scores >= t) for t in thresholds]

# Find the threshold that maximizes the F1 score
best_threshold_index = np.argmax(f1_scores)
best_threshold = thresholds[best_threshold_index]
best_f1_score = f1_scores[best_threshold_index]

print(f'Best Threshold: {best_threshold}')
print(f'Best F1 Score: {best_f1_score}')

roc_auc = roc_auc_score(y_true, y_scores)
print(f'AUC-ROC: {roc_auc}')

In [None]:
# Create a DataFrame for Seaborn
pr_data = pd.DataFrame({
    'Precision': precision[:-1],
    'Recall': recall[:-1],
    'F1 Score': f1_scores
})

# Plot Precision-Recall curve using Seaborn
sns.lineplot(data=pr_data, x='Recall', y='Precision', marker='o')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Gold Label Precision-Recall Curve')
plt.show()

### Debugging Errors on our Gold Labels

Let's evaluate the data using our `gold_label_report` function with the best F1 score. Then we can view the errors and figure out where our model is failing.

In [None]:
raw_df, grouped_df = gold_label_report(
    gold_df,
    [
        sbert_match_binary,
    ],
    threshold=best_threshold
)

#### Label Description Group Analysis

You can see the types of address pairs we are failing on. This can guide our data augmentation / programmatic labeling work at a high level.

In [None]:
grouped_df.head(40)

In [None]:
grouped_df["sbert_match_binary_acc"].sort_values().head(40)

#### What it Got Right ...

In [None]:
# Truthiness analysis
correct_df = raw_df[raw_df["sbert_match_binary_correct"]].reset_index()
print(f"Number correct: {len(correct_df):,}")

correct_df.head(20)

In [None]:
# Error analysis
wrong_df = raw_df[raw_df["sbert_match_binary_correct"] == False].reset_index()
print(f"Number wrong: {len(wrong_df):,}")

wrong_df.head(20)