In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Compare swivel and levenshtein scores

In [None]:
from collections import namedtuple

import jellyfish
import matplotlib.pyplot as plt
from mpire import WorkerPool
import numpy as np
import pandas as pd
import random
from sklearn.model_selection import train_test_split
import torch
from tqdm import tqdm
import wandb

from src.data.filesystem import fopen
from src.data.utils import load_dataset
from src.eval.utils import similars_to_ndarray
from src.models.swivel import SwivelModel, get_best_swivel_matches
from src.models.utils import remove_padding

In [None]:
# config

given_surname = "given"
vocab_size = 610000 if given_surname == "given" else 2100000
sample_size = 5000
embed_dim = 100
encoder_layers = 2
num_matches = 5000
batch_size = 256
Config = namedtuple("Config", [
    "eval_path",
    "embed_dim",
    "swivel_vocab_path",
    "swivel_model_path",
])
config = Config(
    eval_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-train.csv.gz",
    embed_dim=embed_dim,
    swivel_vocab_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-vocab-{vocab_size}-augmented.csv",
    swivel_model_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-model-{vocab_size}-{embed_dim}-augmented.pth",
)

In [None]:
# wandb.init(
#     project="nama",
#     entity="nama",
#     name="64_analyze_scores",
#     group=given_surname,
#     notes="",
#     config=config._asdict(),
# )

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
print("cuda total", torch.cuda.get_device_properties(0).total_memory)
print("cuda reserved", torch.cuda.memory_reserved(0))
print("cuda allocated", torch.cuda.memory_allocated(0))

In [None]:
input_names_eval, weighted_actual_names_eval, candidate_names_eval = load_dataset(config.eval_path, is_eval=True)

In [None]:
vocab_df = pd.read_csv(fopen(config.swivel_vocab_path, "rb"))
swivel_vocab = {name: _id for name, _id in zip(vocab_df["name"], vocab_df["index"])}

In [None]:
swivel_model = SwivelModel(len(swivel_vocab), config.embed_dim)
swivel_model.load_state_dict(torch.load(fopen(config.swivel_model_path, "rb"), map_location=torch.device(device)))
swivel_model.to(device)
swivel_model.eval()

In [None]:
_, input_names_eval_sample, _, weighted_actual_names_eval_sample = \
    train_test_split(input_names_eval, weighted_actual_names_eval, test_size=sample_size)
candidate_names_eval_sample = candidate_names_eval

In [None]:
print("input_names_eval_sample", len(input_names_eval_sample))
print("weighted_actual_names_eval_sample", len(weighted_actual_names_eval_sample))
print("candidate_names_eval_sample", len(candidate_names_eval_sample))

In [None]:
def calc_similarity_to(name):
    name = remove_padding(name)

    def calc_similarity(row):
        cand_name = remove_padding(row[0])
        dist = jellyfish.levenshtein_distance(name, cand_name)
        return 1 - (dist / max(len(name), len(cand_name)))

    return calc_similarity

In [None]:
def get_similars(shared, names, _=None):
    candidate_names_test, k = shared
    
    def get_similars_for_name(name):
        scores = np.apply_along_axis(calc_similarity_to(name), 1, candidate_names_test[:, None])

        # sorted_scores_idx = np.argsort(scores)[::-1][:k]
        partitioned_idx = np.argpartition(scores, -k)[-k:]
        sorted_partitioned_idx = np.argsort(scores[partitioned_idx])[::-1]
        sorted_scores_idx = partitioned_idx[sorted_partitioned_idx]

        candidate_names = candidate_names_test[sorted_scores_idx]
        candidate_scores = scores[sorted_scores_idx]

        return list(zip(candidate_names, candidate_scores))
    
    result = []
    for name in names:
        result.append(get_similars_for_name(name))
    return result

In [None]:
def create_batches(names, batch_size):
    batches = []
    for ix in range(0, len(names), batch_size):
        # batches are tuples to keep mpire from expanding the batch 
        batches.append((names[ix:ix + batch_size], ix))
    return batches

In [None]:
swivel_names_scores = get_best_swivel_matches(model=swivel_model, 
                                              vocab=swivel_vocab, 
                                              input_names=input_names_eval_sample,
                                              candidate_names=candidate_names_eval_sample, 
                                              encoder_model=None,
                                              k=num_matches, 
                                              batch_size=batch_size,
                                              add_context=True,
                                              n_jobs=1)

In [None]:
print(sum(len(names_scores) for names_scores in swivel_names_scores))

In [None]:
input_names_batches = create_batches(input_names_eval_sample, batch_size=batch_size)
with WorkerPool(
    shared_objects=(candidate_names_eval_sample, num_matches),
) as pool:
    lev_names_scores = pool.map(get_similars, input_names_batches, progress_bar=True)
# flatten
lev_names_scores = [name_score for batch in lev_names_scores for name_score in batch]
# convert to ndarray
lev_names_scores = similars_to_ndarray(lev_names_scores)

In [None]:
print(sum(len(names_scores) for names_scores in lev_names_scores))

In [None]:
# find pairs in both with score above a threshold
swivel_threshold = 0.45
lev_threshold = 0.55
sample_rate = 0.01
xs = []
ys = []
cs = []
xs_pos = []
ys_pos = []
xs_neg = []
ys_neg = []
weights = []
actual_score_counts = []
swivel_score_counts = []
lev_score_counts = []
all_candidate_names = set(candidate_names_eval_sample)
for input_name, wans, swivels, levs in \
    zip(input_names_eval_sample, weighted_actual_names_eval_sample, swivel_names_scores, lev_names_scores):
    # actuals - ensure names are in all_candidate_names
    actual_weights = {name: weight for name, weight, _ in wans if name in all_candidate_names}
    actual_score_counts.append(len(actual_weights))
    # swivel
    swivel_scores = {name: score for name, score in swivels if score >= swivel_threshold}
    swivel_names = set(swivel_scores.keys())
    swivel_score_counts.append(len(swivel_scores))
    # levenshtein
    lev_scores = {name: score for name, score in levs if score >= lev_threshold}
    lev_names = set(lev_scores.keys())
    lev_score_counts.append(len(lev_scores))
    
    # count various scores
    candidate_names = swivel_names.intersection(lev_names)
    for candidate_name in candidate_names:
        if random.random() > sample_rate:
            continue
        swivel_score = swivel_scores[candidate_name]
        lev_score = lev_scores[candidate_name]
        xs.append(swivel_score)
        ys.append(lev_score)
        if candidate_name in actual_weights:
            cs.append('green')
            xs_pos.append(swivel_score)
            ys_pos.append(lev_score)
            weights.append(actual_weights[candidate_name])
            del actual_weights[candidate_name]
        else:
            cs.append('red')
            xs_neg.append(swivel_score)
            ys_neg.append(lev_score)
#     for name in actual_weights.keys():
#         if name not in swivel_names:
#             print("swivel", input_name, name)
#         if name not in lev_names:
#             print("lev", input_name, name)

In [None]:
print(sum(swivel_score_counts), sum(lev_score_counts))
print(len(cs), len([c for c in cs if c == 'green']), sum(actual_score_counts)*sample_rate)

In [None]:
figure, ax = plt.subplots(1, 1, figsize=(20, 15))
ax.set_title("Swivel vs Levenshtein score")
ax.scatter(x=xs, y=ys, c=cs)
plt.xlabel("swivel score")
plt.ylabel("levenshtein score")
plt.xlim([swivel_threshold, 1.0])
plt.ylim([lev_threshold, 1.0])
plt.show()

In [None]:
figure, ax = plt.subplots(1, 1, figsize=(20, 15))
ax.set_title("Swivel vs Levenshtein score - positive only")
ax.scatter(x=xs_pos, y=ys_pos)
plt.xlabel("swivel score")
plt.ylabel("levenshtein score")
plt.xlim([swivel_threshold, 1.0])
plt.ylim([lev_threshold, 1.0])
plt.show()

In [None]:
figure, ax = plt.subplots(1, 1, figsize=(20, 15))
ax.set_title("Swivel vs Levenshtein score - negative only")
ax.scatter(x=xs_neg, y=ys_neg)
plt.xlabel("swivel score")
plt.ylabel("levenshtein score")
plt.xlim([swivel_threshold, 1.0])
plt.ylim([lev_threshold, 1.0])
plt.show()

In [None]:
figure, ax = plt.subplots(1, 1, figsize=(20, 15))
ax.set_title("Weights")
ax.hist(x=weights, bins=100)
plt.show()

In [None]:
figure, ax = plt.subplots(1, 1, figsize=(20, 15))
ax.set_title("Swivel score counts")
ax.hist(x=swivel_score_counts, bins=100)
plt.show()

In [None]:
figure, ax = plt.subplots(1, 1, figsize=(20, 15))
ax.set_title("Levenshtein score counts")
ax.hist(x=lev_score_counts, bins=100)
plt.show()