In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from collections import namedtuple
import heapq

import joblib
import numpy as np
import pandas as pd
import torch
import wandb

from src.data.normalize import normalize_freq_names
from src.data.utils import load_dataset
from src.data.filesystem import fopen
from src.models.cluster import (
    get_names_to_cluster,
    get_distances,
    generate_clusters_from_distances,
    write_clusters,
)
from src.models.swivel import SwivelModel
from src.models.utils import remove_padding

In [None]:
# configure
given_surname = "given"
vocab_size = 610000 if given_surname == "given" else 2100000
save_partitions = False
save_clusters = True
max_partitions = 720 if save_partitions else 0
n_to_cluster = 200000
cluster_threshold = 0.4 if given_surname == "given" else 0.6
n_jobs = 64

embed_dim = 100
encoder_layers = 2
num_matches = 5000
batch_size = 256
verbose = True

Config = namedtuple("Config", [
    "eval_path",
    "freq_path",
    "embed_dim",
    "swivel_vocab_path",
    "swivel_model_path",
    "tfidf_path",
    "ensemble_model_path",
    "name_partition_path",
    "cluster_path",
])
config = Config(
    eval_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-train.csv.gz",
    freq_path=f"s3://familysearch-names/processed/tree-preferred-{given_surname}-aggr.csv.gz",
    embed_dim=embed_dim,
    swivel_vocab_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-vocab-{vocab_size}-augmented.csv",
    swivel_model_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-model-{vocab_size}-{embed_dim}-augmented.pth",
    tfidf_path=f"s3://nama-data/data/models/fs-{given_surname}-tfidf.joblib",
    ensemble_model_path=f"s3://nama-data/data/models/fs-{given_surname}-ensemble-model-{vocab_size}-{embed_dim}-augmented-100.joblib",
    name_partition_path=f"s3://nama-data/data/models/fs-{given_surname}-cluster_partitions.csv",
    cluster_path=f"s3://nama-data/data/models/fs-{given_surname}-cluster_names.csv",
)

In [None]:
wandb.init(
    project="nama",
    entity="nama",
    name="81_cluster",
    group=given_surname,
    notes="",
    config=config._asdict()
)

### Load data

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
input_names_eval, weighted_actual_names_eval, candidate_names_eval = load_dataset(config.eval_path, is_eval=True)

In [None]:
actual_names_eval = set([name for wans in weighted_actual_names_eval for name, _, _ in wans])
candidate_names_eval = np.array(list(actual_names_eval))
del actual_names_eval
print(len(candidate_names_eval))

In [None]:
freq_df = pd.read_csv(config.freq_path, na_filter=False)
name_freq = normalize_freq_names(freq_df, is_surname=given_surname != "given", add_padding=True)
freq_df = None

In [None]:
vocab_df = pd.read_csv(fopen(config.swivel_vocab_path, "rb"))
swivel_vocab = {name: _id for name, _id in zip(vocab_df["name"], vocab_df["index"])}

In [None]:
swivel_model = SwivelModel(len(swivel_vocab), config.embed_dim)
swivel_model.load_state_dict(torch.load(fopen(config.swivel_model_path, "rb"), map_location=torch.device(device)))
swivel_model.to(device)
swivel_model.eval()

In [None]:
tfidf_vectorizer = joblib.load(fopen(config.tfidf_path, mode='rb'))

In [None]:
ensemble_model = joblib.load(fopen(config.ensemble_model_path, mode='rb'))

### Get names to cluster

In [None]:
names_to_cluster = get_names_to_cluster(name_freq, n_to_cluster)

### Compute cluster hierarchy

In [None]:
%%time
distances = get_distances(name_freq, 
                          names_to_cluster,
                          swivel_model=swivel_model,
                          swivel_vocab=swivel_vocab,
                          tfidf_vectorizer=tfidf_vectorizer,
                          ensemble_model=ensemble_model,
                          num_matches=num_matches,
                          verbose=verbose,
                          n_jobs=n_jobs,
                         )

In [None]:
%%time
model, name_cluster = generate_clusters_from_distances(
                            cluster_algo="agglomerative",
                            cluster_linkage="average",
                            cluster_threshold=-10.0,  # initially put everything into a single cluster
                            distances=distances,
                            names_to_cluster=names_to_cluster,
                            verbose=verbose,
                            n_jobs=n_jobs)

#### Split into partitions

In [None]:
# model.children_ is a list of all non-leaf nodes in the cluster hierarchy that contains their immediate children
leaf_node_count = len(names_to_cluster)
non_leaf_node_count = len(model.children_)
total_node_count = leaf_node_count + non_leaf_node_count

# count the total name frequency in each leaf and non-leaf node
cluster_freq = np.zeros(total_node_count)
for ix in range(0, leaf_node_count):
    cluster_freq[ix] = name_freq[names_to_cluster[ix]]

for ix in range(0, non_leaf_node_count):
    count = 0
    for child in model.children_[ix]:
        count += cluster_freq[child]
    cluster_freq[ix + leaf_node_count] = count

In [None]:
# starting with the partition at the root of the cluster hierarchy, split the largest partition until you have max_partitions
total_partitions = 1
partitions = []
initial_partition = total_node_count - 1
heapq.heappush(partitions, (-cluster_freq[initial_partition], initial_partition, 1))

while total_partitions < max_partitions:
    freq, partition, n_partitions = heapq.heappop(partitions)
    total_partitions -= n_partitions
    # if this is a leaf node that needs to be split, this will be a multi-partition leaf
    if partition < leaf_node_count:
        n_partitions += 1
        total_partitions += n_partitions
        heapq.heappush(partitions, (-cluster_freq[partition] / n_partitions, partition, n_partitions))
    else:
        for child in model.children_[partition - leaf_node_count]:
            total_partitions += 1
            heapq.heappush(partitions, (-cluster_freq[child], child, 1))

In [None]:
# TODO merge smaller partitions?

In [None]:
# histo on partition sizes
partition_sizes_df = pd.DataFrame([-freq for freq, _, _ in partitions])
partition_sizes_df.hist()

#### Split partition(s) into clusters

In [None]:
if save_clusters:
    clusters = []
    # start with the partition nodes
    for _, partition, _ in partitions:
        distance = 0.0 if partition < leaf_node_count else model.distances_[partition - leaf_node_count]
        heapq.heappush(clusters, (-distance, partition))

    # then split each node into clusters if the node's distance is above threshold
    while True:
        distance, cluster = heapq.heappop(clusters)
        if distance >= -(1 - cluster_threshold):  # cluster threshold is measured in terms of (1 - distance)
            heapq.heappush(clusters, (distance, cluster))
            break
        for child in model.children_[cluster - leaf_node_count]:
            distance = 0.0 if child < leaf_node_count else model.distances_[child - leaf_node_count]
            heapq.heappush(clusters, (-distance, child))

#### Save partitions and clusters

In [None]:
partition_map = {partition_id: n_partitions for _, partition_id, n_partitions in partitions}

def get_most_frequent_name(names):
    most_freq_name = None
    highest_freq = -1
    for name in names:
        freq = name_freq.get(name, -1)
        if freq > highest_freq:
            most_freq_name = name
            highest_freq = freq
    return most_freq_name

def partition_finder(node_id):
    return (names_to_cluster[node_id], partition_map.get(node_id, 1)) if node_id < leaf_node_count else None

def name_finder(node_id):
    return names_to_cluster[node_id] if node_id < leaf_node_count else None

def gather_children(node_id, fn, result):
    item = fn(node_id)
    if item:
        result.append(item)
    elif node_id >= leaf_node_count:
        for child in model.children_[node_id - leaf_node_count]:
            gather_children(child, fn, result)

In [None]:
if save_partitions:
    # walk the cluster hierarchy to get the names in each partition
    partition2names = {}
    for _, partition, _ in partitions:
        name_partitions = []
        gather_children(partition, partition_finder, name_partitions)
        if len(name_partitions) == 0:
            print("ERROR empty partition", partition)
        partition2names[partition] = name_partitions

    # invert partition2names to get a dataframe with name, partition pairs
    partition_number = 0
    name_partition_name = []
    name_partition_partition = []
    name_partition_count = []
    for partition, name_partitions in partition2names.items():
        for name, n_partitions in name_partitions:
            name_partition_name.append(remove_padding(name))
            name_partition_partition.append(partition_number)
            name_partition_count.append(n_partitions)
        partition_number += 1 if len(name_partitions) > 1 else name_partitions[0][1]
    name_partition_df = pd.DataFrame({
        "name": name_partition_name,
        "start_partition": name_partition_partition,
        "n_partitions": name_partition_count,
    })

    # write the dataframe to a csv file
    name_partition_df.to_csv(config.name_partition_path, index=False)

In [None]:
if save_clusters:
    # walk the cluster hierarchy to get the names in each cluster
    cluster2names = {}
    for _, cluster in clusters:
        names = []
        gather_children(cluster, name_finder, names)
        if len(names) == 0:
            print("ERROR: empty cluster", cluster)
        # the name of the cluster is the most-frequent name
        freq_name = remove_padding(get_most_frequent_name(names))
        cluster2names[freq_name] = names

    # invert cluster2names
    name_cluster = {}
    for cluster, names in cluster2names.items():
        for name in names:
            name_cluster[name] = cluster

    # write the dataframe to a csv file
    write_clusters(config.cluster_path, name_cluster)

In [None]:
wandb.finish()