In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Train a clustering model
Once we have a good scoring function (using the ensemble model), we can use a standard clustering algorithm to group names into clusters.

In [None]:
from bisect import bisect_left
from collections import namedtuple, defaultdict
import pickle
import random

import joblib
import numpy as np
import pandas as pd
import torch
import wandb

from src.data.normalize import normalize_freq_names
from src.data.utils import load_dataset
from src.data.filesystem import fopen
from src.models.cluster import (
    get_names_to_cluster,
    get_distances,
    generate_clusters_from_distances,
    write_clusters,
    read_clusters,
)
from src.models.swivel import SwivelModel
from src.models.utils import remove_padding

In [None]:
# configure
given_surname = "given"
vocab_size = 610000 if given_surname == "given" else 2100000
n_to_cluster = 50000
cluster_threshold = 0.3
n_jobs = 8

embed_dim = 100
encoder_layers = 2
num_matches = 1000
batch_size = 256
verbose = True

Config = namedtuple("Config", [
    "eval_path",
    "tree_freq_path",
    "hr_freq_path",
    "embed_dim",
    "swivel_vocab_path",
    "swivel_model_path",
    "tfidf_path",
    "ensemble_model_path",
    "cluster_path",
])
config = Config(
    eval_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-train.csv.gz",
    tree_freq_path=f"s3://familysearch-names/processed/tree-preferred-{given_surname}-aggr.csv.gz",
    hr_freq_path=f"s3://familysearch-names-private/hr-preferred-{given_surname}-aggr.csv.gz",
    embed_dim=embed_dim,
    swivel_vocab_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-vocab-{vocab_size}-augmented.csv",
    swivel_model_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-model-{vocab_size}-{embed_dim}-augmented.pth",
    tfidf_path=f"s3://nama-data/data/models/fs-{given_surname}-tfidf.joblib",
    ensemble_model_path=f"s3://nama-data/data/models/fs-{given_surname}-ensemble-model-{vocab_size}-{embed_dim}-augmented-100.joblib",
    cluster_path=f"s3://nama-data/data/models/fs-{given_surname}-cluster-names-{n_to_cluster}-{cluster_threshold}.csv",
)

In [None]:
wandb.init(
    project="nama",
    entity="nama",
    name="81_cluster",
    group=given_surname,
    notes="",
    config=config._asdict()
)

### Load data

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
freq_df = pd.read_csv(config.tree_freq_path, na_filter=False)
tree_name_freq = normalize_freq_names(freq_df, is_surname=given_surname != "given", add_padding=True)
freq_df = None

In [None]:
# create clusters based upon tree freq so we get consistent cluster names
cluster_name_freq = tree_name_freq

In [None]:
vocab_df = pd.read_csv(fopen(config.swivel_vocab_path, "rb"), na_filter=False)
swivel_vocab = {name: _id for name, _id in zip(vocab_df["name"], vocab_df["index"])}

In [None]:
swivel_model = SwivelModel(len(swivel_vocab), config.embed_dim)
swivel_model.load_state_dict(torch.load(fopen(config.swivel_model_path, "rb"), map_location=torch.device(device)))
swivel_model.to(device)
swivel_model.eval()

In [None]:
tfidf_vectorizer = joblib.load(fopen(config.tfidf_path, mode='rb'))

In [None]:
ensemble_model = joblib.load(fopen(config.ensemble_model_path, mode='rb'))

### Get names to cluster

In [None]:
# TODO pass in swivel_vocab and ensure that only names in swivel_vocab were selected to cluster
names_to_cluster = get_names_to_cluster(cluster_name_freq, n_to_cluster)

### Compute clusters

In [None]:
%%time
distances = get_distances(cluster_name_freq, 
                          names_to_cluster,
                          swivel_model=swivel_model,
                          swivel_vocab=swivel_vocab,
                          tfidf_vectorizer=tfidf_vectorizer,
                          ensemble_model=ensemble_model,
                          num_matches=num_matches,
                          verbose=verbose,
                          n_jobs=n_jobs,
                         )

In [None]:
%%time
model, name2cluster = generate_clusters_from_distances(
                            cluster_algo="agglomerative",
                            cluster_linkage="average",
                            cluster_threshold=cluster_threshold,
                            distances=distances,
                            names_to_cluster=names_to_cluster,
                            verbose=verbose,
                            n_jobs=n_jobs)

### Save clusters

In [None]:
cluster2names = defaultdict(set)
for name, cluster in name2cluster.items():
    cluster2names[cluster].add(remove_padding(name))
len(cluster2names)

In [None]:
def get_most_frequent_name(names):
    most_freq_name = None
    highest_freq = -1
    for name in names:
        freq = cluster_name_freq.get(name, 0)
        if freq > highest_freq:
            most_freq_name = name
            highest_freq = freq
    return most_freq_name


In [None]:
all_names = []
all_clusters = []
for cluster_id, names in cluster2names.items():
    cluster = get_most_frequent_name(names)
    for name in names:
        all_clusters.append(cluster)
        all_names.append(name)
df = pd.DataFrame({"name": all_names, "cluster": all_clusters})
print(len(df))
df.head(5)

In [None]:
df.to_csv(config.cluster_path, index=False)

In [None]:
wandb.finish()