In [None]:
%load_ext autoreload
%autoreload 2

# Ignore - no longer used - was used to generating training data using snorkel.org

In [None]:
from collections import defaultdict, namedtuple, Counter

import cologne_phonetics
import jellyfish
import pandas as pd
import phonetics
from pyphonetics import RefinedSoundex
from snorkel.labeling.model import LabelModel
from spellwise import CaverphoneTwo
from tqdm import tqdm

from src.data.match import levenshtein_similarity
from src.data.utils import load_datasets, frequent_k_names
from src.models.utils import remove_padding

In [None]:
# Config

given_surname = "given"
vocab_size = 610000 if given_surname == "given" else 2100000
Config = namedtuple("Config", "train_path")
config = Config(
    train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-train.csv.gz",
)

In [None]:
# load data
[train] = load_datasets([config.train_path])
input_names_train, weighted_actual_names_train, candidate_names_train = train
all_names = set(input_names_train).union(set(candidate_names_train))
sample_names = set(frequent_k_names(input_names_train, weighted_actual_names_train, 250))

In [None]:
# construct name_counts, tree2records, and record2trees
name_counts = Counter()
for input_name, wans in zip(input_names_train, weighted_actual_names_train):
    for name, _, co_occurrence in wans:
        name_counts[input_name] += co_occurrence
        name_counts[name] += co_occurrence
freq_names = set(name for name, _ in name_counts.most_common(5000))

tree2records = defaultdict(Counter)
record2trees = defaultdict(Counter)
for input_name, wans in zip(input_names_train, weighted_actual_names_train):
    for name, _, co_occurrence in wans:
        # TODO DON'T COPY - exclude frequent dissimilar pairs
        if input_name in freq_names and name in freq_names and levenshtein_similarity(remove_padding(input_name), remove_padding(name)) <= 0.65:
            continue
        tree2records[input_name][name] += co_occurrence
        record2trees[name][input_name] += co_occurrence

In [None]:
# define strategies and labeling functions based upon co-occurrence (tree2records and record2trees)
def tree_record_strategy(name, tree2records, record2trees):
    """tree->name => tree<->name"""
    c = Counter()
    c += tree2records[name]
    c += record2trees[name]
    return c


def get_weight(name, occurrence, name_counts):
    return occurrence / name_counts[name]


def tree_record_pair_strategy(name, tree2records, record2trees, name_counts, threshold=0, score=1):
    """tree->name, tree->alt_name => name<->alt_name"""
    c = Counter()
    for tree_name, occurrence in record2trees[name].items():
        if get_weight(name, occurrence, name_counts) < threshold:
            continue
        for alt_name, alt_occurrence in tree2records[tree_name].items():
            if get_weight(tree_name, alt_occurrence, name_counts) < threshold:
                continue
            if alt_name != name:
                c[alt_name] += score
    return c


def record_tree_pair_strategy(name, tree2records, record2trees, name_counts, threshold=0, score=1):
    """name->record, alt_name->record => name<->alt_name"""
    c = Counter()
    for record_name, occurrence in tree2records[name].items():
        if get_weight(name, occurrence, name_counts) < threshold:
            continue
        for alt_name, alt_occurrence in record2trees[record_name].items():
            if get_weight(record_name, alt_occurrence, name_counts) < threshold:
                continue
            if alt_name != name:
                c[alt_name] += score
    return c


def tree_record_record_strategy(name, tree2records, record2trees, name_counts, threshold=0, score=1):
    """name->record, record(as tree name)->alt_record => name<->alt_record"""
    c = Counter()
    for record_name, occurrence in tree2records[name].items():
        if get_weight(name, occurrence, name_counts) < threshold:
            continue
        for alt_name, alt_occurrence in tree2records[record_name].items():
            if get_weight(record_name, alt_occurrence, name_counts) < threshold:
                continue
            if alt_name != name:
                c[alt_name] += score
    for tree_name, occurrence in record2trees[name].items():
        if get_weight(name, occurrence, name_counts) < threshold:
            continue
        for alt_name, alt_occurrence in record2trees[tree_name].items():
            if get_weight(tree_name, alt_occurrence, name_counts) < threshold:
                continue
            if alt_name != name:
                c[alt_name] += score
    return c


In [None]:
# define strategies based upon coders
def get_codes(coder, names, multiple_codes=False):
    codes = defaultdict(list)
    for name in names:
        result = coder(remove_padding(name))
        if multiple_codes:
            for code in result:
                if code:
                    codes[code].append(name)
        else:
            codes[result].append(name)
    return codes


def get_code_matches(name, coder, codes, multiple_codes=False):
    result = coder(remove_padding(name))
    if multiple_codes:
        names = set()
        for code in result:
            if code:
                names.update(codes[code])
        return list(names)
    else:
        return codes[result]


def code_strategy(name, coder, codes, multiple_codes=False, score=1):
    c = Counter()
    for alt_name in get_code_matches(name, coder, codes, multiple_codes=multiple_codes):
        if alt_name != name:
            c[alt_name] = score
    return c

In [None]:
# generate codes
def print_coder_stats(coder_name, codes):
    print(coder_name, len(codes), sum(len(v) for v in codes.values()))


caverphone = CaverphoneTwo()
refined_soundex = RefinedSoundex()
cologne = lambda n: [result[1] for result in cologne_phonetics.encode(n)]

nysiis_codes = get_codes(jellyfish.nysiis, all_names)
caverphone_codes = get_codes(caverphone._pre_process, all_names)
refined_soundex_codes = get_codes(refined_soundex.phonetics, all_names)
dmetaphone_codes = get_codes(phonetics.dmetaphone, all_names, True)
cologne_codes = get_codes(cologne, all_names, True)
metaphone_codes = get_codes(jellyfish.metaphone, all_names)

print_coder_stats("nysiis", nysiis_codes)
print_coder_stats("caverphone", caverphone_codes)
print_coder_stats("refined_soundex", refined_soundex_codes)
print_coder_stats("cologne", cologne_codes)
print_coder_stats("metaphone", metaphone_codes)
print_coder_stats("double metaphone", dmetaphone_codes)

In [None]:
# define strategy based upon levenshtein

def get_levenshtein_matches(name, names, threshold=0.65):
    matches = {}
    name = remove_padding(name)
    for n in names:
        score = levenshtein_similarity(name, remove_padding(n))
        if score >= threshold:
            matches[n] = score
    return matches


def lev_strategy(name, names, threshold, score=1):
    c = Counter()
    for alt_name, _ in get_levenshtein_matches(name, names, threshold).items():
        if alt_name != name:
            c[alt_name] = score
    return c

In [None]:
# apply variant-name strategies to freq_names to get possible variants
def add_variants(variants_for_name, learner_name, variants):
    for variant in variants.keys():
        variants_for_name[variant][learner_name] = 1


all_learners = [
    "tree_record_pair_1e-3",
    # "tree_record_pair_1e-4",
    # "tree_record_pair_1e-5",
    # "record_tree_pair_1e-3",
    # "record_tree_pair_1e-4",
    # "record_tree_pair_1e-5",
    # "tree_record_record_1e-3",
    # "tree_record_record_1e-4",
    # "tree_record_record_1e-5",
     "code_nysiis",
     # "code_caverphone",
     "code_refined_soundex",
     "code_dmetaphone",
     "code_cologne",
     "code_metaphone",
     "lev_65",
     "lev_70",
     "lev_75",
     "lev_80",
     "lev_85",
     "lev_90",
]


def zero_learners():
    return {learner: 0 for learner in all_learners}

In [None]:
name_pairs = []
for name in tqdm(sample_names):
    variants_for_name = defaultdict(zero_learners)

    # add variants for learners based upon the various strategies
    add_variants(variants_for_name,
                 "tree_record_pair_1e-3",
                 tree_record_pair_strategy(name, tree2records=tree2records, record2trees=record2trees, name_counts=name_counts, threshold=1e-3))
    # add_variants(variants_for_name,
    #              "tree_record_pair_1e-4",
    #              tree_record_pair_strategy(name, tree2records=tree2records, record2trees=record2trees, name_counts=name_counts, threshold=1e-4))
    # add_variants(variants_for_name,
    #              "tree_record_pair_1e-5",
    #              tree_record_pair_strategy(name, tree2records=tree2records, record2trees=record2trees, name_counts=name_counts, threshold=1e-5))

    # add_variants(variants_for_name,
    #              "record_tree_pair_1e-3",
    #              record_tree_pair_strategy(name, tree2records=tree2records, record2trees=record2trees, name_counts=name_counts, threshold=1e-3))
    # add_variants(variants_for_name,
    #              "record_tree_pair_1e-4",
    #              record_tree_pair_strategy(name, tree2records=tree2records, record2trees=record2trees, name_counts=name_counts, threshold=1e-4))
    # add_variants(variants_for_name,
    #              "record_tree_pair_1e-5",
    #              record_tree_pair_strategy(name, tree2records=tree2records, record2trees=record2trees, name_counts=name_counts, threshold=1e-5))

    # add_variants(variants_for_name,
    #              "tree_record_record_1e-3",
    #              tree_record_record_strategy(name, tree2records=tree2records, record2trees=record2trees, name_counts=name_counts, threshold=1e-3))
    # add_variants(variants_for_name,
    #              "tree_record_record_1e-4",
    #              tree_record_record_strategy(name, tree2records=tree2records, record2trees=record2trees, name_counts=name_counts, threshold=1e-4))
    # add_variants(variants_for_name,
    #              "tree_record_record_1e-5",
    #              tree_record_record_strategy(name, tree2records=tree2records, record2trees=record2trees, name_counts=name_counts, threshold=1e-5))

    add_variants(variants_for_name,
                 "code_nysiis",
                 code_strategy(name, coder=jellyfish.nysiis, codes=nysiis_codes))

    # add_variants(variants_for_name,
    #              "code_caverphone",
    #              code_strategy(name, coder=caverphone._pre_process, codes=caverphone_codes))

    add_variants(variants_for_name,
                 "code_refined_soundex",
                 code_strategy(name, coder=refined_soundex.phonetics, codes=refined_soundex_codes))

    add_variants(variants_for_name,
                 "code_dmetaphone",
                 code_strategy(name, coder=phonetics.dmetaphone, codes=dmetaphone_codes, multiple_codes=True))

    add_variants(variants_for_name,
                 "code_cologne",
                 code_strategy(name, coder=cologne, codes=cologne_codes, multiple_codes=True))

    add_variants(variants_for_name,
                 "code_metaphone",
                 code_strategy(name, coder=jellyfish.metaphone, codes=metaphone_codes))

    add_variants(variants_for_name,
                 "lev_65",
                 lev_strategy(name, names=all_names, threshold=0.65))

    add_variants(variants_for_name,
                 "lev_70",
                 lev_strategy(name, names=all_names, threshold=0.70))

    add_variants(variants_for_name,
                 "lev_75",
                 lev_strategy(name, names=all_names, threshold=0.75))

    add_variants(variants_for_name,
                 "lev_80",
                 lev_strategy(name, names=all_names, threshold=0.80))

    add_variants(variants_for_name,
                 "lev_85",
                 lev_strategy(name, names=all_names, threshold=0.85))

    add_variants(variants_for_name,
                 "lev_90",
                 lev_strategy(name, names=all_names, threshold=0.90))

    # gather all variants into name_pairs
    for variant, learners in variants_for_name.items():
        learners["name"] = name
        learners["variant"] = variant
        name_pairs.append(learners)

# add possible variants to train dataframe
name_pairs_df = pd.DataFrame(name_pairs)
print(len(name_pairs_df))

In [None]:
len(set(name_pairs_df["name"]).union(set(name_pairs_df["variant"])))

In [None]:
# add sum of scores
name_pairs_df["sum_learners"] = name_pairs_df[all_learners].sum(axis=1)

In [None]:
# review
print(len(name_pairs_df))
name_pairs_df.sample(n=200)

In [None]:
low = min(name_pairs_df["sum_learners"])
high = max(name_pairs_df["sum_learners"])
print(low, high)
name_pairs_df.hist(column="sum_learners", bins=high - low + 1)

In [None]:
len(name_pairs_df[name_pairs_df["sum_learners"] >= 3])

In [None]:
name_pairs_df[name_pairs_df["sum_learners"] >= 3].sample(n=40)

In [None]:
# choose a positive ratio
pos_ratio = 0.382

In [None]:
# train snorkel on strategy results
snorkel_train = name_pairs_df[all_learners].to_numpy()
label_model = LabelModel(cardinality=2, verbose=True)
# TODO is our class_balance reversed?
label_model.fit(snorkel_train, n_epochs=1000, log_freq=50, class_balance=[1-pos_ratio, pos_ratio])

In [None]:
# add snorkel predictions to train dataframe
name_pairs_df["predict_proba"] = label_model.predict_proba(L=snorkel_train)[:, 1]
name_pairs_df["predict"] = label_model.predict(L=snorkel_train)

In [None]:
name_pairs_df["predict"].sum()

In [None]:
# review positive and negative examples
name_pairs_df.sample(200)