In [None]:
%load_ext autoreload
%autoreload 2

# Analyze sub-clusters

The questions to answer are:

1. Do any names appear more than once? YES, but it's ok
2. How far away are names from sub-cluster centroids at different distance thresholds? 
3. How many sub-clusters per cluster at different distance thresholds?

In [None]:
from collections import defaultdict
import json
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
import torch
from tqdm.auto import tqdm

from src.models.biencoder import BiEncoder
from src.models.tokenizer import get_tokenize_function_and_vocab

In [None]:
# configure
given_surname = "given"

distance_threshold = 0.65
sub_clusters_path = f"../data/models/sub_clusters_{given_surname}-{distance_threshold}.json"
nama_bucket = 'nama-data'
vocab_type = 'f'
subword_vocab_size = 2000
subwords_path=f"data/models/fs-{given_surname}-subword-tokenizer-{subword_vocab_size}{vocab_type}.json"
pref_path = f"s3://familysearch-names/processed/tree-preferred-{given_surname}-aggr.csv.gz"
model_path = f"../data/models/bi_encoder-{given_surname}.pth"

## Load data

In [None]:
# load sub-clusters
with open(sub_clusters_path, 'rt') as f:
    clusters = json.load(f)
print(len(clusters))

In [None]:
# load model
model = torch.load(model_path)

In [None]:
# load tokenize function
tokenize, tokenizer_vocab = get_tokenize_function_and_vocab(
    subwords_path=subwords_path,
    nama_bucket=nama_bucket,
)
len(tokenizer_vocab)

## Report names appearing more than once

In [None]:
name_clusters = defaultdict(list)
for cluster_name, cluster in clusters.items():
    for sub_cluster_name, sub_cluster in cluster.items():
        seen_sub_clusters.add(sub_cluster_name)
        for name in sub_cluster:
            name_clusters[name].append(cluster_name)
            if len(name_clusters[name]) > 1:
                print("Name appears more than once", name, name_clusters[name])
            

## Compute sub-cluster centroids

In [None]:
bucket = ['altino', 'aaltje', 'altgen', 'eltje', 'aeltje', 'aalken', 'aaltjen']
emb1 = model.get_embedding(tokenize(bucket[0]))
emb1 /= np.linalg.norm(emb1)
for name in bucket[1:]:
    emb2 = model.get_embedding(tokenize(name))
    sim = cosine_similarity([emb1], [emb2])[0][0]
    print(name, sim)

In [None]:
similarities = []
centroids = {}
for ix, (cluster_name, cluster) in tqdm(enumerate(clusters.items())):
    for sub_cluster_name, sub_cluster in cluster.items():
        embeddings = []
        for name in sub_cluster:
            embedding = model.get_embedding(tokenize(name))
            embedding /= np.linalg.norm(embedding)
            embeddings.append(embedding)
        centroid = np.array(embeddings).sum(axis=0) / len(embeddings)
        centroids[f"{cluster_name}/{sub_cluster_name}"] = centroid
        for name in sub_cluster:
            embedding = model.get_embedding(tokenize(name))
            embedding /= np.linalg.norm(embedding)
            similarity = cosine_similarity([centroid], [embedding])[0][0]
            if ix < 10:
                print(sub_cluster_name, name, similarity)
            similarities.append(similarity)         

In [None]:
print(len(similarities))
similarities[:10]

### Plot name similarity to cluster centroids

In [None]:
len([sim for sim in similarities if sim < 0.5])

In [None]:
plt.figure(figsize=(10, 6))
plt.hist(random.sample(similarities, 10000), bins=40, label="Name similarity to centroid")
plt.title('Centroid similarities')
plt.xlabel('similarity')
plt.ylabel('Frequency')

# Show the plot
plt.tight_layout()
plt.show()

### Plot number of sub-clusters

In [None]:
n_sub_clusters = []
for cluster in clusters.values():
    n_sub_clusters.append(len(cluster))
len(n_sub_clusters)

In [None]:
len([n for n in n_sub_clusters if n > 10])

In [None]:
len(clusters['elizabeth'])

In [None]:
plt.figure(figsize=(10, 6))
plt.hist(n_sub_clusters, bins=40, label="Number of Sub-clusters")
plt.title('Number of Sub-clusters')
plt.xlabel('Number of Sub-clusters')
plt.ylabel('Frequency')

# Show the plot
plt.tight_layout()
plt.show()