In [None]:
%load_ext autoreload
%autoreload 2

# Compare swivel vector similarities to tree-record name co-occurrences

## Make sure we handle surname prefixes!
how many triplets include prefixed vs unprefixed surnames?

In [None]:
from collections import Counter
import random

import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
import torch
from tqdm.auto import tqdm

from src.data.filesystem import fopen
from src.data.utils import load_dataset
from src.models.swivel import SwivelModel, get_swivel_embeddings

In [None]:
# Config

given_surname = "given"
vocab_size = 610000 if given_surname == "given" else 2100000
embed_dim = 100
train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-train.csv.gz"
vocab_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-vocab-{vocab_size}-augmented.csv"
model_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-model-{vocab_size}-{embed_dim}-augmented.pth"


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
torch.cuda.empty_cache()
print(torch.cuda.is_available())
print("cuda total", torch.cuda.get_device_properties(0).total_memory)
print("cuda reserved", torch.cuda.memory_reserved(0))
print("cuda allocated", torch.cuda.memory_allocated(0))

## Load data

In [None]:
input_names_train, weighted_actual_names_train, candidate_names_train = \
    load_dataset(train_path)

In [None]:
print("input_names_train", len(input_names_train))
print("weighted_actual_names_train", sum(len(wan) for wan in weighted_actual_names_train))
print("total pairs", sum(freq for wans in weighted_actual_names_train for _, _, freq in wans))
print("candidate_names_train", len(candidate_names_train))
print("total names", len(set(input_names_train).union(set(candidate_names_train))))

In [None]:
vocab_df = pd.read_csv(fopen(vocab_path, "rb"))
vocab = {name: _id for name, _id in zip(vocab_df["name"], vocab_df["index"])}
model = SwivelModel(len(vocab), embed_dim)
model.load_state_dict(torch.load(fopen(model_path, "rb")))
model.eval()

## Review data

In [None]:
input_names_train[468]

In [None]:
candidate_names_train[:10]

In [None]:
# record-names that tree-name aaron co-occurs with
rows = [row for row in weighted_actual_names_train[468] if row[2] >= 100]
print(len(rows))
rows

In [None]:
np.where(candidate_names_train == '<aaaron>')

In [None]:
candidate_names_train[1008]

In [None]:
[row for row in weighted_actual_names_train[468] if row[0] == '<aaaron>']

In [None]:
tree_name_counts = Counter()
record_name_counts = Counter()
for input_name, wans in zip(input_names_train, weighted_actual_names_train):
    for wan in wans:
        tree_name_counts[input_name] += wan[2]
        record_name_counts[wan[0]] += wan[2]
print(len(tree_name_counts))
print(len(record_name_counts))

In [None]:
tree_name_counts.most_common()[9500:9540]

In [None]:
record_name_counts.most_common()[70000:70020]

In [None]:
record_name_counts.most_common()[45000:45020]

In [None]:
total_tree_occurs = Counter()
total_record_occurs = Counter()
for input_name, wans in tqdm(zip(input_names_train, weighted_actual_names_train)):
    for wan in wans:
        # include co-occurrences even if a name goes to itself, 
        # because if a name usually goes to itself, we want its vector
        # to not be that close to another vector
        total_tree_occurs[input_name] += wan[2]
        total_record_occurs[wan[0]] += wan[2]

In [None]:
input_names_train_ixs = {}
for ix, input_name in enumerate(input_names_train):
    input_names_train_ixs[input_name] = ix

In [None]:
def score(tree_name, record_name):
    tree_ix = input_names_train_ixs[tree_name]
    co_occur = 0
    for row in weighted_actual_names_train[tree_ix]:
        if row[0] == record_name:
            co_occur = row[2]
            break
    total_tree_occur = total_tree_occurs[tree_name]
    tree_co_occur_ratio = co_occur / total_tree_occur
    total_record_occur = total_record_occurs[record_name]
    record_co_occur_ratio = co_occur / total_record_occur
    return max(tree_co_occur_ratio, record_co_occur_ratio)

In [None]:
scores = []
for row in weighted_actual_names_train[468]:
    tree_name = input_names_train[468]
    record_name = row[0]
    scores.append((tree_name, record_name, score(tree_name, record_name)))
sorted(scores, key=lambda tup: -tup[2])

In [None]:
tree_name_min_freq = 2000
record_name_min_freq = 200
pos_threshold = 0.5
max_triplets_per_tree_name = 2000

total_record_candidates = 0
total_tree_names = 0
triplets = []
for input_name, wans in tqdm(zip(input_names_train, weighted_actual_names_train)):
    if tree_name_counts[input_name] < tree_name_min_freq:
        continue
    record_candidates = [wan for wan in wans \
                         if record_name_counts[wan[0]] >= record_name_min_freq]
    pairs = set()
    for pos_candidate in record_candidates:
        pos_name = pos_candidate[0]
        if pos_name == input_name:
            continue
        for neg_candidate in record_candidates:
            neg_name = neg_candidate[0]
            if neg_name == input_name:
                continue
            if pos_name == neg_name:
                continue
            if f"{pos_name},{neg_name}" in pairs \
            or f"{neg_name},{pos_name}" in pairs:
                continue
            pos_score = score(input_name, pos_name)
            neg_score = score(input_name, neg_name)
            if max(pos_score, neg_score) < pos_threshold:
                continue
            if pos_score < neg_score:
                pos_name, pos_score, neg_name, neg_score = neg_name, neg_score, pos_name, pos_score
            pairs.add(f"{pos_name},{neg_name}")
            triplets.append((input_name, pos_name, pos_score, neg_name, neg_score))
            if len(pairs) == max_triplets_per_tree_name:
                break
        if len(pairs) == max_triplets_per_tree_name:
            break
    total_record_candidates += len(record_candidates)
    total_tree_names += 1
print('tree names', total_tree_names)
print('total record candidates for all tree names', total_record_candidates)
print('avg record candidates per tree name', total_record_candidates / total_tree_names)
print('total triplets', len(triplets))

In [None]:
[triplet for triplet in triplets if triplet[0] == '<richard>']

## Evaluate

In [None]:
def softmax(z):
    assert len(z.shape) == 2
    s = np.max(z, axis=1)
    s = s[:, np.newaxis] # necessary step to do broadcasting
    e_x = np.exp(z - s)
    div = np.sum(e_x, axis=1)
    div = div[:, np.newaxis] # dito
    return e_x / div

def harmonic_mean(x,y):
    return 2 / (1/x + 1/y)

In [None]:
def eval(tree_name, record_name):
    tree_ix = input_names_train.index(tree_name)
    co_occur = [row for row in weighted_actual_names_train[tree_ix] if row[0] == record_name][0][2]
    print('co-occur', co_occur)
    total_tree_occur = total_tree_occurs[tree_name]
    print('total tree occur', total_tree_occur)
    tree_co_occur_ratio = co_occur / total_tree_occur
    print('tree co-occur ratio', tree_co_occur_ratio)
    total_record_occur = total_record_occurs[record_name]
    print('total record occur', total_record_occur)
    record_co_occur_ratio = co_occur / total_record_occur
    print('record co-occur ratio', record_co_occur_ratio)
    tree_record_occur = [(input_names_train[ix], row[2]) \
                             for ix in range(len(input_names_train)) \
                             for row in weighted_actual_names_train[ix] \
                             if row[0] == record_name]
    print('all names going to record', tree_record_occur)
    print('harmonic mean', harmonic_mean(tree_co_occur_ratio, record_co_occur_ratio))
    print('max', max(tree_co_occur_ratio, record_co_occur_ratio))
    print('tree-record co-occur ratio', co_occur / (total_tree_occur + total_record_occur - co_occur))
    embs = get_swivel_embeddings(model, vocab, [tree_name, record_name], add_context=True)
    print('cosine similarity', cosine_similarity([embs[0]], [embs[1]])[0][0]) 

In [None]:
eval('<aaron>', '<aaaron>')

In [None]:
eval('<aaron>', '<aron>')

In [None]:
eval('<aaron>', '<aarone>')

In [None]:
eval('<sarah>', '<sara>')

In [None]:
eval('<rebecca>', '<rebekah>')

In [None]:
eval('<donald>', '<ronald>')

In [None]:
eval('<richard>', '<richerd>')

In [None]:
eval('<richard>', '<richand>')

In [None]:
eval('<dallas>', '<dallan>')

In [None]:
eval('<dallin>', '<dallan>')

In [None]:
eval('<dallin>', '<dallen>')

In [None]:
eval('<joan>', '<joanne>')

In [None]:
eval('<joan>', '<joane>')

In [None]:
eval('<joanne>', '<joane>')

In [None]:
eval('<wilbur>', '<wilber>')

In [None]:
eval('<wilhelmina>', '<wilhelm>')

In [None]:
eval('<richard>', '<card>')

In [None]:
eval('<richard>', '<orchard>')