In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from collections import namedtuple

import pandas as pd
import ray
from ray import tune
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.integration.wandb import WandbLoggerCallback
import torch

from src.data.utils import load_datasets, select_frequent_k
from src.data.filesystem import fopen
from src.eval.metrics import avg_precision_at_threshold, avg_weighted_recall_at_threshold
from src.models.cluster import (
    get_sorted_similarities,
    generate_closures,
    generate_clusters,
    assign_names_to_clusters,
    get_best_cluster_matches,
)
from src.models.swivel import SwivelModel, get_swivel_embeddings

In [None]:
# configure
wandb_api_key_file = "../.wandb-api-key"
given_surname = "given"
DEFAULT_NAMES_TO_CLUSTER = 400000
DEFAULT_CLOSURE_THRESHOLD = 1000
DEFAULT_CLUSTER_THRESHOLD = 0.4
DEFAULT_CLUSTER_LINKAGE = "average"
DEFAULT_K_NN = 1

vocab_size = 600000 if given_surname == "given" else 2100000
embed_dim = 100
Config = namedtuple("Config", "train_path embed_dim swivel_vocab_path swivel_model_path")
config = Config(
    train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-train.csv.gz",
    embed_dim=embed_dim,
    swivel_vocab_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-vocab-{vocab_size}.csv",
    # TODO fix
    swivel_model_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-model-{vocab_size}-{embed_dim}-50.pth.26",
)

### Load data

In [None]:
[train] = load_datasets([config.train_path])

In [None]:
input_names_train, weighted_actual_names_train, candidate_names_train = train

In [None]:
vocab_df = pd.read_csv(fopen(config.swivel_vocab_path, "rb"))
print(len(vocab_df))
print(vocab_df.head(5))

In [None]:
swivel_vocab = {name: _id for name, _id in zip(vocab_df["name"], vocab_df["index"])}
print(next(iter(swivel_vocab.items())))

In [None]:
swivel_model = SwivelModel(len(swivel_vocab), config.embed_dim)
swivel_model.load_state_dict(torch.load(fopen(config.swivel_model_path, "rb")))
swivel_model.eval()
print(swivel_model)

### Optimize hyperparameters

In [None]:
def ray_training_function(config,
                          swivel_model,
                          swivel_vocab,
                          input_names_train,
                          weighted_actual_names_train,
                          candidate_names_train,
                          checkpoint_dir=None):

    # filter names to cluster from train
    input_names_cluster, weighted_actual_names_cluster, candidate_names_cluster = \
        select_frequent_k(input_names_train, 
                          weighted_actual_names_train, 
                          candidate_names_train,
                          config["n_to_cluster"])

    ### remove - validate on clustered names only
    input_names_train = input_names_cluster
    weighted_actual_names_train = weighted_actual_names_cluster
    candidate_names_train = candidate_names_cluster
    
    # get names to cluster
    cluster_names = list(set(input_names_cluster).union(set(candidate_names_cluster)))
    input_names_cluster = candidate_names_cluster = None  # release memory

    # get embeddings for names to cluster
    cluster_embeddings = get_swivel_embeddings(swivel_model, swivel_vocab, cluster_names).astype('float32')
    cluster_names = None  # release memory

    # get sorted_similarities from embeddings
    sorted_similarities = get_sorted_similarities(cluster_embeddings, threshold=0.4)

    # generate closures from sorted similarities
    _, closure2ids, _, max_score_not_merged = generate_closures(sorted_similarities, config["closure_threshold"])
    sorted_similarities = None  # release memory

    # generate clusters from closures and embeddings
    id2cluster = generate_clusters(closure2ids,
                                   cluster_embeddings,
                                   config["cluster_threshold"],
                                   config["cluster_linkage"],
                                   n_jobs=1,
                                   verbose=False,
                                  )
    closure2ids = None  # release memory

    # get all embeddings
    all_names = list(set(input_names_train).union(set(candidate_names_train)))
    all_embeddings = get_swivel_embeddings(swivel_model, swivel_vocab, all_names).astype('float32')
    candidate_names_train = swivel_model = swivel_vocab = None  # release memory

    # assign all names to clusters
    name2cluster, cluster2names = assign_names_to_clusters(all_names,
                                                           all_embeddings,
                                                           id2cluster,
                                                           cluster_embeddings,
                                                           k=config["k_nn"],
                                                           verbose=False,
                                                          )
    all_names = all_embeddings = id2cluster = cluster_embeddings = None  # release memory

    num_clusters = len(cluster2names)
    max_cluster_size = max([len(names) for names in cluster2names.values()])
    
    print("max_score_not_merged", max_score_not_merged)
    print("num_clusters", num_clusters)
    print("max_cluster_size", max_cluster_size)


    # get best matches
    best_matches = get_best_cluster_matches(name2cluster, cluster2names, input_names_train)
    name2cluster = cluster2names = input_names_train = None  # release memory

    # eval f1
    precision = avg_precision_at_threshold(weighted_actual_names_train, best_matches, 0.5)
    recall = avg_weighted_recall_at_threshold(weighted_actual_names_train, best_matches, 0.5)
    f1 = 2 * (precision * recall) / (precision + recall)

    # Report the metrics to Ray
    tune.report(f1=f1, 
                precision=precision, 
                recall=recall,
               )

In [None]:
config_params={
    "n_to_cluster": DEFAULT_NAMES_TO_CLUSTER,  # tune.qrandint(100000, vocab_size, 100000),
    "closure_threshold": DEFAULT_CLOSURE_THRESHOLD,
    "cluster_threshold": tune.quniform(0.1, 0.8, 0.05),
    "cluster_linkage": tune.choice(["average", "single", "complete", "ward"]),
    "k_nn": DEFAULT_K_NN,  # tune.choice([1, 3, 5]),
}

current_best_params = [{
    "n_to_cluster": DEFAULT_NAMES_TO_CLUSTER,
    "closure_threshold": DEFAULT_CLOSURE_THRESHOLD,
    "cluster_threshold": DEFAULT_CLUSTER_THRESHOLD,
    "cluster_linkage": DEFAULT_CLUSTER_LINKAGE,
    "k_nn": DEFAULT_K_NN,
}]

In [None]:
# https://docs.ray.io/en/latest/tune/api_docs/suggestion.html#tune-hyperopt
search_alg = HyperOptSearch(points_to_evaluate=current_best_params)

In [None]:
ray.shutdown()
ray.init()

callbacks = []
if wandb_api_key_file:
    callbacks.append(WandbLoggerCallback(
        project="nama",
        entity="nama",
        group="80_cluster_tune_"+given_surname,
        notes="",
        config=config._asdict(),
        api_key_file=wandb_api_key_file
    ))

result = tune.run(
    tune.with_parameters(ray_training_function,
                         swivel_model=swivel_model,
                         swivel_vocab=swivel_vocab,
                         input_names_train=input_names_train,
                         weighted_actual_names_train=weighted_actual_names_train,
                         candidate_names_train=candidate_names_train),
    resources_per_trial={"cpu": 2.0, "gpu": 0.0},
    max_concurrent_trials=3,
    config=config_params,
    search_alg=search_alg,
    num_samples=50,
    metric="f1",
    mode="max",
    checkpoint_score_attr="f1",
    time_budget_s=10*3600,
    progress_reporter=tune.JupyterNotebookReporter(
        overwrite=False,
        max_report_frequency=5*60
    ),
    callbacks=callbacks
)

### Get best model

In [None]:
# Get trial that has the highest F1
best_trial = result.get_best_trial(metric='f1', mode='max', scope='all')

In [None]:
# Parameters with the highest AUC
best_trial.config

In [None]:
print(f"Best trial final train f1: {best_trial.last_result['f1']}")
print(f"Best trial final train precision: {best_trial.last_result['precision']}")
print(f"Best trial final train recall: {best_trial.last_result['recall']}")

### Get all trials as DF

In [None]:
# All trials as pandas dataframe
df = result.results_df

In [None]:
df

In [None]:
df[(df["f1"] > 0.44) & (df["recall"] > 0.46)]

### Test

In [None]:
cluster_names = ["<john>", "<jonathan>", "<mary>", "<marie>", "<maria>"]
closure2ids = {"c": [0,1,2,3,4]}
cluster_embeddings = get_swivel_embeddings(swivel_model, swivel_vocab, cluster_names).astype('float32')

In [None]:
id2cluster = generate_clusters(closure2ids, cluster_embeddings, 0.6, "average", n_jobs=1)
print(id2cluster)
id2cluster = generate_clusters(closure2ids, cluster_embeddings, 0.99, "average", n_jobs=1)
print(id2cluster)
id2cluster = generate_clusters(closure2ids, cluster_embeddings, 0.01, "average", n_jobs=1)
print(id2cluster)
id2cluster = generate_clusters(closure2ids, cluster_embeddings, 0.6, "ward", n_jobs=1)
print(id2cluster)
