In [1]:
%load_ext autoreload
%autoreload 2

# Generate triplets from cross encoder

Generate triplets using the cross encoder from 377
- get all pairs from tree-record-attachments and existing standard buckets
- add easy-negs from most-frequent 10000 or 25000 tree pref names

In [17]:
import csv
from datetime import datetime
import os
import random
import re
import tempfile

from sentence_transformers.cross_encoder import CrossEncoder
import torch

from nama.data.filesystem import download_file_from_s3, upload_file_to_s3
from nama.data.utils import read_csv, load_dataset

In [3]:
#config
# TODO do for given and surname
given_surname = "given"
# given_surname = "surname"

num_training_examples = 10_000_000
num_common_names = 10_000 if given_surname == "given" else 25_000

cross_encoder_vocab_size = 265
tokenizer_max_length = 32
std_path = f"../references/std_{given_surname}.txt"

pref_path = f"s3://fs-nama-data/2024/familysearch-names/processed/tree-preferred-{given_surname}-aggr.csv.gz"
train_path = f"s3://fs-nama-data/2024/familysearch-names/processed/tree-hr-{given_surname}-train.csv.gz"
non_negatives_path = f"s3://fs-nama-data/2024/familysearch-names/processed/common_{given_surname}_non_negatives.csv"
cross_encoder_dir = f"../data/models/cross-encoder-{given_surname}-{cross_encoder_vocab_size}"
cross_encoder_dir_s3 = f"s3://fs-nama-data/2024/nama-data/data/models/cross-encoder-{given_surname}-{cross_encoder_vocab_size}/"

triplets_path=f"s3://fs-nama-data/2024/familysearch-names/processed/cross-encoder-triplets-{given_surname}-train.csv"
triplets_path_common=f"s3://fs-nama-data/2024/familysearch-names/processed/cross-encoder-triplets-{given_surname}-common.csv"

In [4]:
torch.cuda.empty_cache()
print(torch.cuda.is_available())
print("cuda total", torch.cuda.get_device_properties(0).total_memory)
print("cuda reserved", torch.cuda.memory_reserved(0))
print("cuda allocated", torch.cuda.memory_allocated(0))

True
cuda total 8141471744
cuda reserved 0
cuda allocated 0


## Load data

In [19]:
# load cross encoder
for filename in [
    'CECorrelationEvaluator_dev_results.csv',
    'config.json',
    'merges.txt',
    'special_tokens_map.json',
    'tokenizer_config.json',
    'tokenizer.json',
    'vocab.json',
    'model.safetensors',
]:
    print(filename)
    cross_encoder_dir_filename = os.path.join(cross_encoder_dir, filename)
    if not os.path.exists(cross_encoder_dir_filename):
        print('downloading', filename)
        download_file_from_s3(cross_encoder_dir_s3+filename, cross_encoder_dir_filename)

CECorrelationEvaluator_dev_results.csv
config.json
merges.txt
special_tokens_map.json
tokenizer_config.json
tokenizer.json
vocab.json
model.safetensors


### load common names

In [6]:
path = download_file_from_s3(pref_path) if pref_path.startswith("s3://") else pref_path
pref_df = read_csv(path)
common_names = [name for name in pref_df['name'][:num_common_names].tolist() \
                if len(name) > 1 and re.fullmatch(r'[a-z]+', name)]
pref_df = None
len(common_names)

9977

### load common non-negatives

In [7]:
name_pairs = set()
path = download_file_from_s3(non_negatives_path) if non_negatives_path.startswith("s3://") else non_negatives_path
common_non_negatives_df = read_csv(path)
for name1, name2 in common_non_negatives_df.values.tolist():
    if name1 != name2:
        name_pairs.add((name1, name2))
len(name_pairs)

1122584

### load training dataset

In [9]:
path = download_file_from_s3(train_path) if train_path.startswith("s3://") else train_path
tree_names_train, attached_names_train, record_names_train = load_dataset(path)

In [10]:
print("tree_names_train", len(tree_names_train))
print("attached_names_train", sum(len(attachments) for attachments in attached_names_train))
print("total pairs", sum(freq for attachments in attached_names_train for _, freq in attachments))
print("record_names_train", len(record_names_train))
print("total names", len(set(tree_names_train).union(set(record_names_train))))

tree_names_train 581934
attached_names_train 3790287
total pairs 1072113643
record_names_train 791151
total names 836285


In [11]:
buckets = []

In [12]:
%%time
for tree_name, attachments in zip(tree_names_train, attached_names_train):
    names = [tree_name]
    for name, _ in attachments:
        if name not in names:
            names.append(name)
    if len(names) < 2:
        continue
    buckets.append(names)
    for name1 in names:
        for name2 in names:
            if name1 != name2:
                name_pairs.add((name1, name2))
print(len(buckets), sum(len(bucket) for bucket in buckets))

489770 3698123


### load std

In [13]:
with open(std_path) as f:
    for line in f.readlines():
        line = line.strip()
        head_names, tail_names = line.split(':')
        head_names = head_names.strip()
        tail_names = tail_names.strip()
        names = []
        for name in head_names.split(' '):
            if len(name) > 0 and name not in names:
                names.append(name)
        for name in tail_names.split(' '):
            if len(name) > 0 and name not in names:
                names.append(name)
        if len(names) < 2:
            continue
        buckets.append(names)
        for name1 in names:
            for name2 in names:
                if name1 != name2:
                    name_pairs.add((name1, name2))
print(len(buckets), sum(len(bucket) for bucket in buckets))

497178 3792663


In [14]:
len([bucket for bucket in buckets if len(bucket) == 2])

210075

### name pairs

In [15]:
name_pairs = list(name_pairs)
len(name_pairs)

314048990

## Generate and write triplets

In [20]:
model = CrossEncoder(cross_encoder_dir, max_length=tokenizer_max_length)

In [21]:
def harmonic_mean(x,y):
    return 2 / (1/x+1/y)

def choose3(names):
    while True:
        anchor = random.randrange(len(names))
        pos = random.randrange(len(names))
        neg = random.randrange(len(names))
        if anchor != neg and pos != neg and (anchor != pos or len(names) == 2):
            if names[pos] < names[neg]:  # always return pos >= neg
                neg, pos = pos, neg
            return names[anchor], names[pos], names[neg]
        
def choose_pair(pair):
    if random.random() < 0.5:
        return pair[0], pair[1]
    else:
        return pair[1], pair[0]

def clamp(score):
    return max(0.0, min(1.0, score))

In [22]:
def write_triple(writer, triple):
    anchor, pos, neg = triple
    anchor_pos1, anchor_pos2, anchor_neg1, anchor_neg2 = \
        model.predict([[anchor, pos], [pos, anchor], [anchor, neg], [neg, anchor]])
    anchor_pos = harmonic_mean(anchor_pos1, anchor_pos2)
    anchor_neg = harmonic_mean(anchor_neg1, anchor_neg2)
    if anchor == pos:
        anchor_pos = 1.0
    if anchor == neg:
        anchor_neg = 1.0
    anchor_pos = clamp(anchor_pos)
    anchor_neg = clamp(anchor_neg)
    if anchor_pos < anchor_neg:
        pos, neg = neg, pos
        anchor_pos, anchor_neg = anchor_neg, anchor_pos
    writer.writerow({
        'anchor': anchor, 
        'positive': pos, 
        'positive_score': anchor_pos, 
        'negative': neg, 
        'negative_score': anchor_neg,
    })

In [23]:
# open each path, get names, write triplets
def write_triplets(num_easy_negs, path):
    cnt = 0
    seen_triples = set()
    seen_cnt = 0
    local_path = (
        f"{tempfile.NamedTemporaryFile(delete=False).name}" if path.startswith("s3://") else path
    )    
    with open(local_path, 'w', newline='') as f:
        # Create a CSV writer object
        writer = csv.DictWriter(f, fieldnames=['anchor','positive','positive_score','negative','negative_score'])
    
        # Write the column headers
        writer.writeheader()
    
        while cnt < num_training_examples:
            # write triples from common names
            if num_easy_negs == 'common':
                anchor = random.choice(common_names)
                pos = random.choice(common_names)
                neg = random.choice(common_names)
                if anchor == pos or anchor == neg or pos == neg:
                    continue
                if pos < neg:
                    pos, neg = neg, pos
                triple = (anchor, pos, neg)
                if triple in seen_triples:
                    seen_cnt += 1
                    continue
                seen_triples.add(triple)
                write_triple(writer, triple)
                cnt += 1
                if cnt % 10_000 == 0:
                    print(cnt, seen_cnt, datetime.now())
            else:
                ix = random.randrange(len(buckets))
                if len(buckets[ix]) < 3:
                    continue
                triple = choose3(buckets[ix])
                if triple in seen_triples:
                    seen_cnt += 1
                    continue
                seen_triples.add(triple)
                write_triple(writer, triple)
                cnt += 1
                if cnt % 10_000 == 0:
                    print(cnt, seen_cnt, datetime.now())
    
                easy_neg_cnt = 0
                while easy_neg_cnt < num_easy_negs:
                    ix = random.randrange(len(name_pairs))
                    anchor, pos = choose_pair(name_pairs[ix])
                    easy_neg = random.choice(common_names)
                    triple = (anchor, pos, easy_neg)
                    if triple in seen_triples:
                        seen_cnt += 1
                        continue
                    seen_triples.add(triple)
                    write_triple(writer, triple)
                    easy_neg_cnt += 1
                    cnt += 1
                    if cnt % 10_000 == 0:
                        print(cnt, seen_cnt, datetime.now())
    if path.startswith("s3://"):
        upload_file_to_s3(local_path, path)
    return cnt

In [24]:
%%time
cnt = write_triplets(0, triplets_path)
print(triplets_path, cnt)

10000 22 2024-12-09 14:52:36.601552
20000 98 2024-12-09 14:54:07.205455
30000 215 2024-12-09 14:55:35.700288
40000 365 2024-12-09 14:57:01.936077
50000 577 2024-12-09 14:58:24.775799
60000 835 2024-12-09 14:59:43.471984
70000 1130 2024-12-09 15:00:59.686857
80000 1505 2024-12-09 15:02:15.788286
90000 1878 2024-12-09 15:03:31.758060
100000 2319 2024-12-09 15:04:47.994638
110000 2806 2024-12-09 15:06:03.964517
120000 3347 2024-12-09 15:07:20.568084
130000 3891 2024-12-09 15:08:38.121188
140000 4530 2024-12-09 15:09:53.343506
150000 5211 2024-12-09 15:11:09.027005
160000 5949 2024-12-09 15:12:24.495304
170000 6714 2024-12-09 15:13:41.806128
180000 7469 2024-12-09 15:15:01.223954
190000 8307 2024-12-09 15:16:21.439283
200000 9207 2024-12-09 15:17:41.671322
210000 10190 2024-12-09 15:19:00.756251
220000 11133 2024-12-09 15:20:20.420970
230000 12190 2024-12-09 15:21:39.295882
240000 13314 2024-12-09 15:22:56.348757
250000 14470 2024-12-09 15:24:12.741543
260000 15688 2024-12-09 15:25:29.0023

In [25]:
%%time
cnt = write_triplets('common', triplets_path_common)
print(triplets_path_common, cnt)

10000 0 2024-12-10 12:01:19.123701
20000 0 2024-12-10 12:02:54.716233
30000 0 2024-12-10 12:04:30.360344
40000 0 2024-12-10 12:06:06.527031
50000 0 2024-12-10 12:07:41.105663
60000 0 2024-12-10 12:09:07.918048
70000 0 2024-12-10 12:10:25.221582
80000 0 2024-12-10 12:11:42.829540
90000 0 2024-12-10 12:12:59.361385
100000 0 2024-12-10 12:14:15.632357
110000 0 2024-12-10 12:15:32.044534
120000 0 2024-12-10 12:16:48.569799
130000 0 2024-12-10 12:18:05.421929
140000 0 2024-12-10 12:19:21.724641
150000 0 2024-12-10 12:20:38.151789
160000 0 2024-12-10 12:21:54.649234
170000 0 2024-12-10 12:23:10.728672
180000 0 2024-12-10 12:24:27.094211
190000 0 2024-12-10 12:25:43.294990
200000 0 2024-12-10 12:26:59.254071
210000 0 2024-12-10 12:28:14.934871
220000 0 2024-12-10 12:29:30.440007
230000 0 2024-12-10 12:30:46.191775
240000 0 2024-12-10 12:32:01.951546
250000 0 2024-12-10 12:33:17.491244
260000 0 2024-12-10 12:34:33.241680
270000 0 2024-12-10 12:35:49.052408
280000 0 2024-12-10 12:37:05.110579
2