In [None]:
%load_ext autoreload
%autoreload 2

# Evaluate a bi-encoder model

Load a bi-encoder (either one trained directly or one trained from the output of a cross-encoder)
and evaluate it.

Note that a bi-encoder trained directly has an advantage here because it's being evaluated on its training data.

| type          | high-random    | low-non-negative  | differences| pos-neg|test (high low)|rps-above nns-below|
| ------------- | -------------- | ----------------- | ---------- | --------- | -------- | -------------------- |
| orig-old      | 495902 2051 121|674868 313972 76494|200002 38914| 156783 13212 | 1215 16786 | 13088 150760 |
| orig          | 495902 3444 196|674868 100534 14828|200002 9700 | 156783 599|1938 4081 |0.008 0.47 7390 56782 |
| ce-3          |495902 14295 598|674868 46658 4128  |200002 26806| 156783 68 |1892 5294 |0.13 0.56 7610 67924  |   
| ce-5          |495902 13768 638|674868 48352 5040  |200002 25367| 156783 73 |1871 5661 | 0.13 0.55 7693 68824 |
|ce-1@1        |495902 218716 3081|674868 13274   80 |200002 48979| 200002  0 |1693 8387 | 0.29 0.63 6962 84512 |
|ce-2@1        |495902 226092 2622|674868 13178   28 |200002 48680| 200002  0 |1669 8572 | 0.30 0.63 5991 90902 |
|
| aug38         | 495902 2783 147|674868 116760 17228|200002 7717 | 156783 670|1937 3946 | 0.005 0.45 6447 60470|
| aug33         | 495902 2281 131|674868 174158 24706|200002 8601 |156783 1564|1742 5070 | 0.002 0.42 6902 76200|
| aug40         | 495902 3140 175|674868  98194 15300|200002 7561 |200002 1586|2018 3594 | 0.007 0.47 6406 55454|
|aug40a         | 495902 2785 151|674868 114850 17370|200002 7840 |200002 2066|1952 3868 | 0.005 0.45 6438 59338|
| unaug         |495902 13578 616|674868 157320 24086|200002 17360|200002 5961|1351 14113|0.042 0.42 35167 89842|
| ce100         |495902 1592 554|674868 473534 404346|200002 152081|200002 140401|471 113326|0.001 0.16 92970 381612|
|ce100@3-aug40@6| 495902 3145 153|674868 101122 15260|200002  7593|200002 1724|2011 3396 | 0.007 0.47 6489 56714|
|aug40@6-ce0@1 |495902 34316 1661| 674868 32488  8754|200002 47882|200002   18|1844 5690 |0.061 0.62 21415 41730|
|aug40@6-ce0@6 |495902 33970 1646| 674868 34092  9010|200002 47085|200002   25|1831 5869 |0.060 0.61 21731 42944|
|aug40@6        | 495902 3057 151| 674868 98958 15140|200002 7595 |200002 1612|2018 3493 |0.007 0.47  6449 55490|
|aug40+ce0@3    | 495902 5239 243| 674868 49410  8228|200002 14994|200002  278|2207 2817 |0.030 0.54  6265 44058|

* aug40a has 0 smoothing in notebook 200_generate_triplets; the others have 20 smoothing

Compare
-
- CE?@1 + aug40@1
- CE?@1 + aug40@6
- CE100@6


In [None]:
import re

import matplotlib.pyplot as plt
import pandas as pd
from scipy.stats import ttest_ind, mannwhitneyu
import torch

from src.data.utils import read_csv
from src.models.biencoder import BiEncoder
from src.models.tokenizer import get_tokenize_function_and_vocab

In [None]:
given_surname = "given"
num_common_names = 10000
model_type = 'aug40+ce0'

model_path = f"../data/models/bi_encoder-{given_surname}-{model_type}.pth"
test_triplets_path=f"../data/processed/tree-hr-{given_surname}-triplets-v2-1000.csv.gz"

max_tokens = 10
subwords_path=f"../data/models/fs-{given_surname}-subword-tokenizer-2000f.json"

pref_path = f"s3://familysearch-names/processed/tree-preferred-{given_surname}-aggr.csv.gz"
common_non_negatives_path = f"../data/processed/common_{given_surname}_non_negatives-augmented.csv"

In [None]:
torch.cuda.empty_cache()
print(torch.cuda.is_available())
print("cuda total", torch.cuda.get_device_properties(0).total_memory)
print("cuda reserved", torch.cuda.memory_reserved(0))
print("cuda allocated", torch.cuda.memory_allocated(0))

## Load data

In [None]:
# read triplets
triplets_df = read_csv(test_triplets_path)
print(len(triplets_df))
triplets_df.head(30)

### read common names

In [None]:
pref_df = read_csv(pref_path)
common_names = [name for name in pref_df['name'][:num_common_names].tolist() \
                if len(name) > 1 and re.fullmatch(r'[a-z]+', name)]
pref_df = None
len(common_names)

### read common non-negatives

In [None]:
common_non_negatives = set()

common_non_negatives_df = read_csv(common_non_negatives_path)
for name1, name2 in common_non_negatives_df.values.tolist():
    common_non_negatives.add((name1, name2))
len(common_non_negatives)

### load tokenizer

In [None]:
tokenize, tokenizer_vocab = get_tokenize_function_and_vocab(
    max_tokens=max_tokens,
    subwords_path=subwords_path,
)
len(tokenizer_vocab)

### load bi-encoder

In [None]:
model = torch.load(model_path)
model.eval()

## Evaluate bi-encoder

### how many random pairs score high?

In [None]:
%%time

total = 0
cnt = 0
bad_cnt = 0
n_names = 1000
pos_threshold = 0.3
bad_threshold = 0.5
random_pair_scores = []
for ix, pos in enumerate(common_names[:n_names]):
    pos_tokens = tokenize(pos)
    for neg in common_names[ix+1:n_names]:
        if (pos, neg) in common_non_negatives:
            continue
        neg_tokens = tokenize(neg)
        sim = model.predict(pos_tokens, neg_tokens)
        random_pair_scores.append(sim)
        if sim > pos_threshold:
            print(pos, neg, sim, '***' if sim > bad_threshold else '')
            cnt += 1
            bad_cnt += 1 if sim > bad_threshold else 0
        total += 1
print(total, cnt, bad_cnt)

### how many common non-negatives score low?

In [None]:
%%time

neg_threshold = 0.3
bad_threshold = 0.1
total = 0
cnt = 0
bad_cnt = 0
non_negative_scores = []
for name1, name2 in common_non_negatives:
    sim = model.predict(tokenize(name1), tokenize(name2))
    non_negative_scores.append(sim)
    if sim < neg_threshold:
        if cnt < 50:
            print(name1, name2, sim, '***' if sim < bad_threshold else '')
        cnt += 1
        bad_cnt += 1 if sim < bad_threshold else 0
    total += 1
print(total, cnt, bad_cnt)

### how many pairs score significantly differently than their label?

In [None]:
%%time

threshold = 0.25
total = 0
cnt = 0
for ix, (anchor, pos, pos_score, neg, neg_score) in enumerate(zip(
    triplets_df['anchor'], 
    triplets_df['positive'], 
    triplets_df['positive_score'], 
    triplets_df['negative'],
    triplets_df['negative_score'],
)):
    if ix > 100_000:
        break
    anchor_toks = tokenize(anchor)
    pos_toks = tokenize(pos)
    neg_toks = tokenize(neg)
    pos_pred = model.predict(anchor_toks, pos_toks)
    neg_pred = model.predict(anchor_toks, neg_toks)
    if abs(pos_score - pos_pred) > threshold:
        if cnt < 50:
            print(anchor, pos, pos_pred, pos_score)
        cnt += 1
    if abs(neg_score - neg_pred) > threshold:
        if cnt < 50:
            print(anchor, neg, neg_pred, neg_score)
        cnt += 1
    total += 2
print(total, cnt)

### how many positive pairs score negatively, and how many negative pairs score positively?

In [None]:
%%time

threshold = 0.1
midpoint = 0.3
total_pos = 0
total_neg = 0
cnt = 0
pos_neg_cnt = 0
neg_pos_cnt = 0
for ix, (anchor, pos, pos_score, neg, neg_score) in enumerate(zip(
    triplets_df['anchor'], 
    triplets_df['positive'], 
    triplets_df['positive_score'], 
    triplets_df['negative'],
    triplets_df['negative_score'],
)):
    if ix > 100_000:
        break
    anchor_toks = tokenize(anchor)
    pos_toks = tokenize(pos)
    neg_toks = tokenize(neg)
    pos_pred = model.predict(anchor_toks, pos_toks)
    neg_pred = model.predict(anchor_toks, neg_toks)
    if pos_score >= midpoint+threshold:
        if pos_pred < midpoint-threshold:
            pos_neg_cnt += 1
            if cnt < 50:
                print(anchor, pos, pos_pred, pos_score, '***')
            cnt += 1
        total_pos += 1
    if neg_score >= midpoint+threshold:
        if neg_pred < midpoint-threshold:
            pos_neg_cnt += 1
            if cnt < 50:
                print(anchor, neg, neg_pred, neg_score, '***')
            cnt += 1
        total_pos += 1
    if pos_score < midpoint-threshold:
        if pos_pred >= midpoint+threshold:
            neg_pos_cnt += 1
            if cnt < 50:
                print(anchor, pos, pos_pred, pos_score)
            cnt += 1
        total_neg += 1
    if neg_score < midpoint-threshold:
        if neg_pred >= midpoint+threshold:
            neg_pos_cnt += 1
            if cnt < 50:
                print(anchor, neg, neg_pred, neg_score)
            cnt += 1
        total_neg += 1
print(total_pos, pos_neg_cnt, total_neg, neg_pos_cnt)


### Welch t-test

In [None]:
# larger ttest is better, smaller mann is better
t_ttest, _ = ttest_ind(random_pair_scores, non_negative_scores, equal_var=False)
t_mann, _  = mannwhitneyu(random_pair_scores, non_negative_scores, use_continuity=False)
print(int(abs(t_ttest)), int(t_mann/1_000_000))

In [None]:
rps_mean = sum(random_pair_scores) / len(random_pair_scores)
nps_mean = sum(non_negative_scores) / len(non_negative_scores)
avg_mean = (rps_mean + nps_mean) / 2
rps_above_mean = len([score for score in random_pair_scores if score > avg_mean])
nns_below_mean = len([score for score in non_negative_scores if score < avg_mean])
print(f"{avg_mean:0.2} {len(random_pair_scores)} {len(non_negative_scores)}")
print(f"{rps_mean:0.2} {nps_mean:0.2} {rps_above_mean} {nns_below_mean}")

### graph results

In [None]:
plt.figure(figsize=(10, 6))
plt.hist(non_negative_scores, bins=30, alpha=0.5, label="Non negatives", color='green')
plt.hist(random_pair_scores, bins=30, alpha=0.5, label="Random pairs", color='red')
plt.title('Overlapping Histogram')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.legend(loc='upper right')

# Show the plot
plt.tight_layout()
plt.show()