In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from collections import namedtuple

import pandas as pd
import ray
from ray import tune
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.integration.wandb import WandbLoggerCallback
import torch

from src.data.utils import load_datasets, select_frequent_k
from src.data.filesystem import fopen
from src.eval.metrics import (
    avg_precision_at_threshold, 
    avg_weighted_recall_at_threshold,
    precision_weighted_recall_curve_at_threshold,
)
from src.models.cluster import (
    get_sorted_similarities,
    generate_closures,
    generate_clusters,
    get_clusters,
    get_best_cluster_matches,
)
from src.models.swivel import SwivelModel, get_swivel_embeddings

In [None]:
# configure
wandb_api_key_file = "../.wandb-api-key"
given_surname = "surname"
vocab_size = 600000 if given_surname == "given" else 2100000
embed_dim = 100

DEFAULT_NAMES_TO_CLUSTER = vocab_size
DEFAULT_CLOSURE_THRESHOLD = 5000
DEFAULT_SEARCH_THRESHOLD = 0.55
DEFAULT_ALGO = "agglomerative"
# agglomerative options
DEFAULT_CLUSTER_THRESHOLD = 0.7
DEFAULT_CLUSTER_LINKAGE = "complete"
# optics and hdbscan options
DEFAULT_MIN_SAMPLES = 2
DEFAULT_EPS = 0.2
# optics options
DEFAULT_MAX_EPS = 1.0
DEFAULT_XI = 0.15
# hdbscan options
DEFAULT_SELECTION_METHOD = "eom"
DEFAULT_MIN_CLUSTER_SIZE = 2

Config = namedtuple("Config", "train_path embed_dim swivel_vocab_path swivel_model_path")
config = Config(
    train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-train.csv.gz",
    embed_dim=embed_dim,
    swivel_vocab_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-vocab-{vocab_size}.csv",
    swivel_model_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-model-{vocab_size}-{embed_dim}.pth",
)

### Load data

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
[train] = load_datasets([config.train_path])
input_names_train, weighted_actual_names_train, candidate_names_train = train

vocab_df = pd.read_csv(fopen(config.swivel_vocab_path, "rb"))
swivel_vocab = {name: _id for name, _id in zip(vocab_df["name"], vocab_df["index"])}

swivel_model = SwivelModel(len(swivel_vocab), config.embed_dim)
swivel_model.load_state_dict(torch.load(fopen(config.swivel_model_path, "rb"), map_location=torch.device(device)))
swivel_model.eval()

### Optimize hyperparameters

In [None]:
def ray_training_function(config,
                          swivel_model,
                          swivel_vocab,
                          input_names_train,
                          weighted_actual_names_train,
                          candidate_names_train,
                          checkpoint_dir=None):

    # filter names to cluster from train
    input_names_cluster, weighted_actual_names_cluster, candidate_names_cluster = \
        select_frequent_k(input_names_train, 
                          weighted_actual_names_train, 
                          candidate_names_train,
                          config["n_to_cluster"])

    # validate on names to cluster
    input_names_validate = input_names_cluster
    weighted_actual_names_validate = weighted_actual_names_cluster
    candidate_names_validate = candidate_names_cluster
    input_names_train = weighted_actual_names_train = candidate_names_train = None  # release memory
    
    # get names to cluster
    cluster_names = list(set(input_names_cluster).union(set(candidate_names_cluster)))
    input_names_cluster = candidate_names_cluster = None  # release memory

    # get embeddings for names to cluster
    cluster_embeddings = get_swivel_embeddings(swivel_model, swivel_vocab, cluster_names).astype('float32')
    cluster_names = None  # release memory

    # get sorted_similarities from embeddings
    sorted_similarities = get_sorted_similarities(cluster_embeddings, threshold=0.4)

    # generate closures from sorted similarities
    _, closure2ids, _, max_score_not_merged = generate_closures(sorted_similarities, config["closure_threshold"])
    sorted_similarities = None  # release memory

    # generate clusters from closures and embeddings
    id2cluster = generate_clusters(closure2ids,
                                   cluster_embeddings,
                                   cluster_algo=config["cluster_algo"],
                                   # agglomerative options
                                   cluster_linkage=config["cluster_linkage"],
                                   cluster_threshold=config["cluster_threshold"],
                                   # optics or hdbscan options
                                   min_samples=config["min_samples"],
                                   eps=config["eps"],
                                   # optics options
                                   cluster_method=config["cluster_method"],
                                   max_eps=config["max_eps"],
                                   xi=config["xi"],
                                   # hdbscan options
                                   selection_method=config["selection_method"],
                                   min_cluster_size=config["min_cluster_size"],
                                   # other options
                                   n_jobs=1,
                                   verbose=False,
                                  )
    closure2ids = None  # release memory

    # get validate names and embeddings
    validate_names = list(set(input_names_validate).union(set(candidate_names_validate)))
    validate_embeddings = get_swivel_embeddings(swivel_model, swivel_vocab, validate_names).astype('float32')
    candidate_names_validate = swivel_model = swivel_vocab = None  # release memory

    # assign all names to clusters
    name2clusters, cluster2names = get_clusters(validate_names,
                                                validate_embeddings,
                                                id2cluster,
                                                cluster_embeddings,
                                                k=100,
                                                max_clusters=5,
                                                verbose=False,
                                               )
    all_names = all_embeddings = id2cluster = cluster_embeddings = None  # release memory

    num_clusters = len(cluster2names)
    max_cluster_size = max([len(names) for names in cluster2names.values()])
    
#     print("max_score_not_merged", max_score_not_merged)
#     print("num_clusters", num_clusters)
#     print("max_cluster_size", max_cluster_size)

    # get best matches
    best_matches = get_best_cluster_matches(name2clusters, cluster2names, input_names_validate)
    name2clusters = cluster2names = input_names_validate = None  # release memory

    # eval f1
    precision = avg_precision_at_threshold(weighted_actual_names_validate, best_matches, config["search_threshold"])
    recall = avg_weighted_recall_at_threshold(weighted_actual_names_validate, best_matches, config["search_threshold"])
    f1 = 2 * (precision * recall) / (precision + recall)

    # Report the metrics to Ray
    tune.report(f1=f1,
                precision=precision, 
                recall=recall,
               )

In [None]:
config_params={
    "cluster_algo": DEFAULT_ALGO,
    "n_to_cluster": DEFAULT_NAMES_TO_CLUSTER,  # tune.qrandint(100000, 500000, 100000),
    "closure_threshold": DEFAULT_CLOSURE_THRESHOLD,
    "search_threshold": DEFAULT_SEARCH_THRESHOLD,
    "cluster_threshold": tune.grid_search([0.95, 0.9, 0.85, 0.8]),
    "cluster_linkage": DEFAULT_CLUSTER_LINKAGE,  # tune.choice(["average", "single", "complete", "ward"]),
    "min_samples": DEFAULT_MIN_SAMPLES,
    "eps": DEFAULT_EPS,
    "max_eps": DEFAULT_MAX_EPS,
    "cluster_method": "dbscan",
    "xi": DEFAULT_XI, 
    "selection_method": DEFAULT_SELECTION_METHOD,  # tune.choice(["eom", "leaf"]),
    "min_cluster_size": DEFAULT_MIN_CLUSTER_SIZE,
}

current_best_params = [{
    "cluster_algo": DEFAULT_ALGO,
    "n_to_cluster": DEFAULT_NAMES_TO_CLUSTER,
    "closure_threshold": DEFAULT_CLOSURE_THRESHOLD,
    "search_threshold": DEFAULT_SEARCH_THRESHOLD,
    "cluster_threshold": DEFAULT_CLUSTER_THRESHOLD,
    "cluster_linkage": DEFAULT_CLUSTER_LINKAGE,
    "min_samples": DEFAULT_MIN_SAMPLES,
    "eps": DEFAULT_EPS,
    "max_eps": DEFAULT_MAX_EPS,
    "cluster_method": "dbscan",
    "xi": DEFAULT_XI,
    "selection_method": DEFAULT_SELECTION_METHOD,
    "min_cluster_size": DEFAULT_MIN_CLUSTER_SIZE,
}]

In [None]:
# https://docs.ray.io/en/latest/tune/api_docs/suggestion.html#tune-hyperopt
search_alg = HyperOptSearch(points_to_evaluate=current_best_params)

In [None]:
ray.shutdown()
ray.init()

callbacks = []
if wandb_api_key_file:
    callbacks.append(WandbLoggerCallback(
        project="nama",
        entity="nama",
        group="80_cluster_tune_"+given_surname+"_600",
        notes="",
        config=config._asdict(),
        api_key_file=wandb_api_key_file
    ))

result = tune.run(
    tune.with_parameters(ray_training_function,
                         swivel_model=swivel_model,
                         swivel_vocab=swivel_vocab,
                         input_names_train=input_names_train,
                         weighted_actual_names_train=weighted_actual_names_train,
                         candidate_names_train=candidate_names_train),
    resources_per_trial={"cpu": 2.0, "gpu": 0.0},
    max_concurrent_trials=4,
    config=config_params,
#     search_alg=search_alg,
#     num_samples=6,
#     metric="f1",
#     mode="max",
#     checkpoint_score_attr="f1",
#     time_budget_s=4*3600,
    progress_reporter=tune.JupyterNotebookReporter(
        overwrite=False,
        max_report_frequency=5*60
    ),
    callbacks=callbacks
)

### Get best model

In [None]:
# Get trial that has the highest F1
best_trial = result.get_best_trial(metric='f1', mode='max', scope='all')

In [None]:
# Parameters with the highest F1
best_trial.config

In [None]:
print(f"Best trial final train f1: {best_trial.last_result['f1']}")
print(f"Best trial final train precision: {best_trial.last_result['precision']}")
print(f"Best trial final train recall: {best_trial.last_result['recall']}")

### Get all trials as DF

In [None]:
# All trials as pandas dataframe
df = result.results_df

In [None]:
df

### Test

In [None]:
cluster_names = ["<john>", "<jonathan>", "<mary>", "<marie>", "<maria>", "<george>"]
closure2ids = {"c": [0,1,2,3,4,5]}
cluster_embeddings = get_swivel_embeddings(swivel_model, swivel_vocab, cluster_names).astype('float32')

In [None]:
id2cluster = generate_clusters(closure2ids, cluster_embeddings, 0.15, "average", n_jobs=1)
print(id2cluster)
id2cluster = generate_clusters(closure2ids, cluster_embeddings, 0.99, "average", n_jobs=1)
print(id2cluster)
id2cluster = generate_clusters(closure2ids, cluster_embeddings, 0.01, "average", n_jobs=1)
print(id2cluster)
id2cluster = generate_clusters(closure2ids, cluster_embeddings, 0.01, "ward", n_jobs=1)
print(id2cluster)


In [None]:
from sklearn.preprocessing import normalize

cluster_embeddings = normalize(cluster_embeddings)

In [None]:
from sklearn.cluster import OPTICS, cluster_optics_dbscan

min_samples=2
max_eps=0.7
xi=0.05   # 0.01..0.20, 0.01
metric="cosine"
eps=0.45  # 0.45..0.70, 0.05


clust = OPTICS(min_samples=min_samples, 
               xi=xi, 
               max_eps=max_eps,
               metric=metric,
              )
clust.fit(cluster_embeddings)

labels = cluster_optics_dbscan(
    reachability=clust.reachability_,
    core_distances=clust.core_distances_,
    ordering=clust.ordering_,
    eps=eps,
)
labels

In [None]:
import hdbscan

min_samples=2
eps=0.0
selection_method="leaf"
min_cluster_size=2

clust = hdbscan.HDBSCAN(min_samples=min_samples,
                        cluster_selection_epsilon=eps,
                        cluster_selection_method=selection_method,
                        min_cluster_size=min_cluster_size,
                        metric="euclidean",
                        )
clust.fit(cluster_embeddings)

clust.labels_

In [None]:
max_cluster = max(clust.labels_)
max_cluster

In [None]:
labels = [0,1,0,1,-1,-1,0,1,-1]
max_cluster = max(labels)

In [None]:
results = []
for label in labels:
    if label < 0:
        max_cluster += 1
        label = max_cluster
    results.append(label)
    
print(results)