In [1]:
%load_ext autoreload
%autoreload 2

# Generate tree anchor name, pos record name, neg record name triplets

## DEPRECATED

Use the swivel model over high-frequency names to generate (anchor, pos, pos_score, neg, neg_score) triplets.

For each high-frequency name and every other high-frequency name that is similar, generate 5 hard negatives and 5 easy negatives.

The hard negatives come from other high-frequency names that are similar. The easy negatives come from other very-high-frequency names, whether they are similar or not.

In [2]:
from bisect import bisect_right
import gzip
import math
import os
import shutil
import tempfile

import numpy as np
import pandas as pd
import random
import torch
from tqdm.auto import tqdm

from nama.data.filesystem import download_file_from_s3, upload_file_to_s3
from nama.data.utils import read_csv
from nama.models.swivel import SwivelModel, get_best_swivel_matches

In [3]:
# Config

# TODO run both given and surname
given_surname = "given"
# given_surname = "surname"

high_freq_threshold = 1000
very_high_freq_threshold = 10000
pos_threshold = 0.4
hard_neg_threshold = 0.3
hard_neg_count = 10
easy_neg_count = 30

vocab_size = 610000 if given_surname == "given" else 2100000
embed_dim = 100

frequencies_path = f"s3://fs-nama-data/2024/familysearch-names/interim/tree-hr-{given_surname}-aggr-v2.parquet"
swivel_vocab_path=f"s3://fs-nama-data/2024/nama-data/data/models/fs-{given_surname}-swivel-vocab-{vocab_size}-augmented.csv"
swivel_model_path=f"s3://fs-nama-data/2024/nama-data/data/models/fs-{given_surname}-swivel-model-{vocab_size}-{embed_dim}-augmented.pth"

triplets_path=f"s3://fs-nama-data/2024/familysearch-names/processed/tree-hr-{given_surname}-triplets-{hard_neg_count}-{easy_neg_count}.csv.gz"

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
torch.cuda.empty_cache()
print(torch.cuda.is_available())
if torch.cuda.is_available():
    print("cuda total", torch.cuda.get_device_properties(0).total_memory)
    print("cuda reserved", torch.cuda.memory_reserved(0))
    print("cuda allocated", torch.cuda.memory_allocated(0))

cuda:0
True
cuda total 8141471744
cuda reserved 0
cuda allocated 0


## Load data

In [5]:
# load counts
frequencies_path = download_file_from_s3(frequencies_path) if frequencies_path.startswith("s3://") else frequencies_path
counts_df = pd.read_parquet(frequencies_path)
print(counts_df.shape)
counts_df.head(3)

(25541154, 10)


Unnamed: 0,name,alt_name,frequency,reverse_frequency,sum_name_frequency,total_name_frequency,total_alt_name_frequency,ordered_prob,unordered_prob,similarity
0,a,a,1622927,1622927,2578937,36295683,36295683,0.629301,0.04680698,1.0
1,a,aa,154,139,2578937,36295683,5067,6e-05,8.071524e-06,0.5
2,a,aaa,3,5,2578937,36295683,143,1e-06,2.204111e-07,0.333333


In [6]:
counts_df = counts_df[['alt_name', 'total_alt_name_frequency']].drop_duplicates()
print(counts_df.shape)
counts_df.head(3)

(6148634, 2)


Unnamed: 0,alt_name,total_alt_name_frequency
0,a,36295683
1,aa,5067
2,aaa,143


In [7]:
# load swivel vocab
swivel_vocab_path = download_file_from_s3(swivel_vocab_path) if swivel_vocab_path.startswith("s3://") else swivel_vocab_path
vocab_df = read_csv(swivel_vocab_path)
swivel_vocab = {name: _id for name, _id in zip(vocab_df["name"], vocab_df["index"])}
print(len(swivel_vocab))

610000


In [8]:
swivel_model_path = download_file_from_s3(swivel_model_path) if swivel_model_path.startswith("s3://") else swivel_model_path
swivel_model = SwivelModel(len(swivel_vocab), embed_dim)
swivel_model.load_state_dict(torch.load(swivel_model_path, map_location=torch.device(device)))
swivel_model.to(device)
swivel_model.eval()

  swivel_model.load_state_dict(torch.load(swivel_model_path, map_location=torch.device(device)))


SwivelModel(
  (wi): Embedding(610000, 100)
  (wj): Embedding(610000, 100)
  (bi): Embedding(610000, 1)
  (bj): Embedding(610000, 1)
)

In [9]:
# get high-frequency names, ignoring initials
high_freq_names = [name for name, freq in zip(counts_df['alt_name'], counts_df['total_alt_name_frequency']) 
                   if freq > high_freq_threshold and name in swivel_vocab and len(name) > 1]
len(high_freq_names)

43379

In [10]:
# get very-high-frequency names along with the log of their frequency, ignoring initials
very_high_freq_name_freqs = {name: math.log10(freq) for name, freq in zip(counts_df['alt_name'], counts_df['total_alt_name_frequency']) 
                   if freq > very_high_freq_threshold and name in swivel_vocab and len(name) > 1}
len(very_high_freq_name_freqs)

9743

In [11]:
very_high_freq_name_positions = []
very_high_freq_name_names = []
start_pos = 0.0
total_freq = sum(very_high_freq_name_freqs.values())
for name, freq in very_high_freq_name_freqs.items():
    very_high_freq_name_positions.append(start_pos)
    very_high_freq_name_names.append(name)
    start_pos += freq / total_freq
print(very_high_freq_name_positions[0:10])
print(very_high_freq_name_names[0:10])

[0.0, 0.00010380820314508952, 0.00019260745835782074, 0.00030672626576713275, 0.0004382758681311588, 0.0005344989823098632, 0.0006286051530217165, 0.0007200461856900705, 0.0008096374979576926, 0.0009153062339552252]
['aage', 'aagot', 'aaltje', 'aaron', 'aart', 'aase', 'ab', 'abad', 'abagail', 'abba']


In [12]:
def find_name_for_position(positions, names, input_position):
    """
    Finds the name associated with the highest position that is less than
    or equal to the input_position using binary search.

    :param positions: List of positions, sorted by position.
    :param names: List of names, sorted by position.
    :param input_position: The input position to search for.
    :return: The name associated with the highest position <= input_position, or None if no such position exists.
    """
    # Find the index where input_position would fit
    index = bisect_right(positions, input_position) - 1
    
    # Check if the index is valid
    if index >= 0:
        return names[index]
    return None

In [13]:
print(find_name_for_position(very_high_freq_name_positions, very_high_freq_name_names, 0.0001))
print(find_name_for_position(very_high_freq_name_positions, very_high_freq_name_names, 0.00025))

aage
aaltje


## Generate triplets

In [14]:
def save_to_csv(df, filepath):
    """
    Save a DataFrame to CSV, either creating a new file or appending to existing one.
    
    Parameters:
    df (pandas.DataFrame): DataFrame to save
    filepath (str): Path to the CSV file
    
    Returns:
    bool: True if successful, False if an error occurred
    """
    if not os.path.exists(filepath):
        # File doesn't exist - create new file with headers
        df.to_csv(filepath, index=False)
    else:
        # File exists - append without headers
        df.to_csv(filepath, mode='a', header=False, index=False)

In [15]:
%%time
temp_filepath = f"{tempfile.NamedTemporaryFile(delete=False).name}.csv"
print(temp_filepath)
for anchor_name in tqdm(high_freq_names):
    triplets = []
    # get positives and hard negatives
    swivel_scores = get_best_swivel_matches(model=swivel_model, 
                                            vocab=swivel_vocab, 
                                            input_names=np.array([anchor_name]),
                                            candidate_names=np.array(high_freq_names), 
                                            encoder_model=None,
                                            k=1000, 
                                            batch_size=1000,
                                            add_context=True,
                                            progress_bar=False,
                                            n_jobs=1)
    pos_names = [(name, score) for name, score in swivel_scores[0] if score > pos_threshold]
    hard_neg_names = [(name, score) for name, score in swivel_scores[0] if score > hard_neg_threshold]
    if len(pos_names) == 0 or len(hard_neg_names) == 0:
        continue
    # get easy negatives
    swivel_scores = get_best_swivel_matches(model=swivel_model, 
                                            vocab=swivel_vocab, 
                                            input_names=np.array([anchor_name]),
                                            candidate_names=np.array(very_high_freq_name_names), 
                                            encoder_model=None,
                                            k=len(very_high_freq_name_names), 
                                            batch_size=1000,
                                            add_context=True,
                                            progress_bar=False,
                                            n_jobs=1)
    easy_neg_name_scores = {name: score for name, score in swivel_scores[0]}
    # generate triplets
    for pos_name, pos_score in pos_names:
        # add hard negatives
        for _ in range(hard_neg_count):
            neg_name, neg_score = random.choice(hard_neg_names)
            temp_pos_name, temp_pos_score = (pos_name, pos_score) if pos_score > neg_score else (neg_name, neg_score)
            temp_neg_name, temp_neg_score = (neg_name, neg_score) if pos_score > neg_score else (pos_name, pos_score)
            triplets.append({
                'anchor': anchor_name, 
                'positive': pos_name, 
                'positive_score': pos_score, 
                'negative': neg_name, 
                'negative_score': neg_score
            })            
        # add easy negatives
        for _ in range(easy_neg_count):
            neg_name = find_name_for_position(very_high_freq_name_positions, very_high_freq_name_names, random.random())
            neg_score = easy_neg_name_scores[neg_name]
            temp_pos_name, temp_pos_score = (pos_name, pos_score) if pos_score > neg_score else (neg_name, neg_score)
            temp_neg_name, temp_neg_score = (neg_name, neg_score) if pos_score > neg_score else (pos_name, pos_score)
            triplets.append({
                'anchor': anchor_name, 
                'positive': pos_name, 
                'positive_score': pos_score, 
                'negative': neg_name, 
                'negative_score': neg_score
            })
    # save triplets
    df = pd.DataFrame(triplets)
    save_to_csv(df, temp_filepath)        

/tmp/tmpi6pxekvu.csv


  0%|          | 0/43379 [00:00<?, ?it/s]

CPU times: user 1d 51min 23s, sys: 5min 41s, total: 1d 57min 4s
Wall time: 3h 48min 4s


## Save triplets to S3

In [16]:
temp_gz_filepath = f"{tempfile.NamedTemporaryFile(delete=False).name}.csv.gz"
print(temp_gz_filepath)
with open(temp_filepath, 'rb') as f_in:
    with gzip.open(temp_gz_filepath, 'wb') as f_out:
        shutil.copyfileobj(f_in, f_out)
upload_file_to_s3(temp_gz_filepath, triplets_path)

/tmp/tmp1v977a17.csv.gz


True