In [None]:
%load_ext autoreload
%autoreload 2

# Split input pairs into train and test sets

In [None]:
from collections import namedtuple
import wandb

from src.data.familysearch import train_test_split
from src.data.utils import load_train_test
from src.models.utils import add_padding

In [None]:
given_surname = "given"
Config = namedtuple("Config", "pref_path in_path train_path test_path freq_train_path freq_test_path freq_cutoff train_cutoff")
config = Config(
    pref_path=f"s3://familysearch-names/processed/tree-preferred-{given_surname}-aggr.csv.gz",
    in_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-similar.csv.gz",
    train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-similar-train.csv.gz",
    test_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-similar-test.csv.gz",
    freq_train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-similar-train-freq.csv.gz",
    freq_test_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-similar-test-freq.csv.gz",
    freq_cutoff=1000,
    train_cutoff=0,
)

In [None]:
wandb.init(
    project="nama",
    entity="nama",
    name="45_train_test_split",
    group=given_surname,
    notes="",
    config=config._asdict()
)

In [None]:
train_test_split(config.pref_path, config.in_path, config.train_path, config.test_path, config.freq_train_path, config.freq_test_path,
                 config.freq_cutoff, config.train_cutoff)

In [None]:
train, freq_train, freq_test = load_train_test([config.train_path, config.freq_train_path, config.freq_test_path])

input_names_train, weighted_actual_names_train, candidate_names_train = train
input_names_freq_train, weighted_actual_names_freq_train, candidate_names_freq_train = freq_train
input_names_freq_test, weighted_actual_names_freq_test, candidate_names_freq_test = freq_test

In [None]:
print("train input names (name1), weighted actual (name1 -> [name2, weighted_count, co_occurrence], candidate names (name2)")
print("name1", len(input_names_train))
print("weighted actual - should be same as name1", len(weighted_actual_names_train))
print("number of actuals", sum(len(wa) for wa in weighted_actual_names_train))
print("name2", len(candidate_names_train))
print("total unique names", len(set(input_names_train).union(set(candidate_names_train))))

In [None]:
print("test in-vocab: input names (name1), weighted actual (name1 -> [name2, weighted_count, co_occurrence], candidate names (name2)")
print("name1", len(input_names_freq_train))
print("weighted actual - should be same as name1", len(weighted_actual_names_freq_train))
print("number of actuals", sum(len(wa) for wa in weighted_actual_names_freq_train))
print("name2", len(candidate_names_freq_train))
print("total unique names", len(set(input_names_freq_train).union(set(candidate_names_freq_train))))

In [None]:
print("test out-of-vocab: input names (name1), weighted actual (name1 -> [name2, weighted_count, co_occurrence], candidate names (name2)")
print("name1", len(input_names_freq_test))
print("weighted actual - should be same as name1", len(weighted_actual_names_freq_test))
print("number of actuals", sum(len(wa) for wa in weighted_actual_names_freq_test))
print("name2", len(candidate_names_freq_test))
print("total unique names", len(set(input_names_freq_test).union(set(candidate_names_freq_test))))

In [None]:
# probe datasets to validate
def print_weighted_actual_names(label, weighted_actual_names, max=0):
    print(label)
    print("total", len(weighted_actual_names))
    if 0 < max < len(weighted_actual_names):
        weighted_actual_names = weighted_actual_names[:max]
    for wan in weighted_actual_names:
        print("  ", wan)

probe_name = add_padding("chesworth" if given_surname == "surname" else "richard")
print_weighted_actual_names("freq_train", weighted_actual_names_freq_train[input_names_freq_train.index(probe_name)], 20)
print_weighted_actual_names("freq_test", weighted_actual_names_freq_test[input_names_freq_test.index(probe_name)], 20)

In [None]:
wandb.finish()