In [None]:
%load_ext autoreload
%autoreload 2

# Augment clusters with preferred names

Load clusters and preferred tree names, and determine which common tree names do not appear in the clusters.

Use a bi-encoder followed by a cross-encoder to determine which cluster they should go into.

In [None]:
from collections import defaultdict, Counter
import json
import math
import os
import re

import numpy as np
import pandas as pd
# from py4j.java_gateway import JavaGateway
from sentence_transformers.cross_encoder import CrossEncoder
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics.pairwise import cosine_similarity
import torch
from tqdm.auto import tqdm

from src.data.normalize import normalize
from src.data.utils import read_csv
from src.models.biencoder import BiEncoder
from src.models.tokenizer import get_tokenize_function_and_vocab
from src.models.utils import get_cross_encoder_score, top_similar_names

In [None]:
# configure
given_surname = "given"

max_tokens = 10
subwords_path=f"../data/models/fs-{given_surname}-subword-tokenizer-2000f.json"
common_name_threshold = 100 # TODO 105
pref_path = f"s3://familysearch-names/processed/tree-preferred-{given_surname}-aggr.csv.gz"
ce_model_dir = f"../data/models/cross-encoder-{given_surname}-10m-265-same-all"
be_model_type = 'cecommon+0+aug-0-1'
be_model_path = f"../data/models/bi_encoder-{given_surname}-{be_model_type}.pth"
tokenizer_max_length = 32
ce_model_dir = f"../data/models/cross-encoder-{given_surname}-10m-265-same-all"
linkage = "average"
scorer = "ce"  # be, ce, or cebe
similarity_threshold = 0.10
cluster_freq_normalizer = "none"
clusters_path = f"../data/processed/clusters_{given_surname}-{scorer}-{linkage}-{similarity_threshold}-{cluster_freq_normalizer}.json"

augmented_clusters_path = f"../data/processed/clusters_{given_surname}-{scorer}-{linkage}-{similarity_threshold}-{cluster_freq_normalizer}-augmented.json"

In [None]:
torch.cuda.empty_cache()
print(torch.cuda.is_available())
print("cuda total", torch.cuda.get_device_properties(0).total_memory)
print("cuda reserved", torch.cuda.memory_reserved(0))
print("cuda allocated", torch.cuda.memory_allocated(0))

## Load data

In [None]:
name_cluster = {}             # name -> cluster label
cluster_names = {}            # cluster label -> names
cluster_centroids = []        # centroid for each cluster
cluster_centroid_labels = []  # label for each cluster

with open(clusters_path, 'r') as f:
    clusters = json.load(f)  # cluster label -> names, centroid

for label, cluster in clusters.items():
    cluster_names[label] = set(cluster['names'])
    for name in cluster['names']:
        name_cluster[name] = label
    cluster_centroid_labels.append(label)
    cluster_centroids.append(np.array(cluster['centroid']))
cluster_centroid_labels = np.array(cluster_centroid_labels)

print(len(cluster_names), len(name_cluster))

In [None]:
# load pref names
pref_df = read_csv(pref_path)

In [None]:
# get total frequency, including names w frequency=1 that aren't in pref_df
total_freq = sum(pref_df['frequency']) + len(pref_df[pref_df['frequency'] == 2]) * 2
total_freq

In [None]:
# calculate % of total frequency of the top N names 
freq = sum(pref_df['frequency'][:117000])
print(freq/total_freq)

In [None]:
# create common names pref names that occur >= common_name_threshold
common_names = [name for name, freq in zip(pref_df['name'], pref_df['frequency']) \
                if len(name) > 1 and re.fullmatch(r'[a-z]+', name) and freq >= common_name_threshold]
len(common_names)

In [None]:
# load tokenize function
tokenize, tokenizer_vocab = get_tokenize_function_and_vocab(
    max_tokens=max_tokens,
    subwords_path=subwords_path,
)
len(tokenizer_vocab)

In [None]:
# load bi-encoder model
be_model = torch.load(be_model_path)
be_model.eval()

In [None]:
# load cross encoder model
ce_model = CrossEncoder(ce_model_dir, max_length=tokenizer_max_length)

## Which names are not in the clusters?

In [None]:
print_cnt = 10
unseen_names = []
for ix, name in enumerate(common_names):
    name_pieces = normalize(name, is_surname=given_surname=='surname', dont_return_empty=False)
    if len(name_pieces) != 1:
        continue
    name = name_pieces[0]
    if ix % 1000 == 0 and len(unseen_names) > 0:
        print(ix, len(unseen_names))
        print_cnt = 10
    if name in name_cluster:
        continue
    unseen_names.append(name)
    if print_cnt > 0:
        print('   ', ix, name)
        print_cnt -= 1

In [None]:
print(len(unseen_names))
unseen_names[:10]

### get name embeddings

In [None]:
def get_embedding(name):
    embedding = be_model.get_embedding(tokenize(name)) 
    embedding /= np.linalg.norm(embedding)
    return embedding

In [None]:
name_embeddings_names = np.array(list(name_cluster.keys()))
name_embeddings = [get_embedding(name) for name in name_cluster.keys()]

## Figure out which cluster to put the names into

In [None]:
def get_nearest_bi_encoder_names(name, threshold=0.3, limit=10):
    embedding = get_embedding(name)
    return top_similar_names(embedding, name_embeddings, name_embeddings_names, threshold, limit)

def get_bi_encoder_score(name, other_name):
    emb1 = get_embedding(name)
    emb2 = get_embedding(other_name)
    result = cosine_similarity([emb1], [emb2])[0][0]
    return result

def get_bi_encoder_cluster_score(name, cluster, use_max=False):
    scores = []
    # print('cluster', cluster)
    for cluster_name in cluster_names[cluster]:
        score = get_bi_encoder_score(name, cluster_name)
        # print('  score', name, cluster_name, score)
        scores.append(score)
    return max(scores) if use_max else sum(scores)/len(scores)

def get_nearest_bi_encoder_cluster_score(name, other_names, limit=10, use_max=False):
    # get most-frequently occurring clusters
    clusters = Counter()
    for other_name in other_names:
        if other_name not in name_cluster:
            continue
        cluster = name_cluster[other_name]
        clusters[cluster] += 1
    if len(clusters) == 0:
        return None, None
    # get nearest cluster
    max_cluster = None
    max_score = None
    for cluster, count in clusters.most_common(limit):
        # print('cluster, count', cluster, count)
        score = get_bi_encoder_cluster_score(name, cluster, use_max=use_max)
        # print('   score', score)
        if max_score is None or score > max_score:
            max_score = score
            max_cluster = cluster
    return max_cluster, max_score

In [None]:
def get_cross_encoder_cluster_score(name, cluster, use_max=False):
    scores = []
    # print('cluster', cluster)
    for cluster_name in cluster_names[cluster]:
        score = get_cross_encoder_score(name, cluster_name, ce_model)
        # print('   score', name, cluster_name, score)
        scores.append(score)
    return max(scores) if use_max else sum(scores)/len(scores)

def get_nearest_cross_encoder_cluster_score(name, other_names, limit=10, use_max=False):
    # get most-frequently occurring clusters
    clusters = Counter()
    for other_name in other_names:
        if other_name not in name_cluster:
            continue
        cluster = name_cluster[other_name]
        clusters[cluster] += 1
    if len(clusters) == 0:
        return None, None
    # get nearest cluster
    max_cluster = None
    max_score = None
    for cluster, count in clusters.most_common(limit):
        # print('cluster, count', cluster, count)
        score = get_cross_encoder_cluster_score(name, cluster, use_max=use_max)
        # print('   score', score)
        if max_score is None or score > max_score:
            max_score = score
            max_cluster = cluster
    return max_cluster, max_score

In [None]:
# test
name = 'ivanovna'
names, scores = get_nearest_bi_encoder_names(name, limit=20)
print(names, scores)
# names = names[1:]
# scores = scores[1:]
ce_cluster, ce_score = get_nearest_cross_encoder_cluster_score(name, names)
print('cross-encoder', ce_cluster, ce_score)

be_cluster, be_score = get_nearest_bi_encoder_cluster_score(name, names)
print('bi-encoder', be_cluster, be_score)

In [None]:
def sample_names(cluster):
    if not cluster:
        return ''
    return ' '.join(list(cluster_names[cluster])[:8])

In [None]:
print(len(cluster_names), len(name_cluster))

In [None]:
testing = False

score_threshold = 0.10

unseen_names_set = set(unseen_names)
n_new_clusters = 0

for name in unseen_names[:1000] if testing else tqdm(unseen_names, mininterval=1.0):
    # get fs vote
    # fs_cluster, fs_score = get_fs_cluster_score(name)
    # if fs_cluster is not None:
    #     votes[fs_cluster] += fs_score * fs_weight + fs_boost
    cluster = None
    score = 0.0
    # get nearby names
    names, scores = get_nearest_bi_encoder_names(name)
    if len(names) > 0:
        # get cross-encoder vote
        cluster, score = get_nearest_cross_encoder_cluster_score(name, names)
    
    # print stuff if testing
    if testing and abs(score - score_threshold) < 0.05:
        print()
        print(name)
        # print('   fs', fs_cluster, fs_score, sample_names(fs_cluster))
        print('   ce', cluster, score, sample_names(cluster))
        if score > score_threshold:
            print('WINNER', cluster, score)
        continue

    # add name to existing cluster, or create a new cluster
    if score >= score_threshold:
        name_cluster[name] = cluster
        cluster_names[cluster].add(name)
    else:
        n_new_clusters += 1
        name_cluster[name] = name
        cluster_names[name] = {name}

    # add embedding
    name_embeddings_names = np.append(name_embeddings_names, [name], axis=0)
    name_embeddings = np.append(name_embeddings, [get_embedding(name)], axis=0)

print('new clusters', n_new_clusters)

In [None]:
print(len(cluster_names), len(name_cluster))

## Save augmented clusters

In [None]:
def get_centroid(names):
    total_embedding = None
    for name in names:
        embedding = get_embedding(name)
        if total_embedding is None:
            total_embedding = embedding
        else:
            total_embedding += embedding
    # get average embedding
    total_embedding = total_embedding / len(names)
    # normalize
    total_embedding = total_embedding / np.linalg.norm(total_embedding)
    return total_embedding

In [None]:
augmented_clusters_path

In [None]:
clusters = {}
for label, names in cluster_names.items():
    centroid = get_centroid(names)
    clusters[label] = {"names": list(names), "centroid": centroid.tolist()}

with open(augmented_clusters_path, 'w') as f:
    json.dump(clusters, f, indent=2)

## Compare centroids

In [None]:
with open(augmented_clusters_path, 'r') as f:
    augmented_clusters = json.load(f)  # cluster label -> names, centroid

with open(clusters_path, 'r') as f:
    clusters = json.load(f)  # cluster label -> names, centroid
    

In [None]:
sims = []
cnt = 0
for label in clusters:
    centroid = clusters[label]['centroid']
    if label not in augmented_clusters:
        print('Should not happen', label)
        continue
    aug_centroid = augmented_clusters[label]['centroid']
    sim = cosine_similarity([centroid], [aug_centroid])[0][0]
    if sim < 0.8:
        cnt += 1
        print()
        print(label)
        print('1', len(clusters[label]['names']), clusters[label]['names'])
        print('2', len(augmented_clusters[label]['names']), augmented_clusters[label]['names'])
    sims.append(sim)
cnt

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.hist(sims, bins=100, label="sim", color='green')
plt.legend(loc='upper right')
# Show the plot
plt.tight_layout()
plt.show()