In [None]:
%load_ext autoreload
%autoreload 2

# Augment triplets

Add easy-negatives and under-represented tokens

In [None]:
from collections import Counter
import random
import re

import pandas as pd
from tqdm.auto import tqdm

from src.data.utils import read_csv
from src.models.tokenizer import get_tokenize_function_and_vocab

In [None]:
given_surname = "given"
num_common_names = 10000

num_easy_pos_negs = 100
num_easy_neg_negs = 5
num_common_negs = 1000
num_common_neg_copies = 2
num_anchor_pos_neg_copies = 1
under_represented_threshold = 500

max_tokens = 10
use_phonemes = False
use_edit_subwords = False
vocab_type = 'f'  # tokenizer based upon training name frequency
use_bigrams = False
use_pretrained_embeddings = False
subword_vocab_size = 2000  # 500, 1000, 1500, 2000

triplets_path=f"../data/processed/tree-hr-{given_surname}-triplets-v2-1000.csv.gz"

pref_path = f"s3://familysearch-names/processed/tree-preferred-{given_surname}-aggr.csv.gz"
common_non_negatives_path = f"../data/processed/common_{given_surname}_non_negatives-augmented.csv"
nama_bucket = 'nama-data'
subwords_path=f"../data/models/fs-{given_surname}-subword-tokenizer-{subword_vocab_size}{vocab_type}.json"

augmented_path=f"../data/processed/tree-hr-{given_surname}-triplets-v2-1000-augmented.csv.gz"

## Load data

In [None]:
# read triplets
triplets_df = read_csv(triplets_path)
print(len(triplets_df))
triplets_df.head(3)

In [None]:
# count the number of unique anchor-pos pairs and anchor-neg pairs
anchor_pos = set()
anchor_neg = set()
for tup in tqdm(triplets_df.itertuples()):
    anchor = tup.anchor
    pos = tup.positive
    neg = tup.negative
    anchor_pos.add(f"{anchor}:{pos}")
    anchor_neg.add(f"{anchor}:{neg}")
print(len(anchor_pos))
print(len(anchor_neg))

### read common names

In [None]:
pref_df = read_csv(pref_path)
common_names = [name for name in pref_df['name'][:num_common_names].tolist() \
                if len(name) > 1 and re.fullmatch(r'[a-z]+', name)]
pref_df = None
len(common_names)

### read common non-negatives

In [None]:
common_non_negatives = set()

common_non_negatives_df = read_csv(common_non_negatives_path)
for name1, name2 in common_non_negatives_df.values.tolist():
    common_non_negatives.add((name1, name2))
len(common_non_negatives)

## Create training data

In [None]:
# don't add anchor-pos pairs that teach the model bad habits
bad_anchor_pos_pairs = [('maria', 'annamaria'), 
                        ('marie', 'annamarie'),
                       ]

def is_bad_anchor_pos_pair(name1, name2):
    for bad_name1, bad_name2 in bad_anchor_pos_pairs:
        if (name1 == bad_name1 and name2 == bad_name2) or \
           (name2 == bad_name1 and name1 == bad_name2):
            return True
    return False

### Get tokenizer

In [None]:
tokenize, tokenizer_vocab = get_tokenize_function_and_vocab(
    max_tokens=max_tokens,
    subwords_path=subwords_path,
    nama_bucket=nama_bucket,
)
len(tokenizer_vocab)

In [None]:
tokenize('dallan')

### Add anchor-pos-neg triplets

In [None]:
# augment data with easy negatives
aug_data = []
seen_anchor_pos = set()
seen_anchor_neg = set()
for tup in tqdm(triplets_df.itertuples()):
    anchor = tup.anchor
    pos = tup.positive
    neg = tup.negative
    if is_bad_anchor_pos_pair(anchor, pos):
        continue
    # anchor, positive, hard-negative
    for _ in range(num_anchor_pos_neg_copies):
        aug_data.append({
            'anchor': anchor,
            'positive': pos,
            'negative': neg,
            'positive_score': tup.positive_score,
            'negative_score': tup.negative_score,
        })

    # add anchor-pos-easy-negatives
    # only add easy negatives the first time we see this anchor,pos pair
    anchor_pos = f"{anchor},{pos}"
    if anchor_pos not in seen_anchor_pos:
        seen_anchor_pos.add(anchor_pos)
        ix = 0
        while ix < num_easy_pos_negs:
            # anchor, positive, easy-negative
            easy_neg = random.choice(common_names)
            # only add anchor-pos-easy-neg if easy-neg isn't really a non-negative
            if anchor == easy_neg or pos == easy_neg or (anchor, easy_neg) in common_non_negatives:
                continue
            aug_data.append({
                'anchor': anchor,
                'positive': pos,
                'negative': easy_neg,
                'positive_score': tup.positive_score,
                'negative_score': 0.0,
            })
            ix += 1

    # add anchor-neg-easy-negatives
    # only add easy negatives the first time we see this anchor,neg pair
    neg_anchor = f"{neg},{anchor}"
    if neg_anchor not in seen_anchor_neg:
        seen_anchor_neg.add(neg_anchor)
        ix = 0
        while ix < num_easy_neg_negs:
            easy_neg = random.choice(common_names)
            # only add anchor-neg-easy-neg if easy-neg isn't really a non-negative
            if anchor == easy_neg or neg == easy_neg or (anchor, easy_neg) in common_non_negatives:
                continue
            aug_data.append({
                'anchor': anchor,
                'positive': neg,
                'negative': easy_neg,
                'positive_score': tup.negative_score,
                'negative_score': 0.0,
            })
            ix += 1
            
len(aug_data)

### Add pos-pos-easyneg triplets

In [None]:
cnt = 0
for pos in tqdm(common_names[:num_common_negs]):
    for neg in common_names[:num_common_negs]:
        if pos == neg or (pos, neg) in common_non_negatives:
            continue
        for _ in range(num_common_neg_copies):
            aug_data.append({
                'anchor': pos,
                'positive': pos,
                'negative': neg,
                'positive_score': 1.0,
                'negative_score': 0.0,
            })
            cnt += 1
print(cnt)

### Add triplets for names we want to push apart

In [None]:
push_apart_pairs = [('charles', 'frances'),
                    ('marie', 'annie'),
                    ('james', 'jane'),
                    ('jane', 'janos'),
                    ('hannah', 'hans'),
                    ('frank', 'frederick'),
                    ('anne', 'anders'),
                    ('maria', 'manuel'),
                    ('maria', 'manuela'),
                    ('juan','julia'),
                    ('margaret','violet'),
                    ('antonio','emilio'),
                    ('edward','edwin'),
                    ('samuel','smith'),
                    ('martin','augustin'),
                    ('eva','evan'),
                    ('dejesus','de'),
                    ('dejesus','dean'),
                    ('eliza','luiza'),
                    ('frank','mark'),
                    ('benjamin','benita'),
                    ('andrew','matthew'),
                    ('andrew','mathew'),
                    ('guadalupe','guy'),
                    ('jeanne','susanne'),
                    ('delacruz','delaconcepcion'),
                    ('rebecca','veronica'),
                    ('rebecca','francesca'),
                    ('karl','karen'),
                    ('karl','karin'),
                    ('adam','ada'),
                    ('adam','addie'),
                    ('bertha','bruce'),
                    ('edith','edmond'),
                    ('mathias','elias'),
                    ('anton','anta'),
                    ('ethel','effie'),
                    ('delcarmen','oscar'),
                    ('santiago','santos'),
                    ('vicente','clemente'),
                    ('ysabel','ysidro'),
                    ('karen','karolina'),
                    ('ralph','christoph'),
                    ('raymond','reyes'),
                    ('maren','christen'),
                    ('christoph','jph'),
                    ('erzsebet','jozsef'),
                    ('carlos','marcos'),
                    ('ada','adamus'),
                    ('delaluz','dela'),
                    ('jennie','jemima'),
                    ('lorenzo','vincenzo'),
                    ('stina','stella'),
                    ('pearl','per'),
                    ('pearl','pehr'),
                    ('oscar','encarnacion'),
                    ('veronica','francesca'),
                    ('sebastiana','victoriana'),
                    ('elias','matias'),
                    ('myrtle','estelle'),
                    ('bernardo','leonardo'),
                    ('amy','amos'),
                    ('leslie','lester'),
                    ('rosario','hilario'),
                    ('karin','karolina'),
                    ('nora','norma'),
                    ('michaela','mc'),
                    ('christiana','luciana'),
                    ('chen','chester'),
                    ('angelina','augustina'),
                    ('sam','smith'),
                    ('soledad','solomon'),
                    ('mari','jacobi'),
                    ('mari','eli'),
                    ('mari','josephi'),
                    ('mari','li'),
                    ('delacrus','delaconcepcion'),
                    ('etta','etienne'),
                    ('imre','ines'),
                    ('florentina','valentina'),
                    ('jacobi','josephi'),
                    ('joanne','susanne'),
                    ('bernardino','florentino'),
                    ('josefina','rufina'),
                    ('eli','josephi'),
                    ('dean','delia'),
                    ('emilio','mario'),
                    ('jenny','jemima'),
                    ('paulino','antonino'),
                    # 13 Oct 2023
                    ('anne','anders'),
                    ('ann','amy'),
                    ('dejesus','de'),
                    ('dejesus','dedios'),
                    ('anders','an'),
                    ('ana','anastacia'),
                    ('de','dedios'),
                    ('mae','mette'),
                    ('betty','bell'),
                    ('jesus','julius'),
                    ('joao','joel'),
                    ('clarence','claire'),
                    ('martina','marta'),
                    ('roy','ray'),
                    ('pearl','per'),
                    ('veronica','domenica'),
                    ('elias','elisa'),
                    ('lidia','li'),
                    ('sven','sue'),
                    ('bernardino','paulino'),
                    ('bernardino','antonino'),
                    ('eli','elin'),
                    ('emilio','mario'),
                    ('antal','an'),
                    ('anta','an'),
                    ('paulino','antonino'),
                   ]
push_apart_copies = 10

In [None]:
for anchor, neg in push_apart_pairs:
    for _ in range(push_apart_copies):
        aug_data.append({
            'anchor': anchor,
            'positive': anchor,
            'negative': neg,
            'positive_score': 1.0,
            'negative_score': 0.0,
        })

## Analyze training data

In [None]:
token_id2text = {}
for text, id_ in tokenizer_vocab.items():
    token_id2text[id_] = text

counter = Counter()
for row in tqdm(aug_data):
    for key in ['anchor', 'positive', 'negative']:
        for token in tokenize(row[key]):
            if token == 1:
                break
            counter[token] += 1
for ix, (token, cnt) in enumerate(counter.most_common()):
    print(ix, token, token_id2text[token], cnt)

In [None]:
tokenize('jewel')

### Find names that contain under-represented tokens

In [None]:
under_represented_token_ids = set([id_ for id_ in tokenizer_vocab.values() if counter[id_] < under_represented_threshold])
print(len(under_represented_token_ids))

In [None]:
all_names = set()
for tup in tqdm(triplets_df.itertuples()):
    anchor = tup.anchor
    pos = tup.positive
    neg = tup.negative
    all_names.add(anchor)
    all_names.add(pos)
    all_names.add(neg)
all_names.update(common_names)
len(all_names)

In [None]:
under_represented_names = set()
for name in all_names:
    found_token = False
    for token in tokenize(name):
        if token == 1:
            break
        if token in under_represented_token_ids:
            found_token = True
            break
    if found_token:
        under_represented_names.add(name)
print(len(under_represented_names))
for name in under_represented_names:
    token_counts = []
    for token in tokenize(name):
        if token == 1:
            break
        token_counts.append((token, token_id2text[token], counter[token]))
    print(name, token_counts)

In [None]:
# add any under-represented tokens that don't start with ## as under-represented names
for token, id_ in tokenizer_vocab.items():
    if counter[id_] >= under_represented_threshold:
        continue
    if '[' in token or '#' in token:
        continue
    print(token)
    under_represented_names.add(token)

In [None]:
len(under_represented_names)

### Add names that contain under-represented tokens to aug_data

In [None]:
cnt = 0
for pos in tqdm(under_represented_names):
    for neg in common_names[:under_represented_threshold]:
        if pos == neg or (pos, neg) in common_non_negatives:
            continue
        pos_neg = f"{pos},{neg}"
        if pos_neg in seen_anchor_pos:
            continue
        neg_pos = f"{neg},{pos}"
        if neg_pos in seen_anchor_pos:
            continue
        aug_data.append({
            'anchor': pos,
            'positive': pos,
            'negative': neg,
            'positive_score': 1.0,
            'negative_score': 0.0,
        })
        cnt += 1
print(cnt)

## Re-Analyze training data

In [None]:
counter = Counter()
for row in tqdm(aug_data):
    for key in ['anchor', 'positive', 'negative']:
        for token in tokenize(row[key]):
            if token == 1:
                break
            counter[token] += 1
for ix, (token, cnt) in enumerate(counter.most_common()):
    print(ix, token, token_id2text[token], cnt)

In [None]:
for token, id_ in tokenizer_vocab.items():
    if counter[id_] == 0:
        print(id_, token, counter[id_])

In [None]:
tokenize('zetty')

## Save augmented triplets

In [None]:
len(aug_data)

In [None]:
df = pd.DataFrame(aug_data)
df.to_csv(augmented_path, index=False)