In [None]:
%load_ext autoreload
%autoreload 2

# Train a bi-encoder from cross-encoder
learn name-to-vec encodings

### Results

Hyperparameters
* epochs
* embedding_dim
* subword_vocab_size
* add 20% to score

All tests are score^20% trainall noback10mask 256dim 6epochs

| phon/sub | notes  | uni/bi | size | loss | errors | negs |
| -------- | ---- | ------ | ------ | ---- | ---- | ------ | ---- |
| Subwords | base | unigrams | 2000f |  |  |  |


new data @ 12 epochs:         ??? ??? ??, 189 7, 1011 193, 61728
old data @ 12 epochs:         132 124 13,   1 0,  330 197, 56461
old data w model @ 12 epochs: 145 110 76,   8 0,  482 246, 48875

In [None]:
from collections import Counter
import random
import re

import matplotlib.pyplot as plt
import pandas as pd


import torch
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split

from src.models.biencoder import BiEncoder
from src.models.tokenizer import get_tokenize_function_and_vocab

In [None]:
given_surname = "given"
num_common_names = 10000
report_size = 10000
max_tokens = 10

vocab_type = 'f'  # tokenizer based upon training name frequency
use_pretrained_embeddings = False
subword_vocab_size = 2000  # 500, 1000, 1500, 2000
under_represented_threshold = 200

# hyperparameters
embedding_dim = 256
learning_rate = 0.001
batch_size = 64
num_epochs = 12
use_amsgrad = False

# triplets
num_easy_negs = 5

train_triplets_path=f"../data/processed/cross-encoder-triplets-{given_surname}-{num_easy_negs}.csv"

pref_path = f"s3://familysearch-names/processed/tree-preferred-{given_surname}-aggr.csv.gz"
tfidf_path=f"s3://nama-data/data/models/fs-{given_surname}-tfidf-v2.joblib"
phoneme_vocab_path = f"s3://nama-data/data/models/fs-{given_surname}-espeak_phoneme_vocab.json"
phoneme_bigrams_vocab_path = f"s3://nama-data/data/models/fs-{given_surname}-espeak_phoneme_vocab_bigrams.json"

nama_bucket = 'nama-data'
subwords_path=f"../data/models/fs-{given_surname}-subword-tokenizer-{subword_vocab_size}{vocab_type}.json"

common_non_negatives_path = f"../references/common_{given_surname}_non_negatives.csv"
name_variants_path = f"../references/{given_surname}_variants.csv"
given_nicknames_path = "../references/givenname_nicknames.csv"

model_path = f"../data/models/bi_encoder-{given_surname}-ce-{num_easy_negs}.pth"

In [None]:
torch.cuda.empty_cache()
print(torch.cuda.is_available())
print("cuda total", torch.cuda.get_device_properties(0).total_memory)
print("cuda reserved", torch.cuda.memory_reserved(0))
print("cuda allocated", torch.cuda.memory_allocated(0))

## Load data

In [None]:
# read train triplets
train_triplets_df = pd.read_csv(train_triplets_path, na_filter=False)
print(len(train_triplets_df))
train_triplets_df.head(3)

In [None]:
# count the number of unique anchor-pos pairs and anchor-neg pairs
anchor_pos = set()
anchor_neg = set()
for tup in tqdm(train_triplets_df.itertuples()):
    anchor = tup.anchor
    pos = tup.positive
    neg = tup.negative
    anchor_pos.add(f"{anchor}:{pos}")
    anchor_neg.add(f"{anchor}:{neg}")
print(len(anchor_pos))
print(len(anchor_neg))

### read common names

In [None]:
pref_df = pd.read_csv(pref_path, na_filter=False)
common_names = [name for name in pref_df['name'][:num_common_names].tolist() \
                if len(name) > 1 and re.fullmatch(r'[a-z]+', name)]
pref_df = None
len(common_names)

### read common non-negatives

In [None]:
common_non_negatives = set()

def add_common_non_negative(name1, name2):
    if name1 > name2:
        name1, name2 = name2, name1
    common_non_negatives.add(f"{name1}:{name2}")

def is_common_non_negative(name1, name2):
    if name1 > name2:
        name1, name2 = name2, name1
    return f"{name1}:{name2}" in common_non_negatives

In [None]:
common_non_negatives_df = pd.read_csv(common_non_negatives_path, na_filter=False)
for name1, name2 in common_non_negatives_df.values.tolist():
    add_common_non_negative(name1, name2)
len(common_non_negatives)

### add name variants to common non-negatives

In [None]:
name_variants_df = pd.read_csv(name_variants_path, na_filter=False)
print(len(name_variants_df))
name_variants_df.head(3)

In [None]:
for name1, name2 in name_variants_df.values.tolist():
    add_common_non_negative(name1, name2)

### add given nicknames to common non-negatives

In [None]:
if given_surname == "given":
    with open(given_nicknames_path, "rt") as f:
        for line in f.readlines():
            names = line.split(',')
            for name1 in names:
                for name2 in names:
                    if name1 > name2:
                        add_common_non_negative(name1, name2)
len(common_non_negatives)

In [None]:
# use CE model to re-score
# from sentence_transformers.cross_encoder import CrossEncoder

# cross_encoder_dir = f"../data/models/cross-encoder-{given_surname}-10m-265-same-all"
# tokenizer_max_length = 32
# model = CrossEncoder(cross_encoder_dir, max_length=tokenizer_max_length)

# def harmonic_mean(x,y):
#     return 2 / (1/x+1/y)

## Create training data

### Get tokenizer

In [None]:
tokenize, tokenizer_vocab = get_tokenize_function_and_vocab(
    max_tokens=max_tokens,
    subwords_path=subwords_path,
    nama_bucket=nama_bucket,
)
len(tokenizer_vocab)

In [None]:
tokenize('dallan')

### Add anchor-pos-neg triplets

In [None]:
# array of (anchor_tokens, pos_tokens, neg_tokens, target_margin)
all_data = []
for anchor, pos, neg, pos_score, neg_score in tqdm(zip(
    train_triplets_df['anchor'],
    train_triplets_df['positive'],
    train_triplets_df['negative'],
    train_triplets_df['positive_score'],
    train_triplets_df['negative_score'],
), mininterval=2):
    anchor_tokens = tokenize(anchor)
    pos_tokens = tokenize(pos)
    neg_tokens = tokenize(neg)
    # use CE model to re-score
#     anchor_pos1, anchor_pos2, anchor_neg1, anchor_neg2 = \
#         model.predict([[anchor, pos], [pos, anchor], [anchor, neg], [neg, anchor]])
#     pos_score = harmonic_mean(anchor_pos1, anchor_pos2)
#     neg_score = harmonic_mean(anchor_neg1, anchor_neg2)    
    target_margin = pos_score - neg_score
    # anchor, positive, hard-negative
    all_data.append({
        'anchor': torch.tensor(anchor_tokens),
        'pos': torch.tensor(pos_tokens),
        'neg': torch.tensor(neg_tokens),
        'target': torch.tensor(target_margin, dtype=torch.float),
    })
len(all_data)

### Add triplets for names we want to push apart

In [None]:
push_apart_pairs = []
push_apart_copies = 10

In [None]:
for anchor, neg in push_apart_pairs:
    anchor_tokens = tokenize(anchor)
    neg_tokens = tokenize(neg)
    for _ in range(push_apart_copies):
        all_data.append({
            'anchor': torch.tensor(anchor_tokens),
            'pos': torch.tensor(anchor_tokens),
            'neg': torch.tensor(neg_tokens),
            'target': torch.tensor(1.0, dtype=torch.float)
        })

## Analyze training data

In [None]:
train_triplets_path

In [None]:
token_id2text = {}
for text, id_ in tokenizer_vocab.items():
    token_id2text[id_] = text

counter = Counter()
for row in tqdm(all_data):
    for key in ['anchor', 'pos', 'neg']:
        for token in row[key].tolist():
            if token == 1:
                break
            counter[token] += 1
for ix, (token, cnt) in enumerate(counter.most_common()):
    print(ix, token, token_id2text[token], cnt)

In [None]:
tokenize('jewel')

### Find names that contain under-represented tokens

In [None]:
under_represented_token_ids = set([id_ for id_ in tokenizer_vocab.values() if counter[id_] < under_represented_threshold])
print(len(under_represented_token_ids))

In [None]:
all_names = set()
for tup in tqdm(train_triplets_df.itertuples()):
    anchor = tup.anchor
    pos = tup.positive
    neg = tup.negative
    all_names.add(anchor)
    all_names.add(pos)
    all_names.add(neg)
all_names.update(common_names)
len(all_names)

In [None]:
under_represented_names = set()
for name in all_names:
    found_token = False
    for token in tokenize(name):
        if token == 1:  # pad
            break
        if token in under_represented_token_ids:
            found_token = True
            break
    if found_token:
        under_represented_names.add(name)
print(len(under_represented_names))
for name in under_represented_names:
    token_counts = []
    for token in tokenize(name):
        if token == 1:
            break
        token_counts.append((token, token_id2text[token], counter[token]))
    print(name, token_counts)

In [None]:
# add any under-represented tokens that don't start with ## as under-represented names
for token, id_ in tokenizer_vocab.items():
    if counter[id_] >= under_represented_threshold:
        continue
    if '[' in token or '#' in token:
        continue
    print(token)
    under_represented_names.add(token)

In [None]:
len(under_represented_names)

### Add names that contain under-represented tokens to all_data

In [None]:
cnt = 0
for pos in tqdm(under_represented_names):
    pos_tokens = tokenize(pos)
    for neg in common_names[:under_represented_threshold]:
        if pos == neg or is_common_non_negative(pos, neg):
            continue
        neg_tokens = tokenize(neg)
        all_data.append({
            'anchor': torch.tensor(pos_tokens),
            'pos': torch.tensor(pos_tokens),
            'neg': torch.tensor(neg_tokens),
            'target': torch.tensor(1.0, dtype=torch.float)
        })
        cnt += 1
print(cnt)

## Split training data

In [None]:
train_data, val_data = train_test_split(all_data, test_size=0.10, random_state=42)
print(len(train_data), len(val_data))        

In [None]:
train_data[:5]

## Re-Analyze training data

In [None]:
counter = Counter()
for row in tqdm(train_data):
    for key in ['anchor', 'pos', 'neg']:
        for token in row[key].tolist():
            if token == 1:
                break
            counter[token] += 1
for ix, (token, cnt) in enumerate(counter.most_common()):
    print(ix, token, token_id2text[token], cnt)

In [None]:
for token, id_ in tokenizer_vocab.items():
    if counter[id_] == 0:
        print(id_, token, counter[id_])

In [None]:
tokenize('zetty')

## Train bi-encoder

In [None]:
def loss_fn(anchors, positives, negatives, labels):
    # anchor_pos_sim = (anchors * positives).sum(dim=-1)
    # anchor_neg_sim = (anchors * negatives).sum(dim=-1)
    anchor_pos_sim = F.cosine_similarity(anchors, positives, dim=-1)
    anchor_neg_sim = F.cosine_similarity(anchors, negatives, dim=-1)
    margin_pred = anchor_pos_sim - anchor_neg_sim
    return F.mse_loss(margin_pred, labels)

In [None]:
# Training loop
def train(model, train_loader, val_loader, loss_fn, optimizer, num_epochs, verbose=True):
    for epoch in range(num_epochs):
        # make sure gradient tracking is on
        model.train()
        running_loss = 0

        for ix, data in enumerate(train_loader):
            # get batch
            anchors = data['anchor']
            positives = data['pos']
            negatives = data['neg']
            target_margins = data['target']

            # zero gradients
            optimizer.zero_grad()

            # Forward pass
            anchor_embeddings = model(anchors)  # Shape: (batch_size, embedding_dim)
            pos_embeddings = model(positives)  # Shape: (batch_size, embedding_dim)
            neg_embeddings = model(negatives)  # Shape: (batch_size, embedding_dim)

            # Calculate loss
            loss = loss_fn(anchor_embeddings, pos_embeddings, neg_embeddings, target_margins)

            # Backward pass and optimization step
            loss.backward()
            optimizer.step()

            # Calculate loss and report
            if verbose:
                running_loss += loss.item()
                if ix % report_size == report_size - 1:
                    avg_loss = running_loss / report_size  # loss per batch
                    print(f"Epoch {epoch} batch {ix} loss {avg_loss}")
                    running_loss = 0

        # set model to evaluation mode
        model.eval()

        # disable gradient computation
        running_loss = 0
        num_val_batches = 0
        with torch.no_grad():
            for data in val_loader:
                anchors = data['anchor']
                positives = data['pos']
                negatives = data['neg']
                target_margins = data['target']
                anchor_embeddings = model(anchors)  # Shape: (batch_size, embedding_dim)
                pos_embeddings = model(positives)  # Shape: (batch_size, embedding_dim)
                neg_embeddings = model(negatives)  # Shape: (batch_size, embedding_dim)
                loss = loss_fn(anchor_embeddings, pos_embeddings, neg_embeddings, target_margins)
                running_loss += loss.item()  
                num_val_batches += 1

        # calculate average validation loss
        val_loss = running_loss / num_val_batches
        if verbose:
            print(f"VALIDATION: Epoch {epoch} loss {val_loss}")
        # epoch_model_path = f"{model_path}-{epoch}"
        # torch.save(model.state_dict, epoch_model_path)
        
    # return final epoch validation loss
    return val_loss

## Hyperparameter search

## Train model

In [None]:
def get_pretrained_embeddings(tokenizer_vocab, embedding_dim):
    embeddings = [[0.0 for _ in range(embedding_dim)] for _ in range(len(tokenizer_vocab))]
    dims_per_char = embedding_dim // 26
    ones = [1.0] * dims_per_char
    for token, id_ in tokenizer_vocab.items():
        embedding = embeddings[id_]
        if token[0] == '[':
            continue
        for ch in token.strip('#'):
            # set dimensions for ch in embedding to 1.0
            start = (ord(ch) - ord('a')) * dims_per_char
            embedding[start:start+dims_per_char] = ones
        embeddings[id_] = embedding
    return torch.FloatTensor(embeddings)

In [None]:
%%time

# Construct pre-trained embeddings based upon bag-of-characters
pretrained_embeddings = None
if use_pretrained_embeddings:
    pretrained_embeddings = get_pretrained_embeddings(tokenizer_vocab, embedding_dim)
    
# Create an instance of the bi-encoder model
model = BiEncoder(embedding_dim, len(tokenizer_vocab), max_tokens, tokenizer_vocab['[PAD]'], pretrained_embeddings)

# Define the optimizer
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, amsgrad=use_amsgrad)

# Create data loader
### NOTE: train on all_data in place of train_data
train_loader = DataLoader(all_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)

train(model, train_loader, val_loader, loss_fn, optimizer, num_epochs)

## Save Model

In [None]:
torch.save(model, model_path)