In [None]:
%load_ext autoreload
%autoreload 2

# Train a bi-encoder to learn name-to-vec encodings
Try phonemes, subwords, and/or n-grams using anchor-name pairs with MSELoss

In [None]:
import json
import random

import joblib
import pandas as pd
from phonemizer.separator import Separator
from phonemizer.backend import EspeakBackend

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import ray
from ray import air, tune
from ray.air import session
from ray.tune.search.hyperopt import HyperOptSearch
from sklearn.model_selection import train_test_split

from src.data.filesystem import fopen

In [None]:
given_surname = "given"
sample_frac = 1.0

use_bigrams = True
embedding_dim = 16
max_tokens = 15
learning_rate = 0.001
batch_size = 32
report_size = 10000
num_epochs = 10

triplets_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-triplets.csv.gz"
tfidf_path=f"s3://nama-data/data/models/fs-{given_surname}-tfidf-v2.joblib"
vocab_path = f"s3://nama-data/data/models/fs-{given_surname}-espeak_phoneme_vocab.json"
bigrams_vocab_path = f"s3://nama-data/data/models/fs-{given_surname}-espeak_phoneme_vocab_bigrams.json"
model_path = f"../data/models/bi_encoder-{given_surname}"

In [None]:
torch.cuda.empty_cache()
print(torch.cuda.is_available())
print("cuda total", torch.cuda.get_device_properties(0).total_memory)
print("cuda reserved", torch.cuda.memory_reserved(0))
print("cuda allocated", torch.cuda.memory_allocated(0))

In [None]:
ray.shutdown()
ray.init()

## Load data

In [None]:
# read triplets
triplets_df = pd.read_csv(triplets_path).sample(frac=sample_frac)
print(len(triplets_df))
triplets_df.head(3)

In [None]:
# read phoneme vocab
with fopen(bigrams_vocab_path if use_bigrams else vocab_path, 'r') as f:
    phoneme_vocab = json.load(f)

In [None]:
phoneme_vocab['[UNK]'] = len(phoneme_vocab)
phoneme_vocab['[PAD]'] = len(phoneme_vocab)
vocab_size = len(phoneme_vocab)
vocab_size

## Set up generators

In [None]:
espeak = EspeakBackend('en-us')
separator = Separator(phone=' ', syllable=None, word='|')

## Create training data

In [None]:
def tokenize(name, max_tokens):
    tokens = [phoneme_vocab['[PAD]']] * max_tokens
    unk = phoneme_vocab['[UNK]']
    phonemes = espeak.phonemize([name], separator=separator, strip=True)[0].split(' ')
    context_phoneme = 'START'
    if use_bigrams:
        phonemes.append('END')
    for ix, phoneme in enumerate(phonemes):
        if ix == max_tokens:
            break
        if use_bigrams:
            phoneme_bigram = f"{context_phoneme},{phoneme}"
            if phoneme_bigram in phoneme_vocab:
                tokens[ix] = phoneme_vocab[phoneme_bigram]
            else:
                tokens[ix] = phoneme_vocab.get(phoneme, unk)
        else:
            tokens[ix] = phoneme_vocab.get(phoneme, unk)
        context_phoneme = phoneme
    return tokens

In [None]:
all_anchors = triplets_df['anchor'].tolist()

In [None]:
# array of (anchor_tokens, name, target)
all_data = []
for tup in tqdm(triplets_df.itertuples()):
    anchor_tokens = tokenize(tup.anchor[1:-1], max_tokens)
    pos_tokens = tokenize(tup.positive[1:-1], max_tokens)
    neg_tokens = tokenize(tup.negative[1:-1], max_tokens)
    easy_neg = random.choice(all_anchors)
    easy_neg_tokens = tokenize(easy_neg[1:-1], max_tokens)
    # anchor, positive
    all_data.append({
        'anchor': torch.tensor(anchor_tokens),
        'name': torch.tensor(pos_tokens),
        'target': torch.tensor(tup.positive_score, dtype=torch.float),
    })
    # anchor, hard-negative
    all_data.append({
        'anchor': torch.tensor(anchor_tokens),
        'name': torch.tensor(neg_tokens),
        'target': torch.tensor(tup.negative_score, dtype=torch.float),
    })
    # anchor, easy-negative
    if anchor_tokens == easy_neg_tokens:
        continue
    all_data.append({
        'anchor': torch.tensor(anchor_tokens),
        'name': torch.tensor(easy_neg_tokens),
        'target': torch.tensor(0.0, dtype=torch.float)
    })

In [None]:
train_data, val_data = train_test_split(all_data, test_size=0.10)
print(len(train_data), len(val_data))

## Train bi-encoder

In [None]:
def loss_fn(anchors, names, labels):
    anchor_name_pred_sim = (anchors * names).sum(dim=-1)
    return F.mse_loss(anchor_name_pred_sim, labels)

In [None]:
# Define your bi-encoder model
class BiEncoder(nn.Module):
    def __init__(self, embedding_dim, vocab_size, max_tokens, pad_token):
        super(BiEncoder, self).__init__()
        self.embedding_dim = embedding_dim
        self.max_tokens = max_tokens
        self.pad_token = pad_token
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
        self.forward_positional_embedding = nn.Embedding(num_embeddings=max_tokens+1, embedding_dim=embedding_dim)
        self.backward_positional_embedding = nn.Embedding(num_embeddings=max_tokens+1, embedding_dim=embedding_dim)
        self.pooling = nn.AdaptiveAvgPool1d(1)  # Pooling layer to create a single vector

    def forward(self, input):
        # get token embedding
        embedded = self.embedding(input)  # Shape: (batch_size, max_tokens, embedding_dim)
        # get forward positional embedding: pad token is position 0
        positions = torch.arange(start=1, end=self.max_tokens+1).repeat(input.shape[0], 1)
        forward_positions = torch.where(input == self.pad_token, 0, positions)
        forward_positional_embedded = self.forward_positional_embedding(forward_positions)
        # get backward positional embedding
        backward_positions = torch.where(input == self.pad_token, 0, 1)
        backward_n_tokens = backward_positions.sum(dim=1)
        for ix in range(backward_n_tokens.shape[0]):
            n_tokens = backward_n_tokens[ix]
            backward = torch.arange(start=n_tokens, end=0, step=-1)
            backward_positions[ix][:n_tokens] = backward
        backward_positional_embedded = self.backward_positional_embedding(backward_positions)
        # multiply embeddings
        embedded = embedded * forward_positional_embedded * backward_positional_embedded
        pooled = self.pooling(embedded.permute(0, 2, 1)).squeeze(2)  # Shape: (batch_size, embedding_dim)
        return pooled

In [None]:
# Training loop
def train(model, train_loader, val_loader, loss_fn, optimizer, num_epochs, verbose=True):
    for epoch in range(num_epochs):
        # make sure gradient tracking is on
        model.train()
        running_loss = 0

        for ix, data in enumerate(train_loader):
            # get batch
            anchors = data['anchor']
            names = data['name']
            targets = data['target']

            # zero gradients
            optimizer.zero_grad()

            # Forward pass
            anchor_embeddings = model(anchors)  # Shape: (batch_size, embedding_dim)
            name_embeddings = model(names)  # Shape: (batch_size, embedding_dim)

            # Calculate loss
            loss = loss_fn(anchor_embeddings, name_embeddings, targets)

            # Backward pass and optimization step
            loss.backward()
            optimizer.step()

            # Calculate loss and report
            if verbose:
                running_loss += loss.item()
                if ix % report_size == report_size - 1:
                    avg_loss = running_loss / report_size  # loss per batch
                    print(f"Epoch {epoch} batch {ix} loss {avg_loss}")
                    running_loss = 0

        # set model to evaluation mode
        model.eval()

        # disable gradient computation
        running_loss = 0
        num_val_batches = 0
        with torch.no_grad():
            for data in val_loader:
                anchors = data['anchor']
                names = data['name']
                targets = data['target']
                anchor_embeddings = model(anchors)  # Shape: (batch_size, embedding_dim)
                name_embeddings = model(names)  # Shape: (batch_size, embedding_dim)
                loss = loss_fn(anchor_embeddings, name_embeddings, targets)
                running_loss += loss.item()  
                num_val_batches += 1

        # calculate average validation loss
        val_loss = running_loss / num_val_batches
        if verbose:
            print(f"VALIDATION: Epoch {epoch} loss {val_loss}")
        # epoch_model_path = f"{model_path}-{epoch}"
        # torch.save(model.state_dict, epoch_model_path)
        
    # return final epoch validation loss
    return val_loss

## Hyperparameter search

In [None]:
def trainer(config, train_data, val_data, pad_token):
    learning_rate = config['learning_rate']
    batch_size = config['batch_size']
    embedding_dim = config['embedding_dim']
    num_epochs = config['num_epochs']

    # Create an instance of the bi-encoder model
    model = BiEncoder(embedding_dim, vocab_size, max_tokens, pad_token)
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # model.to(device)

    # Define the optimizer
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

    # Create data loader
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True) 

    val_loss = train(model, train_loader, val_loader, loss_fn, optimizer, num_epochs, verbose=False)

    # Report the metrics to Ray
    session.report({'loss': val_loss})

In [None]:
search_space={
    "learning_rate": tune.loguniform(1e-4, 1e-2),
    "batch_size": tune.choice([8,16,32,64]),
    "embedding_dim": tune.choice([8,16,32,64]),
    "num_epochs": tune.choice([5, 10, 20]),
}

starting_parameters = [{
    "learning_rate": learning_rate,
    "batch_size": batch_size,
    "embedding_dim": embedding_dim,
    "num_epochs": num_epochs,
}]

In [None]:
# https://docs.ray.io/en/latest/tune/api/doc/ray.tune.Tuner.html#ray.tune.Tuner

search_alg = HyperOptSearch(points_to_evaluate=starting_parameters)

callbacks = []
# if wandb_api_key_file:
#     callbacks.append(WandbLoggerCallback(
#         project="nama",
#         entity="nama",
#         group="80_cluster_tune_"+given_surname+"_agglomerative",
#         notes="",
#         config=config._asdict(),
#         api_key_file=wandb_api_key_file
#     ))

tuner = tune.Tuner(
    tune.with_parameters(
        trainer,
        train_data=train_data,
        val_data=val_data,
        pad_token=phoneme_vocab['[PAD]']
    ),
    param_space=search_space,
    tune_config=tune.TuneConfig(
        mode='min',
        metric='loss',
        search_alg=search_alg,
        num_samples=200,
        max_concurrent_trials=2,
        time_budget_s=8*60*60,
    ),
    run_config=air.RunConfig(
        callbacks=callbacks,
    ),
)

In [None]:
results = tuner.fit()

In [None]:
# All results as pandas dataframe
df = results.get_dataframe()
print(len(df))
df

In [None]:
# Get result that has the lowest validation loss
best_result = results.get_best_result(metric='loss', mode='min', scope='all')

# Parameters with the lowest loss
best_result.metrics

## Review predictions

In [None]:
batch_size = 64
embedding_dim = 32
learning_rate = 0.0010
num_epochs = 10
# batch_size = best_result.config['batch_size']
# learning_rate = best_result.config['learning_rate']
# embedding_dim = best_result.config['embedding_dim']
# num_epochs = best_result.config['num_epochs']

# Create an instance of the bi-encoder model
model = BiEncoder(embedding_dim, vocab_size, max_tokens, phoneme_vocab['[PAD]'])

# Define the optimizer
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

# Create data loader
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)

train(model, train_loader, val_loader, loss_fn, optimizer, num_epochs)

In [None]:
def predict(name1, name2):
    name1_tokens = tokenize(name1, max_tokens)
    name2_tokens = tokenize(name2, max_tokens)
    model.eval()
    with torch.no_grad():
        embeddings = model(torch.tensor([name1_tokens, name2_tokens]))
    # return (embeddings[0] * embeddings[1]).sum(dim=-1).item()
    return F.cosine_similarity(embeddings[0], embeddings[1], dim=-1).item()

In [None]:
num_correct = 0
num_error = 0
for row in triplets_df.sample(1000).itertuples():
    anchor = row.anchor[1:-1]
    pos = row.positive[1:-1]
    neg = row.negative[1:-1]
    pos_predict = predict(anchor, pos)
    neg_predict = predict(anchor, neg)
    correct = pos_predict > neg_predict
    if correct:
        num_correct += 1
    else:
        num_error += 1
    if pos_predict < 0.6 or not correct:
        print(anchor, pos, pos_predict, row.positive_score, '' if correct else 'ERROR')
        print(anchor, neg, neg_predict, row.negative_score)
print(f"num_correct={num_correct} num_error={num_error}")

In [None]:
for row in triplets_df.sample(100).itertuples():
    anchor = row.anchor[1:-1]
    pos = 'john'
    sim = predict(anchor, pos)
    if sim > 0.5:
        print(anchor, pos, sim)