In [None]:
%load_ext autoreload
%autoreload 2

# Split input pairs into train and test sets

In [None]:
from collections import namedtuple
import wandb

from src.data.familysearch import train_test_split_on_frequency
from src.data.utils import load_dataset
from src.models.utils import add_padding

In [None]:
given_surname = "given"
Config = namedtuple("Config", "in_path train_path test_path threshold")
config = Config(
    in_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-similar.csv.gz",
    train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-train-unfiltered.csv.gz",
    test_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-test.csv.gz",
    threshold=0.5
)

In [None]:
wandb.init(
    project="nama",
    entity="nama",
    name="45_train_test_split",
    group=given_surname,
    notes="",
    config=config._asdict()
)

In [None]:
train_test_split_on_frequency(config.in_path, config.train_path, config.test_path, config.threshold)

In [None]:
input_names_train, weighted_actual_names_train, candidate_names_train = \
        load_dataset(config.train_path)
input_names_test, weighted_actual_names_test, candidate_names_test = \
        load_dataset(config.test_path)

In [None]:
vocab = set(input_names_train).union(set(candidate_names_train))
print(len(vocab))

In [None]:
# check test set is correct
n_zero = n_one = n_two = 0
for input_name, wans in zip(input_names_test, weighted_actual_names_test):
    for actual_name, _, _ in wans:
        if input_name in vocab and actual_name in vocab and input_name != actual_name:
            n_two += 1
        elif input_name in vocab or actual_name in vocab:
            n_one += 1
        else:
            n_zero += 1
print("two names in vocab (should not be possible)", n_two)
print("one name in vocab", n_one)
print("zero names in vocab", n_zero)

In [None]:
print("train input names (name1), weighted actual (name1 -> [name2, weighted_count, co_occurrence], candidate names (name2)")
print("name1", len(input_names_train))
print("weighted actual - should be same as name1", len(weighted_actual_names_train))
print("number of actuals", sum(len(wa) for wa in weighted_actual_names_train))
print("name2", len(candidate_names_train))
print("total unique names", len(set(input_names_train).union(set(candidate_names_train))))

In [None]:
print("test out-of-vocab: input names (name1), weighted actual (name1 -> [name2, weighted_count, co_occurrence], candidate names (name2)")
print("name1", len(input_names_test))
print("weighted actual - should be same as name1", len(weighted_actual_names_test))
print("number of actuals", sum(len(wa) for wa in weighted_actual_names_test))
print("name2", len(candidate_names_test))
print("total unique names", len(set(input_names_test).union(set(candidate_names_test))))

### Probe datasets

In [None]:
def print_weighted_actual_names(label, weighted_actual_names, max=0):
    print(label)
    print("total", len(weighted_actual_names))
    if 0 < max < len(weighted_actual_names):
        weighted_actual_names = weighted_actual_names[:max]
    for wan in weighted_actual_names:
        print("  ", wan)

probe_name = add_padding("jones" if given_surname == "surname" else "richard")
print("total weight", sum(wc for _, wc, _ in weighted_actual_names_train[input_names_train.index(probe_name)]))
print_weighted_actual_names("train", weighted_actual_names_train[input_names_train.index(probe_name)], 20)
print("total weight", sum(wc for _, wc, _ in weighted_actual_names_test[input_names_test.index(probe_name)]))
print_weighted_actual_names("test", weighted_actual_names_test[input_names_test.index(probe_name)], 20)

In [None]:
wandb.finish()