In [None]:
%load_ext autoreload
%autoreload 2

# Evaluate a bi-encoder model

Load a bi-encoder (either one trained directly or one trained from the output of a cross-encoder)
and evaluate it.

Note that a bi-encoder trained directly has a small advantage here because it's being evaluated on its training data.

| type | notes  | high-random | low-non-negative  | differences  | pos-neg               |
| ---- | ------ | ----------- | ----------------- | ------------ | --------------------- |
| orig |        | 496483 129 2| 88259 65071 36343 | 200002 38914 | 115358 7485 5680 1805 |


In [None]:
import re

import matplotlib.pyplot as plt
import pandas as pd
import torch
from tqdm.auto import tqdm

from src.models.biencoder import BiEncoder
from src.models.tokenizer import get_tokenize_function_and_vocab

In [None]:
given_surname = "given"
num_common_names = 10000
bi_encoder_type = 'orig'

max_tokens = 10

pref_path = f"s3://familysearch-names/processed/tree-preferred-{given_surname}-aggr.csv.gz"
test_triplets_path=f"../data/processed/tree-hr-{given_surname}-triplets-v2-1000.csv.gz"
subwords_path=f"../data/models/fs-{given_surname}-subword-tokenizer-2000f.json"
model_path = f"../data/models/bi_encoder-{given_surname}-{bi_encoder_type}.pth"

common_non_negatives_path = f"../references/common_{given_surname}_non_negatives.csv"
name_variants_path = f"../references/{given_surname}_variants.csv"
given_nicknames_path = "../references/givenname_nicknames.csv"

In [None]:
torch.cuda.empty_cache()
print(torch.cuda.is_available())
print("cuda total", torch.cuda.get_device_properties(0).total_memory)
print("cuda reserved", torch.cuda.memory_reserved(0))
print("cuda allocated", torch.cuda.memory_allocated(0))

## Load data

In [None]:
# read triplets
triplets_df = pd.read_csv(test_triplets_path, na_filter=False)
print(len(triplets_df))
triplets_df.head(3)

### read common names

In [None]:
pref_df = pd.read_csv(pref_path, na_filter=False)
common_names = [name for name in pref_df['name'][:num_common_names].tolist() \
                if len(name) > 1 and re.fullmatch(r'[a-z]+', name)]
pref_df = None
len(common_names)

### read common non-negatives

In [None]:
common_non_negatives = set()

def add_common_non_negative(name1, name2):
    if name1 > name2:
        name1, name2 = name2, name1
    common_non_negatives.add(f"{name1}:{name2}")

def is_common_non_negative(name1, name2):
    if name1 > name2:
        name1, name2 = name2, name1
    return f"{name1}:{name2}" in common_non_negatives

In [None]:
common_non_negatives_df = pd.read_csv(common_non_negatives_path, na_filter=False)
for name1, name2 in common_non_negatives_df.values.tolist():
    add_common_non_negative(name1, name2)
len(common_non_negatives)

In [None]:
name_variants_df = pd.read_csv(name_variants_path, na_filter=False)
for name1, name2 in name_variants_df.values.tolist():
    add_common_non_negative(name1, name2)

In [None]:
if given_surname == "given":
    with open(given_nicknames_path, "rt") as f:
        for line in f.readlines():
            names = line.split(',')
            for name1 in names:
                for name2 in names:
                    if name1 > name2:
                        add_common_non_negative(name1, name2)
len(common_non_negatives)

### load tokenizer

In [None]:
tokenize, tokenizer_vocab = get_tokenize_function_and_vocab(
    max_tokens=max_tokens,
    subwords_path=subwords_path,
)
len(tokenizer_vocab)

### load bi-encoder model

In [None]:
model = torch.load(model_path)

## Evaluate bi-encoder

### how many random pairs score highly?

In [None]:
%%time

total = 0
cnt = 0
bad_cnt = 0
n_names = 1000
pos_threshold = 0.5
bad_threshold = 0.7
common_negative_scores = []
for ix, pos in enumerate(common_names[:n_names]):
    pos_tokens = tokenize(pos)
    for neg in common_names[ix+1:n_names]:
        if is_common_non_negative(pos, neg):
            continue
        neg_tokens = tokenize(neg)
        sim = model.predict(pos_tokens, neg_tokens)
        common_negative_scores.append(sim)
        if sim > pos_threshold:
            print(pos, neg, sim, '***' if sim > bad_threshold else '')
            cnt += 1
            bad_cnt += 1 if sim > bad_threshold else 0
        total += 1
print(total, cnt, bad_cnt)

### how many common non-negatives score low?

In [None]:
%%time

neg_threshold = 0.4
bad_threshold = 0.2
total = 0
cnt = 0
bad_cnt = 0
for pair in tqdm(common_non_negatives):
    name1, name2 = pair.split(':')
    sim = model.predict(tokenize(name1), tokenize(name2))
    if sim < neg_threshold:
        if cnt < 50:
            print(name1, name2, sim, '***' if sim < bad_threshold else '')
        cnt += 1
        bad_cnt += 1 if sim < bad_threshold else 0
    total += 1
print(total, cnt, bad_cnt)

In [None]:
len(common_non_negatives)

### how many pairs score significantly differently than their label?

In [None]:
%%time

threshold = 0.25
total = 0
cnt = 0
for ix, (anchor, pos, pos_score, neg, neg_score) in tqdm(enumerate(zip(
    triplets_df['anchor'], 
    triplets_df['positive'], 
    triplets_df['positive_score'], 
    triplets_df['negative'],
    triplets_df['negative_score'],
))):
    if ix > 100000:
        break
    anchor_toks = tokenize(anchor)
    pos_toks = tokenize(pos)
    neg_toks = tokenize(neg)
    pos_pred = model.predict(anchor_toks, pos_toks)
    neg_pred = model.predict(anchor_toks, neg_toks)
    if abs(pos_score - pos_pred) > threshold:
        if cnt < 50:
            print(anchor, pos, pos_pred, pos_score)
        cnt += 1
    if abs(neg_score - neg_pred) > threshold:
        if cnt < 50:
            print(anchor, neg, neg_pred, neg_score)
        cnt += 1
    total += 2
print(total, cnt)

### how many positive pairs score negatively, and how many negative pairs score positively?

In [None]:
%%time

threshold = 0.1
total = 0
cnt = 0
pos_neg_cnt = 0
neg_pos_cnt = 0
for ix, (anchor, pos, pos_score, neg, neg_score) in tqdm(enumerate(zip(
    triplets_df['anchor'], 
    triplets_df['positive'], 
    triplets_df['positive_score'], 
    triplets_df['negative'],
    triplets_df['negative_score'],
))):
    if ix > 100000:
        break
    anchor_toks = tokenize(anchor)
    pos_toks = tokenize(pos)
    neg_toks = tokenize(neg)
    pos_pred = model.predict(anchor_toks, pos_toks)
    neg_pred = model.predict(anchor_toks, neg_toks)
    if pos_score >= 0.5+threshold or pos_score < 0.5-threshold:
        if (pos_score >= 0.5+threshold and pos_pred < 0.5-threshold) or (pos_score < 0.5-threshold and pos_pred >= 0.5+threshold):
            if pos_score >= 0.5+threshold and pos_pred < 0.5:
                pos_neg_cnt += 1
                pos_neg = True
            else:
                neg_pos_cnt += 1
                pos_neg = False
            if cnt < 50:
                print(anchor, pos, pos_pred, pos_score, '***' if pos_neg else '')
            cnt += 1
        total += 1
    if neg_score >= 0.5+threshold or neg_score < 0.5-threshold:
        if (neg_score >= 0.5+threshold and neg_pred < 0.5-threshold) or (neg_score < 0.5-threshold and neg_pred >= 0.5+threshold):
            if neg_score >= 0.5+threshold and neg_pred < 0.5:
                pos_neg_cnt += 1
                pos_neg = True
            else:
                neg_pos_cnt += 1
                pos_neg = False
            if cnt < 50:
                print(anchor, neg, neg_pred, neg_score, '***' if pos_neg else '')
            cnt += 1
        total += 1
print(total, cnt, pos_neg_cnt, neg_pos_cnt)


### graph results

In [None]:
lower_threshold = 0.1
non_negative_scores = []
cnt = 0
for pair in tqdm(common_non_negatives):
    name1, name2 = pair.split(':')
    sim = model.predict(tokenize(name1), tokenize(name2))
    non_negative_scores.append(sim)
    if sim < lower_threshold:
        print(name1, name2, sim)
        cnt += 1
print(cnt)

In [None]:
print(len(non_negative_scores), len(common_negative_scores))

In [None]:
plt.figure(figsize=(10, 6))
plt.hist(non_negative_scores, bins=30, alpha=0.5, label="Non negatives", color='green')
plt.hist(common_negative_scores, bins=30, alpha=0.5, label="Common negatives", color='red')
plt.title('Overlapping Histogram')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.legend(loc='upper right')

# Show the plot
plt.tight_layout()
plt.show()