In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from collections import namedtuple

import pandas as pd
import ray
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.integration.wandb import WandbLoggerCallback
import torch

from src.data.utils import load_train_test
from src.data.filesystem import fopen
from src.eval.metrics import avg_precision_at_threshold, avg_weighted_recall_at_threshold
from src.models.cluster import (
    get_scores,
    generate_closures,
    generate_clusters,
    assign_names_to_clusters,
    get_best_cluster_matches,
)
from src.models.swivel import get_swivel_embeddings
from src.models.utils import add_padding

In [None]:
# configure

wandb_api_key_file = ""  # "../.wandb-api-key"
given_surname = "given"
size = "freq"
DEFAULT_NAMES_TO_CLUSTER = 100000
DEFAULT_CLOSURE_THRESHOLD = 1000
DEFAULT_CLUSTER_THRESHOLD = 0.4
DEFAULT_CLUSTER_LINKAGE = "average"
DEFAULT_K_NN = 1

vocab_size = 500000
embed_dim = 200
Config = namedtuple("Config", "train_path pref_path embed_dim vocab_path model_path")
config = Config(
    train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-similar-train-{size}.csv.gz",
    pref_path=f"s3://familysearch-names/processed/tree-preferred-{given_surname}-aggr.csv.gz",
    embed_dim=embed_dim,
    vocab_path=f"s3://nama-data/data/models/fs-{given_surname}-{size}-swivel-{vocab_size}-vocab-tfidf.csv",
    model_path=f"../data/models/fs-{given_surname}-{size}-swivel-{vocab_size}-{embed_dim}-tfidf.pt",
    #     model_path=f"s3://nama-data/data/models/fs-{given_surname}-{size}-swivel-{vocab_size}-{embed_dim}-tfidf.pt",
)

### Load data

In [None]:
[train] = load_train_test([config.train_path])

In [None]:
input_names_train, weighted_actual_names_train, candidate_names_train = train

In [None]:
# load preferred names (in frequency order)
pref_df = pd.read_csv(config.pref_path)

In [None]:
pref_df[:5]

In [None]:
vocab_df = pd.read_csv(fopen(config.vocab_path, "rb"))
print(len(vocab_df))
print(vocab_df.head(5))

In [None]:
swivel_vocab = {name: _id for name, _id in zip(vocab_df["name"], vocab_df["index"])}
print(next(iter(swivel_vocab.items())))

In [None]:
swivel_model = torch.load(fopen(config.model_path, "rb"))
swivel_model.eval()
print(swivel_model)

### Get all_names

In [None]:
all_names = set(input_names_train).union(set(candidate_names_train))
print(len(all_names))

In [None]:
pref_names = [add_padding(str(name)) for name in pref_df["name"] if add_padding(str(name)) in all_names]
print(len(pref_names))

In [None]:
all_names = list(all_names)

### Optimize hyperparameters

In [None]:
def ray_training_function(config,
                          swivel_model,
                          swivel_vocab,
                          input_names_train,
                          all_names,
                          pref_names):
    all_embeddings = get_swivel_embeddings(swivel_model, swivel_vocab, all_names).astype('float32')
    clustered_names = pref_names[:config["names_to_cluster"]]
    clustered_embeddings = get_swivel_embeddings(swivel_model, swivel_vocab, clustered_names).astype('float32')
    clustered_scores_sparse, sorted_scores = get_scores(clustered_embeddings, threshold=0.4)

    # generate closures
    id2closure, closure2ids, _, _ = generate_closures(sorted_scores, config["closure_threshold"])

    # clean up
    sorted_scores = None
    
    # generate clusters
    id2cluster = generate_clusters(closure2ids, 
                                   clustered_scores_sparse, 
                                   clustered_names, 
                                   config["cluster_threshold"], 
                                   config["cluster_linkage"],
                                   verbose=False,
                                   n_jobs=1,
                                  )
    
    # clean up
    closure2ids = clustered_scores_sparse = clustered_names = None

    # assign all names to clusters
    name2cluster, cluster2names = assign_names_to_clusters(all_names,
                                                           all_embeddings,
                                                           id2cluster,
                                                           clustered_embeddings,
                                                           k=config["k_nn"])
    num_clusters = len(cluster2names)
    max_cluster_size = max([len(names) for names in cluster2names.values()])
    
    # clean up
    all_names = all_embeddings = id2cluster = clustered_embeddings = None
    
    # get best matches
    best_matches = get_best_cluster_matches(name2cluster, cluster2names, input_names_train)
    
    # clean up
    name2cluster = cluster2names = input_names_train = None
    
    # eval f1
    precision = avg_precision_at_threshold(weighted_actual_names_train, best_matches, 0.5)
    recall = avg_weighted_recall_at_threshold(weighted_actual_names_train, best_matches, 0.5)
    f1 = 2 * (precision * recall) / (precision + recall)

    # Report the metrics to Ray
    tune.report(f1=f1, 
                precision=precision, 
                recall=recall, 
                num_clusters=num_clusters,
                max_cluster_size=max_cluster_size
               )


In [None]:
config_params={
    "names_to_cluster": tune.qrandint(100000, 450000, 50000),
    "closure_threshold": DEFAULT_CLOSURE_THRESHOLD,
    "cluster_threshold": tune.quniform(0.2, 0.8, 0.05),
    "cluster_linkage": tune.choice(["average", "single", "complete"]),
    "k_nn": tune.choice([1, 3, 5, 7]),
}

current_best_params = [{
    "names_to_cluster": DEFAULT_NAMES_TO_CLUSTER,
    "closure_threshold": DEFAULT_CLOSURE_THRESHOLD,
    "cluster_threshold": DEFAULT_CLUSTER_THRESHOLD,
    "cluster_linkage": DEFAULT_CLUSTER_LINKAGE,
    "k_nn": DEFAULT_K_NN,
}]

In [None]:
# https://docs.ray.io/en/latest/tune/api_docs/suggestion.html#tune-hyperopt
search_alg = HyperOptSearch(points_to_evaluate=current_best_params)

In [None]:
%%time
ray.shutdown()
ray.init(_redis_max_memory=4*10**9)  # give redis extra memory

callbacks = []
if wandb_api_key_file:
    callbacks.append(WandbLoggerCallback(
        project="nama",
        entity="nama",
        group="82_cluster_tune_"+given_surname,
        notes="",
        config=config._asdict(),
        api_key_file=wandb_api_key_file
    ))

result = tune.run(
    tune.with_parameters(ray_training_function,
                         swivel_model=swivel_model,
                         swivel_vocab=swivel_vocab,
                         input_names_train=input_names_train,
                         all_names=all_names,
                         pref_names=pref_names),
    resources_per_trial={"cpu": 2.0, "gpu": 0.0},
    max_concurrent_trials=6,
    config=config_params,
    search_alg=search_alg,
    num_samples=12,
    metric="f1",
    mode="max",
    checkpoint_score_attr="f1",
    time_budget_s=1*3600,
    progress_reporter=tune.JupyterNotebookReporter(
        overwrite=False,
        max_report_frequency=5*60
    ),
    callbacks=callbacks
)

### Get best model

In [None]:
# Get trial that has the highest F1
best_trial = result.get_best_trial(metric='f1', mode='max', scope='all')

In [None]:
# Parameters with the highest AUC
best_trial.config

In [None]:
print(f"Best trial final train f1: {best_trial.last_result['f1']}")
print(f"Best trial final train precision: {best_trial.last_result['precision']}")
print(f"Best trial final train recall: {best_trial.last_result['recall']}")

### Get all trials as DF

In [None]:
# All trials as pandas dataframe
df = result.results_df

In [None]:
df