In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from collections import namedtuple

import jellyfish
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
from mpire import WorkerPool
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.utils.extmath import safe_sparse_dot
import torch
from tqdm import tqdm
import wandb

from src.data.filesystem import fopen
from src.data.utils import load_dataset, select_frequent_k, frequent_k_names
from src.eval import metrics
from src.eval.utils import similars_to_ndarray
from src.models.swivel import SwivelModel, get_best_swivel_matches
from src.models.swivel_encoder import SwivelEncoderModel
from src.models.utils import remove_padding

In [None]:
# config

given_surname = "given"
vocab_size = 610000 if given_surname == "given" else 2100000
sample_size = 50000
embed_dim = 100
encoder_layers = 2
num_matches = 500
batch_size = 256
Config = namedtuple("Config", [
    "train_path",
    "eval_path",
    "test_path",
    "embed_dim",
    "swivel_vocab_path",
    "swivel_model_path",
    "encoder_model_path"
])
config = Config(
    train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-train-augmented.csv.gz",
    eval_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-train.csv.gz",
    test_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-test.csv.gz",
    embed_dim=embed_dim,
    swivel_vocab_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-vocab-{vocab_size}-augmented.csv",
    swivel_model_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-model-{vocab_size}-{embed_dim}-augmented.pth",
#     encoder_model_path=f"s3://nama-data/data/models/fs-{given_surname}-encoder-model-{vocab_size}-{embed_dim}-{encoder_layers}-augmented.pth",
    encoder_model_path=f"s3://nama-data/data/models/fs-{given_surname}-encoder-model-{vocab_size}-{embed_dim}-{encoder_layers}-0.0.pth",
)

In [None]:
wandb.init(
    project="nama",
    entity="nama",
    name="70_compare_similarity",
    group=given_surname,
    notes="",
    config=config._asdict(),
)

### Load data

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
print("cuda total", torch.cuda.get_device_properties(0).total_memory)
print("cuda reserved", torch.cuda.memory_reserved(0))
print("cuda allocated", torch.cuda.memory_allocated(0))

In [None]:
input_names_train, weighted_actual_names_train, candidate_names_train = load_dataset(config.train_path)
input_names_eval, weighted_actual_names_eval, candidate_names_eval = load_dataset(config.eval_path, is_eval=True)
input_names_test, weighted_actual_names_test, candidate_names_test = load_dataset(config.test_path, is_eval=True)

In [None]:
vocab_df = pd.read_csv(fopen(config.swivel_vocab_path, "rb"))
swivel_vocab = {name: _id for name, _id in zip(vocab_df["name"], vocab_df["index"])}

In [None]:
swivel_model = SwivelModel(len(swivel_vocab), config.embed_dim)
swivel_model.load_state_dict(torch.load(fopen(config.swivel_model_path, "rb"), map_location=torch.device(device)))
swivel_model.to(device)
swivel_model.eval()

In [None]:
# encoder_model = None
encoder_model = SwivelEncoderModel(n_layers=encoder_layers, output_dim=config.embed_dim, device=device)
encoder_model.load_state_dict(torch.load(fopen(config.encoder_model_path, "rb"), map_location=torch.device(device)))
encoder_model.to(device)
encoder_model.eval()

In [None]:
# sample names (train, eval, and freq_eval for in-vocab and test for out-of-vocab)
# _, input_names_train_sample, _, weighted_actual_names_train_sample = \
#     train_test_split(input_names_train, weighted_actual_names_train, test_size=sample_size)
# candidate_names_train_sample = candidate_names_train

# _, input_names_eval_sample, _, weighted_actual_names_eval_sample = \
#     train_test_split(input_names_eval, weighted_actual_names_eval, test_size=sample_size)
# candidate_names_eval_sample = candidate_names_eval

# input_names_freq_eval_sample, weighted_actual_names_freq_eval_sample, _ = \
#     select_frequent_k(input_names_eval, weighted_actual_names_eval, candidate_names_eval, \
#                       k=sample_size, all_actuals=True)
# candidate_names_freq_eval_sample = candidate_names_eval

input_names_freq_eval_sample, weighted_actual_names_freq_eval_sample, candidate_names_freq_eval_sample = \
    select_frequent_k(input_names_eval, weighted_actual_names_eval, candidate_names_eval, \
                      k=sample_size)

# _, input_names_test_sample, _, weighted_actual_names_test_sample = \
#     train_test_split(input_names_test, weighted_actual_names_test, test_size=sample_size)
# candidate_names_test_sample = candidate_names_test

In [None]:
print("input_names_train_sample", len(input_names_train_sample))
print("weighted_actual_names_train_sample", len(weighted_actual_names_train_sample))
print("candidate_names_train_sample", len(candidate_names_train_sample))

print("input_names_eval_sample", len(input_names_eval_sample))
print("weighted_actual_names_eval_sample", len(weighted_actual_names_eval_sample))
print("candidate_names_eval_sample", len(candidate_names_eval_sample))

print("input_names_freq_eval_sample", len(input_names_freq_eval_sample))
print("weighted_actual_names_freq_eval_sample", len(weighted_actual_names_freq_eval_sample))
print("candidate_names_freq_eval_sample", len(candidate_names_freq_eval_sample))

print("input_names_test_sample", len(input_names_test_sample))
print("weighted_actual_names_test_sample", len(weighted_actual_names_test_sample))
print("candidate_names_test_sample", len(candidate_names_test_sample))

In [None]:
# free memory
input_names_train = weighted_actual_names_train = None
input_names_eval = weighted_actual_names_eval = None
input_names_test = weighted_actual_names_test = None

### Set up other algorithms

In [None]:
# tfidf
tfidf_vectorizer = TfidfVectorizer(ngram_range=(1, 3), analyzer="char_wb", min_df=10, max_df=0.5)
tfidf_X_train_sample = tfidf_vectorizer.fit_transform(candidate_names_train_sample)
tfidf_X_eval_sample = tfidf_vectorizer.transform(candidate_names_eval_sample)
tfidf_X_freq_eval_sample = tfidf_vectorizer.transform(candidate_names_freq_eval_sample)
tfidf_X_test_sample = tfidf_vectorizer.transform(candidate_names_test_sample)

In [None]:
def calc_similarity_to(name, algo="levenshtein"):
    name = remove_padding(name)

    def calc_similarity(row):
        cand_name = remove_padding(row[0])
        similarity = 0.0
        if algo == "levenshtein":
            dist = jellyfish.levenshtein_distance(name, cand_name)
            similarity = 1 - (dist / max(len(name), len(cand_name)))
        elif algo == "damerau_levenshtein":
            dist = jellyfish.damerau_levenshtein_distance(name, cand_name)
            similarity = 1 - (dist / max(len(name), len(cand_name)))
        elif algo == "jaro_winkler":
            similarity = jellyfish.jaro_winkler_similarity(name, cand_name)

        return similarity

    return calc_similarity

#### Similarity Function

In [None]:
def get_similars(shared, names, _=None):
    candidate_names_test, k, algo, tfidf_vectorizer, tfidf_X_test = shared
    
    def get_similars_for_name(name):
        if algo == "tfidf":
            x = tfidf_vectorizer.transform([name]).toarray()
            scores = safe_sparse_dot(tfidf_X_test, x.T).flatten()
        else:
            scores = np.apply_along_axis(calc_similarity_to(name, algo), 1, candidate_names_test[:, None])

        # sorted_scores_idx = np.argsort(scores)[::-1][:k]
        partitioned_idx = np.argpartition(scores, -k)[-k:]
        sorted_partitioned_idx = np.argsort(scores[partitioned_idx])[::-1]
        sorted_scores_idx = partitioned_idx[sorted_partitioned_idx]

        candidate_names = candidate_names_test[sorted_scores_idx]
        candidate_scores = scores[sorted_scores_idx]

        return list(zip(candidate_names, candidate_scores))
    
    result = []
    for name in names:
        result.append(get_similars_for_name(name))
    return result

#### Create batches

In [None]:
def create_batches(names, batch_size):
    batches = []
    for ix in range(0, len(names), batch_size):
        # batches are tuples to keep mpire from expanding the batch 
        batches.append((names[ix:ix + batch_size], ix))
    return batches

#### Demo

In [None]:
probe_name = "<bostelman>" if given_surname == "surname" else "<richard>"
get_similars((candidate_names_test_sample, 10, "levenshtein", None, None), [probe_name])

### Test tfidf

In [None]:
probe_name = "<schumacher>" if given_surname == "surname" else "<richard>"
get_similars((candidate_names_test_sample, 10, "tfidf", tfidf_vectorizer, tfidf_X_test_sample), [probe_name])

### Test levenshtein

In [None]:
ix = 251
input_names_test_sample[ix]

In [None]:
weighted_actual_names_test_sample[ix]

In [None]:
similar_names_scores = get_similars((candidate_names_test_sample, num_matches, "levenshtein", None, None), [input_names_test_sample[ix]])
similar_names_scores[0][:5]

In [None]:
similar_names_scores = similars_to_ndarray(similar_names_scores)

In [None]:
metrics.weighted_recall_at_threshold(weighted_actual_names_test_sample[ix], similar_names_scores[0], 0.95)

In [None]:
metrics.weighted_recall_at_threshold(weighted_actual_names_test_sample[ix], similar_names_scores[0], 0.75)

# Evaluate each algorithm

In [None]:
SimilarityAlgo = namedtuple("SimilarityAlgo", "name min_threshold max_threshold distances")
similarity_algos = [
    SimilarityAlgo("swivel", 0.0, 1.01, False),
    SimilarityAlgo("swivel_encoder", 0.0, 1.01, False),
#     SimilarityAlgo("tfidf", 0.0, 1.01, False),
    SimilarityAlgo("levenshtein", 0.0, 1.01, False),
#     SimilarityAlgo("damerau_levenshtein", 0.0, 1.01, False),
#     SimilarityAlgo("jaro_winkler", 0.0, 1.01, False),
]

In [None]:
def evaluate_algos(similarity_algos, 
                   swivel_vocab, 
                   swivel_model, 
                   encoder_model, 
                   input_names, 
                   weighted_actual_names, 
                   candidate_names, 
                   tfidf_X):
    n_jobs = 1

    figure, ax = plt.subplots(1, 1, figsize=(20, 15))
    ax.set_title("PR at threshold")
    colors = cm.rainbow(np.linspace(0, 1, len(similarity_algos)))

    for algo, color in zip(similarity_algos, colors):
        print(algo.name)
        if algo.name == "swivel":
            similar_names_scores = get_best_swivel_matches(model=swivel_model, 
                                                           vocab=swivel_vocab, 
                                                           input_names=input_names,
                                                           candidate_names=candidate_names, 
                                                           encoder_model=encoder_model,
                                                           k=num_matches, 
                                                           batch_size=batch_size,
                                                           add_context=True,
                                                           n_jobs=n_jobs)
        elif algo.name == "swivel_encoder":
            similar_names_scores = get_best_swivel_matches(model=None, 
                                                           vocab=None, 
                                                           input_names=input_names,
                                                           candidate_names=candidate_names, 
                                                           encoder_model=encoder_model,
                                                           k=num_matches, 
                                                           batch_size=batch_size,
                                                           add_context=True,
                                                           n_jobs=n_jobs)
        else:
            input_names_batches = create_batches(input_names, batch_size=batch_size)
            with WorkerPool(
                shared_objects=(candidate_names, num_matches, algo.name, tfidf_vectorizer, tfidf_X),
            ) as pool:
                similar_names_scores = pool.map(get_similars, input_names_batches, progress_bar=True)
            # flatten
            similar_names_scores = [name_score for batch in similar_names_scores for name_score in batch]
            # convert to ndarray
            similar_names_scores = similars_to_ndarray(similar_names_scores)
        precisions, recalls = metrics.precision_weighted_recall_at_threshold(
            weighted_actual_names,
            similar_names_scores,
            min_threshold=algo.min_threshold,
            max_threshold=algo.max_threshold,
            step=0.05,
            distances=algo.distances,
            n_jobs=1,
            progress_bar=True,
        )
        print("auc", metrics.get_auc_from_precisions_recalls(
            precisions, 
            recalls, 
            distances=algo.distances
        ))
        ax.plot(recalls, precisions, "o--", color=color, label=algo.name)

    ax.legend()
    plt.xlim([0, 1.0])
    plt.ylim([0, 1.0])
    plt.show()

## on augmented in-vocabulary names (training data)

In [None]:
evaluate_algos(similarity_algos, 
               swivel_vocab, 
               swivel_model, 
               encoder_model, 
               input_names_train_sample, 
               weighted_actual_names_train_sample, 
               candidate_names_train_sample, 
               tfidf_X_train_sample)

## on original in-vocabulary names (eval data)

In [None]:
evaluate_algos(similarity_algos, 
               swivel_vocab, 
               swivel_model, 
               encoder_model, 
               input_names_eval_sample, 
               weighted_actual_names_eval_sample, 
               candidate_names_eval_sample, 
               tfidf_X_eval_sample)

## on frequent in-vocabulary names (frequent eval data)

In [None]:
evaluate_algos(similarity_algos, 
               swivel_vocab, 
               swivel_model, 
               encoder_model, 
               input_names_freq_eval_sample, 
               weighted_actual_names_freq_eval_sample, 
               candidate_names_freq_eval_sample, 
               tfidf_X_freq_eval_sample)

## on out-of-vocabulary names (test data)

In [None]:
# make sure that test data doesn't include pairs where both are in the vocab
n_zero = n_one = n_two = 0
for input_name, wans in zip(input_names_test_sample, weighted_actual_names_test_sample):
    for actual_name, _, _ in wans:
        if input_name in swivel_vocab and actual_name in swivel_vocab and input_name != actual_name:
            n_two += 1
        elif input_name in swivel_vocab or actual_name in swivel_vocab:
            n_one += 1
        else:
            n_zero += 1
print("two names in vocab (should not be possible)", n_two)
print("one name in vocab", n_one)
print("zero names in vocab", n_zero)

In [None]:
evaluate_algos(similarity_algos, 
               swivel_vocab, 
               swivel_model, 
               encoder_model, 
               input_names_test_sample, 
               weighted_actual_names_test_sample, 
               candidate_names_test_sample, 
               tfidf_X_test_sample)

In [None]:
wandb.finish()