In [None]:
%load_ext autoreload
%autoreload 2

# Generate triplets from cross encoder

Generate triplets using the cross encoder from 222
- get all pairs from tree-record-attachments and existing standard buckets
- add easy-negs from most-frequent 10000 tree pref names

In [None]:
import csv
from datetime import datetime
import random
import re

import pandas as pd
from sentence_transformers.cross_encoder import CrossEncoder
import torch
from tqdm import tqdm

from src.data.utils import read_csv, load_dataset_v2

In [None]:
given_surname = 'given'

num_training_examples = 10_000_000
run = ''
# if common, write triplets from common names
# else write triplet from train and include num_easy_negs easy negatives
num_easy_negs = 'common'  

allow_dups = False

num_common_names = 10_000

tokenizer_max_length = 32
pref_path = f"s3://familysearch-names/processed/tree-preferred-{given_surname}-aggr.csv.gz"
train_path = f"s3://familysearch-names/processed/tree-hr-{given_surname}-train-v2.csv.gz"
common_non_negatives_path = f"../data/processed/common_{given_surname}_non_negatives-augmented.csv"

std_path = f"../references/std_{given_surname}.txt"
cross_encoder_dir = f"../data/models/cross-encoder-{given_surname}-10m-265-same-all"

triplets_path=f"../data/processed/cross-encoder-triplets-{given_surname}-{num_easy_negs}{'-dups' if allow_dups else ''}{run}.csv"
triplets_path

In [None]:
torch.cuda.empty_cache()
print(torch.cuda.is_available())
print("cuda total", torch.cuda.get_device_properties(0).total_memory)
print("cuda reserved", torch.cuda.memory_reserved(0))
print("cuda allocated", torch.cuda.memory_allocated(0))

## Load data

### load common names

In [None]:
pref_df = read_csv(pref_path)
common_names = [name for name in pref_df['name'][:num_common_names].tolist() \
                if len(name) > 1 and re.fullmatch(r'[a-z]+', name)]
pref_df = None
len(common_names)

### load common non-negatives

In [None]:
name_pairs = set()

common_non_negatives_df = read_csv(common_non_negatives_path)
for name1, name2 in common_non_negatives_df.values.tolist():
    if name1 != name2:
        name_pairs.add((name1, name2))
len(name_pairs)

### load training dataset

In [None]:
tree_names_train, attached_names_train, record_names_train = \
    load_dataset_v2(train_path)

In [None]:
print("tree_names_train", len(tree_names_train))
print("attached_names_train", sum(len(attachments) for attachments in attached_names_train))
print("total pairs", sum(freq for attachments in attached_names_train for _, freq in attachments))
print("record_names_train", len(record_names_train))
print("total names", len(set(tree_names_train).union(set(record_names_train))))

In [None]:
buckets = []

In [None]:
for tree_name, attachments in zip(tree_names_train, attached_names_train):
    names = [tree_name]
    for name, _ in attachments:
        if name not in names:
            names.append(name)
    if len(names) < 2:
        continue
    buckets.append(names)
    for name1 in names:
        for name2 in names:
            if name1 != name2:
                name_pairs.add((name1, name2))
print(len(buckets), sum(len(bucket) for bucket in buckets))

### load std

In [None]:
with open(std_path) as f:
    for line in f.readlines():
        line = line.strip()
        head_names, tail_names = line.split(':')
        head_names = head_names.strip()
        tail_names = tail_names.strip()
        names = []
        for name in head_names.split(' '):
            if len(name) > 0 and name not in names:
                names.append(name)
        for name in tail_names.split(' '):
            if len(name) > 0 and name not in names:
                names.append(name)
        if len(names) < 2:
            continue
        buckets.append(names)
        for name1 in names:
            for name2 in names:
                if name1 != name2:
                    name_pairs.add((name1, name2))
print(len(buckets), sum(len(bucket) for bucket in buckets))

In [None]:
len([bucket for bucket in buckets if len(bucket) == 2])

### name pairs

In [None]:
name_pairs = list(name_pairs)
len(name_pairs)

## Generate and write triplets

In [None]:
model = CrossEncoder(cross_encoder_dir, max_length=tokenizer_max_length)

In [None]:
def harmonic_mean(x,y):
    return 2 / (1/x+1/y)

def choose3(names):
    while True:
        anchor = random.randrange(len(names))
        pos = random.randrange(len(names))
        neg = random.randrange(len(names))
        if anchor != neg and pos != neg and (anchor != pos or len(names) == 2):
            if names[pos] < names[neg]:  # always return pos >= neg
                neg, pos = pos, neg
            return names[anchor], names[pos], names[neg]
        
def choose_pair(pair):
    if random.random() < 0.5:
        return pair[0], pair[1]
    else:
        return pair[1], pair[0]

def clamp(score):
    return max(0.0, min(1.0, score))

In [None]:
triplets_path

In [None]:
def write_triple(writer, triple):
    anchor, pos, neg = triple
    anchor_pos1, anchor_pos2, anchor_neg1, anchor_neg2 = \
        model.predict([[anchor, pos], [pos, anchor], [anchor, neg], [neg, anchor]])
    anchor_pos = harmonic_mean(anchor_pos1, anchor_pos2)
    anchor_neg = harmonic_mean(anchor_neg1, anchor_neg2)
    if anchor == pos:
        anchor_pos = 1.0
    if anchor == neg:
        anchor_neg = 1.0
    anchor_pos = clamp(anchor_pos)
    anchor_neg = clamp(anchor_neg)
    if anchor_pos < anchor_neg:
        pos, neg = neg, pos
        anchor_pos, anchor_neg = anchor_neg, anchor_pos
    writer.writerow({
        'anchor': anchor, 
        'positive': pos, 
        'positive_score': anchor_pos, 
        'negative': neg, 
        'negative_score': anchor_neg,
    })

In [None]:
# open each path, get names, write triplets
cnt = 0
seen_triples = set()
seen_cnt = 0
with open(triplets_path, 'w', newline='') as f:
    # Create a CSV writer object
    writer = csv.DictWriter(f, fieldnames=['anchor','positive','positive_score','negative','negative_score'])

    # Write the column headers
    writer.writeheader()

    while cnt < num_training_examples:
        # write triples from common names
        if num_easy_negs == 'common':
            anchor = random.choice(common_names)
            pos = random.choice(common_names)
            neg = random.choice(common_names)
            if anchor == pos or anchor == neg or pos == neg:
                continue
            if pos < neg:
                pos, neg = neg, pos
            triple = (anchor, pos, neg)
            if triple in seen_triples:
                seen_cnt += 1
                continue
            seen_triples.add(triple)
            write_triple(writer, triple)
            cnt += 1
            if cnt % 10_000 == 0:
                print(cnt, seen_cnt, datetime.now())
        else:
            ix = random.randrange(len(buckets))
            if not allow_dups and len(buckets[ix]) < 3:
                continue
            triple = choose3(buckets[ix])
            if triple in seen_triples:
                seen_cnt += 1
                continue
            seen_triples.add(triple)
            write_triple(writer, triple)
            cnt += 1
            if cnt % 10_000 == 0:
                print(cnt, seen_cnt, datetime.now())

            easy_neg_cnt = 0
            while easy_neg_cnt < num_easy_negs:
                ix = random.randrange(len(name_pairs))
                anchor, pos = choose_pair(name_pairs[ix])
                easy_neg = random.choice(common_names)
                triple = (anchor, pos, easy_neg)
                if triple in seen_triples:
                    seen_cnt += 1
                    continue
                seen_triples.add(triple)
                write_triple(writer, triple)
                easy_neg_cnt += 1
                cnt += 1
                if cnt % 10_000 == 0:
                    print(cnt, seen_cnt, datetime.now())
cnt