In [1]:
%load_ext autoreload
%autoreload 2

# Augment triplets

Add easy-negatives and under-represented tokens

In [4]:
from collections import Counter
import random
import re

import pandas as pd
from tqdm.auto import tqdm

from nama.data.filesystem import download_file_from_s3, save_file
from nama.data.utils import read_csv
from nama.models.tokenizer import get_tokenize_function_and_vocab

In [32]:
#config
# TODO do for given and surname
given_surname = "given"
# given_surname = "surname"

num_common_names = 10000 if given_surname == "given" else 25000

num_easy_pos_negs = 100
num_easy_neg_negs = 5
num_common_negs = 1000
num_common_neg_copies = 2
num_anchor_pos_neg_copies = 1
under_represented_threshold = 500

max_tokens = 10
vocab_size = 2048  # 256, 512, 1024, 1536, 2048
tree_name_min_freq = 1000

triplets_path=f"s3://fs-nama-data/2024/familysearch-names/processed/tree-hr-{given_surname}-triplets-{tree_name_min_freq}.csv.gz"
pref_path = f"s3://fs-nama-data/2024/familysearch-names/processed/tree-preferred-{given_surname}-aggr.csv.gz"
non_negatives_path = f"s3://fs-nama-data/2024/familysearch-names/processed/common_{given_surname}_non_negatives.csv"
tokenizer_path=f"s3://fs-nama-data/2024/nama-data/data/models/fs-{given_surname}-subword-tokenizer-{vocab_size}.json"

augmented_path=f"s3://fs-nama-data/2024/familysearch-names/processed/tree-hr-{given_surname}-triplets-{tree_name_min_freq}-augmented.csv.gz"

## Load data

In [5]:
# read triplets
path = download_file_from_s3(triplets_path) if triplets_path.startswith("s3://") else triplets_path
triplets_df = read_csv(path)
print(len(triplets_df))
triplets_df.head(3)

3289279


Unnamed: 0,anchor,positive,positive_score,negative,negative_score
0,aloysius,alaysius,0.961516,aloisia,0.403759
1,rosanna,roseanna,0.792378,susanne,0.400109
2,stephen,stehen,0.968966,stearn,0.501695


In [6]:
# count the number of unique anchor-pos pairs and anchor-neg pairs
anchor_pos = set()
anchor_neg = set()
for tup in tqdm(triplets_df.itertuples()):
    anchor = tup.anchor
    pos = tup.positive
    neg = tup.negative
    anchor_pos.add(f"{anchor}:{pos}")
    anchor_neg.add(f"{anchor}:{neg}")
print(len(anchor_pos))
print(len(anchor_neg))

0it [00:00, ?it/s]

72350
594206


### read common names

In [8]:
path = download_file_from_s3(pref_path) if pref_path.startswith("s3://") else pref_path
pref_df = read_csv(path)
common_names = [name for name in pref_df['name'][:num_common_names].tolist() \
                if len(name) > 1 and re.fullmatch(r'[a-z]+', name)]
pref_df = None
len(common_names)

9977

### read common non-negatives

In [9]:
common_non_negatives = set()
path = download_file_from_s3(non_negatives_path) if non_negatives_path.startswith("s3://") else non_negatives_path
common_non_negatives_df = read_csv(path)
for name1, name2 in common_non_negatives_df.values.tolist():
    common_non_negatives.add((name1, name2))
len(common_non_negatives)

1122584

## Create training data

In [10]:
# don't add anchor-pos pairs that teach the model bad habits
if given_surname == "given":
    bad_anchor_pos_pairs = [('maria', 'annamaria'), 
                            ('marie', 'annamarie'),
                           ]
else:
    bad_anchor_pos_pairs = []

In [11]:
def is_bad_anchor_pos_pair(name1, name2):
    for bad_name1, bad_name2 in bad_anchor_pos_pairs:
        if (name1 == bad_name1 and name2 == bad_name2) or \
           (name2 == bad_name1 and name1 == bad_name2):
            return True
    return False

### Get tokenizer

In [12]:
path = download_file_from_s3(tokenizer_path) if tokenizer_path.startswith("s3://") else tokenizer_path
tokenize, tokenizer_vocab = get_tokenize_function_and_vocab(tokenizer_path=path, max_tokens=max_tokens)
len(tokenizer_vocab)

2048

In [13]:
tokenize('dallan')

[1584, 489, 1, 1, 1, 1, 1, 1, 1, 1]

### Add anchor-pos-neg triplets

In [14]:
# augment data with easy negatives
aug_data = []
seen_anchor_pos = set()
seen_anchor_neg = set()
for tup in tqdm(triplets_df.itertuples()):
    anchor = tup.anchor
    pos = tup.positive
    neg = tup.negative
    if is_bad_anchor_pos_pair(anchor, pos):
        continue
    # anchor, positive, hard-negative
    for _ in range(num_anchor_pos_neg_copies):
        aug_data.append({
            'anchor': anchor,
            'positive': pos,
            'negative': neg,
            'positive_score': tup.positive_score,
            'negative_score': tup.negative_score,
        })

    # add anchor-pos-easy-negatives
    # only add easy negatives the first time we see this anchor,pos pair
    anchor_pos = f"{anchor},{pos}"
    if anchor_pos not in seen_anchor_pos:
        seen_anchor_pos.add(anchor_pos)
        ix = 0
        while ix < num_easy_pos_negs:
            # anchor, positive, easy-negative
            easy_neg = random.choice(common_names)
            # only add anchor-pos-easy-neg if easy-neg isn't really a non-negative
            if anchor == easy_neg or pos == easy_neg or (anchor, easy_neg) in common_non_negatives:
                continue
            aug_data.append({
                'anchor': anchor,
                'positive': pos,
                'negative': easy_neg,
                'positive_score': tup.positive_score,
                'negative_score': 0.0,
            })
            ix += 1

    # add anchor-neg-easy-negatives
    # only add easy negatives the first time we see this anchor,neg pair
    neg_anchor = f"{neg},{anchor}"
    if neg_anchor not in seen_anchor_neg:
        seen_anchor_neg.add(neg_anchor)
        ix = 0
        while ix < num_easy_neg_negs:
            easy_neg = random.choice(common_names)
            # only add anchor-neg-easy-neg if easy-neg isn't really a non-negative
            if anchor == easy_neg or neg == easy_neg or (anchor, easy_neg) in common_non_negatives:
                continue
            aug_data.append({
                'anchor': anchor,
                'positive': neg,
                'negative': easy_neg,
                'positive_score': tup.negative_score,
                'negative_score': 0.0,
            })
            ix += 1
            
len(aug_data)

0it [00:00, ?it/s]

13494823

### Add pos-pos-easyneg triplets

In [15]:
cnt = 0
for pos in tqdm(common_names[:num_common_negs]):
    for neg in common_names[:num_common_negs]:
        if pos == neg or (pos, neg) in common_non_negatives:
            continue
        for _ in range(num_common_neg_copies):
            aug_data.append({
                'anchor': pos,
                'positive': pos,
                'negative': neg,
                'positive_score': 1.0,
                'negative_score': 0.0,
            })
            cnt += 1
print(cnt)

  0%|          | 0/1000 [00:00<?, ?it/s]

1983724


### Add triplets for names we want to push apart

In [16]:
if given_surname == "given":
    push_apart_pairs = [('charles', 'frances'),
                        ('marie', 'annie'),
                        ('james', 'jane'),
                        ('jane', 'janos'),
                        ('hannah', 'hans'),
                        ('frank', 'frederick'),
                        ('anne', 'anders'),
                        ('maria', 'manuel'),
                        ('maria', 'manuela'),
                        ('juan','julia'),
                        ('margaret','violet'),
                        ('antonio','emilio'),
                        ('edward','edwin'),
                        ('samuel','smith'),
                        ('martin','augustin'),
                        ('eva','evan'),
                        ('dejesus','de'),
                        ('dejesus','dean'),
                        ('eliza','luiza'),
                        ('frank','mark'),
                        ('benjamin','benita'),
                        ('andrew','matthew'),
                        ('andrew','mathew'),
                        ('guadalupe','guy'),
                        ('jeanne','susanne'),
                        ('delacruz','delaconcepcion'),
                        ('rebecca','veronica'),
                        ('rebecca','francesca'),
                        ('karl','karen'),
                        ('karl','karin'),
                        ('adam','ada'),
                        ('adam','addie'),
                        ('bertha','bruce'),
                        ('edith','edmond'),
                        ('mathias','elias'),
                        ('anton','anta'),
                        ('ethel','effie'),
                        ('delcarmen','oscar'),
                        ('santiago','santos'),
                        ('vicente','clemente'),
                        ('ysabel','ysidro'),
                        ('karen','karolina'),
                        ('ralph','christoph'),
                        ('raymond','reyes'),
                        ('maren','christen'),
                        ('christoph','jph'),
                        ('erzsebet','jozsef'),
                        ('carlos','marcos'),
                        ('ada','adamus'),
                        ('delaluz','dela'),
                        ('jennie','jemima'),
                        ('lorenzo','vincenzo'),
                        ('stina','stella'),
                        ('pearl','per'),
                        ('pearl','pehr'),
                        ('oscar','encarnacion'),
                        ('veronica','francesca'),
                        ('sebastiana','victoriana'),
                        ('elias','matias'),
                        ('myrtle','estelle'),
                        ('bernardo','leonardo'),
                        ('amy','amos'),
                        ('leslie','lester'),
                        ('rosario','hilario'),
                        ('karin','karolina'),
                        ('nora','norma'),
                        ('michaela','mc'),
                        ('christiana','luciana'),
                        ('chen','chester'),
                        ('angelina','augustina'),
                        ('sam','smith'),
                        ('soledad','solomon'),
                        ('mari','jacobi'),
                        ('mari','eli'),
                        ('mari','josephi'),
                        ('mari','li'),
                        ('delacrus','delaconcepcion'),
                        ('etta','etienne'),
                        ('imre','ines'),
                        ('florentina','valentina'),
                        ('jacobi','josephi'),
                        ('joanne','susanne'),
                        ('bernardino','florentino'),
                        ('josefina','rufina'),
                        ('eli','josephi'),
                        ('dean','delia'),
                        ('emilio','mario'),
                        ('jenny','jemima'),
                        ('paulino','antonino'),
                        # 13 Oct 2023
                        ('anne','anders'),
                        ('ann','amy'),
                        ('dejesus','de'),
                        ('dejesus','dedios'),
                        ('anders','an'),
                        ('ana','anastacia'),
                        ('de','dedios'),
                        ('mae','mette'),
                        ('betty','bell'),
                        ('jesus','julius'),
                        ('joao','joel'),
                        ('clarence','claire'),
                        ('martina','marta'),
                        ('roy','ray'),
                        ('pearl','per'),
                        ('veronica','domenica'),
                        ('elias','elisa'),
                        ('lidia','li'),
                        ('sven','sue'),
                        ('bernardino','paulino'),
                        ('bernardino','antonino'),
                        ('eli','elin'),
                        ('emilio','mario'),
                        ('antal','an'),
                        ('anta','an'),
                        ('paulino','antonino'),
                       ]
else:
    push_apart_pairs = []
push_apart_copies = 10

In [17]:
for anchor, neg in push_apart_pairs:
    for _ in range(push_apart_copies):
        aug_data.append({
            'anchor': anchor,
            'positive': anchor,
            'negative': neg,
            'positive_score': 1.0,
            'negative_score': 0.0,
        })

## Analyze training data

In [18]:
token_id2text = {}
for text, id_ in tokenizer_vocab.items():
    token_id2text[id_] = text

In [19]:
counter = Counter()
for row in tqdm(aug_data, mininterval=1.0):
    for key in ['anchor', 'positive', 'negative']:
        for token in tokenize(row[key]):
            if token == 1:
                break
            counter[token] += 1
for ix, (token, cnt) in enumerate(counter.most_common()):
    print(ix, token, token_id2text[token], cnt)

  0%|          | 0/15479697 [00:00<?, ?it/s]

0 34 ##a 2400867
1 37 ##e 2189892
2 35 ##n 1777948
3 31 ##o 1551878
4 45 ##s 1457386
5 75 ##ia 1188281
6 41 ##l 1183635
7 42 ##t 1106371
8 49 ##y 904235
9 123 ##ina 851578
10 38 ##i 771990
11 39 ##r 768875
12 71 ##ie 762546
13 40 ##d 761814
14 89 ##us 747787
15 64 ##in 733467
16 85 ##na 664897
17 36 ##h 641336
18 104 ##ine 633647
19 94 ##la 594915
20 171 ##io 562168
21 59 ##er 561259
22 72 ##on 545622
23 135 ##ta 543891
24 52 ##b 527673
25 73 ##is 524536
26 44 ##g 510915
27 32 ##m 493290
28 103 ##ne 478601
29 62 ##en 466277
30 56 ##z 415135
31 119 ##da 409005
32 65 ##el 375927
33 51 ##f 371385
34 705 ##ri 355007
35 153 ##ra 339920
36 61 mar 339460
37 30 z 333899
38 47 ##c 329101
39 97 ##le 326423
40 163 ar 326230
41 46 ##k 323331
42 58 ##an 319215
43 108 al 306982
44 33 ##p 300462
45 1368 ##no 289938
46 68 ##es 288016
47 66 ##il 285533
48 712 ##re 281678
49 19 o 276387
50 48 ##u 275225
51 187 la 274943
52 202 ##sa 273492
53 251 re 271362
54 134 el 265379
55 159 ##de 253631
56 142 ##te 

In [20]:
tokenize('jewel')

[1935, 1, 1, 1, 1, 1, 1, 1, 1, 1]

### Find names that contain under-represented tokens

In [21]:
under_represented_token_ids = set([id_ for id_ in tokenizer_vocab.values() if counter[id_] < under_represented_threshold])
print(len(under_represented_token_ids))

14


In [22]:
all_names = set()
for tup in tqdm(triplets_df.itertuples()):
    anchor = tup.anchor
    pos = tup.positive
    neg = tup.negative
    all_names.add(anchor)
    all_names.add(pos)
    all_names.add(neg)
all_names.update(common_names)
len(all_names)

0it [00:00, ?it/s]

66978

In [23]:
under_represented_names = set()
for name in all_names:
    found_token = False
    for token in tokenize(name):
        if token == 1:
            break
        if token in under_represented_token_ids:
            found_token = True
            break
    if found_token:
        under_represented_names.add(name)
print(len(under_represented_names))
for name in under_represented_names:
    token_counts = []
    for token in tokenize(name):
        if token == 1:
            break
        token_counts.append((token, token_id2text[token], counter[token]))
    print(name, token_counts)

12
loveland [(133, 'lo', 175865), (1653, '##veland', 315)]
doleres [(717, 'dol', 42451), (410, '##eres', 284)]
goodridge [(11, 'g', 175164), (1666, '##oodr', 225), (1414, '##idg', 9693), (37, '##e', 2189892)]
goodrich [(11, 'g', 175164), (1666, '##oodr', 225), (138, '##ich', 50656)]
frinidad [(77, 'fr', 97189), (1415, '##inidad', 319)]
giuse [(1570, 'giuse', 149)]
zuzanny [(30, 'z', 333899), (1354, '##uz', 49129), (925, '##anny', 430)]
cleaveland [(555, 'cle', 53882), (34, '##a', 2400867), (1653, '##veland', 315)]
myfanny [(1678, 'my', 12357), (51, '##f', 371385), (925, '##anny', 430)]
tinidad [(24, 't', 188991), (1415, '##inidad', 319)]
traylor [(877, 'tr', 110279), (1661, '##aylor', 126)]
dwigh [(1863, 'dwig', 111), (36, '##h', 641336)]


In [24]:
# add any under-represented tokens that don't start with ## as under-represented names
for token, id_ in tokenizer_vocab.items():
    if counter[id_] >= under_represented_threshold:
        continue
    if '[' in token or '#' in token:
        continue
    print(token)
    under_represented_names.add(token)

giuse
dwig


In [25]:
len(under_represented_names)

13

### Add names that contain under-represented tokens to aug_data

In [26]:
cnt = 0
for pos in tqdm(under_represented_names):
    for neg in common_names[:under_represented_threshold]:
        if pos == neg or (pos, neg) in common_non_negatives:
            continue
        pos_neg = f"{pos},{neg}"
        if pos_neg in seen_anchor_pos:
            continue
        neg_pos = f"{neg},{pos}"
        if neg_pos in seen_anchor_pos:
            continue
        aug_data.append({
            'anchor': pos,
            'positive': pos,
            'negative': neg,
            'positive_score': 1.0,
            'negative_score': 0.0,
        })
        cnt += 1
print(cnt)

  0%|          | 0/13 [00:00<?, ?it/s]

6493


## Re-Analyze training data

In [27]:
counter = Counter()
for row in tqdm(aug_data):
    for key in ['anchor', 'positive', 'negative']:
        for token in tokenize(row[key]):
            if token == 1:
                break
            counter[token] += 1
for ix, (token, cnt) in enumerate(counter.most_common()):
    print(ix, token, token_id2text[token], cnt)

  0%|          | 0/15486190 [00:00<?, ?it/s]

0 34 ##a 2402036
1 37 ##e 2190918
2 35 ##n 1777974
3 31 ##o 1551943
4 45 ##s 1457412
5 75 ##ia 1188307
6 41 ##l 1183635
7 42 ##t 1106397
8 49 ##y 904261
9 123 ##ina 851591
10 38 ##i 772003
11 39 ##r 768888
12 71 ##ie 762546
13 40 ##d 761814
14 89 ##us 747839
15 64 ##in 733493
16 85 ##na 664910
17 36 ##h 642349
18 104 ##ine 633647
19 94 ##la 594928
20 171 ##io 562168
21 59 ##er 561259
22 72 ##on 545622
23 135 ##ta 543904
24 52 ##b 527673
25 73 ##is 524536
26 44 ##g 510928
27 32 ##m 493290
28 103 ##ne 478601
29 62 ##en 466303
30 56 ##z 415161
31 119 ##da 409005
32 65 ##el 375927
33 51 ##f 372396
34 705 ##ri 355020
35 153 ##ra 339920
36 61 mar 339460
37 30 z 334897
38 47 ##c 329114
39 97 ##le 326423
40 163 ar 326230
41 46 ##k 323331
42 58 ##an 319215
43 108 al 306982
44 33 ##p 300462
45 1368 ##no 289938
46 68 ##es 288029
47 66 ##il 285533
48 712 ##re 281678
49 19 o 276387
50 48 ##u 275225
51 187 la 274943
52 202 ##sa 273505
53 251 re 271362
54 134 el 265379
55 159 ##de 253631
56 142 ##te 

In [28]:
for token, id_ in tokenizer_vocab.items():
    if counter[id_] == 0:
        print(id_, token, counter[id_])

2 [CLS] 0
0 [UNK] 0
3 [SEP] 0
1 [PAD] 0
1794 ##frie 0
4 [MASK] 0


In [29]:
tokenize('zetty')

[30, 232, 49, 1, 1, 1, 1, 1, 1, 1]

## Save augmented triplets

In [35]:
print(type(aug_data))
print(len(aug_data))

<class 'list'>
15486190


In [36]:
random.shuffle(aug_data)
df = pd.DataFrame(aug_data)
save_file(augmented_path,
          lambda local_out_path : df.to_csv(local_out_path, index=False))

In [34]:
augmented_path

's3://fs-nama-data/2024/familysearch-names/processed/tree-hr-given-triplets-1000-augmented.csv.gz'