In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from collections import namedtuple

import pandas as pd
import torch
import wandb

from src.data.utils import load_datasets, select_frequent_k
from src.data.filesystem import fopen
from src.eval.metrics import (
    precision_weighted_recall_curve_at_threshold, 
    avg_precision_at_threshold,
    avg_weighted_recall_at_threshold,
    get_auc,
)
from src.models.swivel import SwivelModel, get_swivel_embeddings
from src.models.cluster import (
    get_sorted_similarities,
    generate_closures,
    generate_clusters,
    get_clusters,
    get_best_cluster_matches,
    write_clusters,
)

In [None]:
# configure

given_surname = "given"
NAMES_TO_CLUSTER = 600000
CLUSTER_ONLY_INPUT_NAMES = False
CLOSURE_THRESHOLD = 20000
SEARCH_THRESHOLD = 0.6
MAX_SEARCH_CLUSTERS = 20
EVAL_SEARCH_CLUSTERS = 10
CLUSTER_ALGO = "agglomerative"
CLUSTER_THRESHOLD = 0.8
CLUSTER_LINKAGE = "average"
names_to_test = 200000

vocab_size = 600000 if given_surname == "given" else 2100000
embed_dim = 100
Config = namedtuple("Config", "train_path embed_dim swivel_vocab_path swivel_model_path cluster_path")
config = Config(
    train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-train.csv.gz",
    embed_dim=embed_dim,
    swivel_vocab_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-vocab-{vocab_size}.csv",
    # TODO fix
    swivel_model_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-model-{vocab_size}-{embed_dim}-50.pth",
    cluster_path=f"s3://nama-data/processed/tree-hr-{given_surname}-clusters-{vocab_size}-{embed_dim}.json.gz"
)

In [None]:
wandb.init(
    project="nama",
    entity="nama",
    name="81_cluster",
    group=given_surname,
    notes="",
    config=config._asdict()
)

### Load data

In [None]:
[train] = load_datasets([config.train_path])
input_names_train, weighted_actual_names_train, candidate_names_train = train

vocab_df = pd.read_csv(fopen(config.swivel_vocab_path, "rb"))
swivel_vocab = {name: _id for name, _id in zip(vocab_df["name"], vocab_df["index"])}

swivel_model = SwivelModel(len(swivel_vocab), config.embed_dim)
swivel_model.load_state_dict(torch.load(fopen(config.swivel_model_path, "rb")))
swivel_model.eval()

### Get names to cluster and embeddings

In [None]:
input_names_cluster, weighted_actual_names_cluster, candidate_names_cluster = \
    select_frequent_k(input_names_train,
                      weighted_actual_names_train,
                      candidate_names_train,
                      NAMES_TO_CLUSTER,
                      input_names_only=CLUSTER_ONLY_INPUT_NAMES)
cluster_names = list(set(input_names_cluster).union(set(candidate_names_cluster)))
cluster_embeddings = get_swivel_embeddings(swivel_model, swivel_vocab, cluster_names).astype('float32')
print("cluster_names", len(cluster_names))

In [None]:
# test names
input_names_test, weighted_actual_names_test, candidate_names_test = \
    select_frequent_k(input_names_train,
                      weighted_actual_names_train,
                      candidate_names_train,
                      names_to_test)    
test_names = list(set(input_names_test).union(set(candidate_names_test)))
test_embeddings = get_swivel_embeddings(swivel_model, swivel_vocab, test_names).astype('float32')

print("test_names", len(test_names))
print("test_input_names", len(input_names_test))

### Generate closures

In [None]:
%%time
sorted_similarities = get_sorted_similarities(cluster_embeddings, threshold=0.4)

In [None]:
%%time
_, closure2ids, not_merged, max_score_not_merged = generate_closures(sorted_similarities, CLOSURE_THRESHOLD)

print("total scores", len(sorted_similarities))
print("total not merged", not_merged)
print("max score not merged", max_score_not_merged)
print("num closures", len(closure2ids))

### Compute clusters

In [None]:
%%time
n_jobs = 4
id2cluster = generate_clusters(closure2ids,
                               cluster_embeddings,
                               cluster_algo=CLUSTER_ALGO,
                               cluster_linkage=CLUSTER_LINKAGE,
                               cluster_threshold=CLUSTER_THRESHOLD,
                               n_jobs=n_jobs,
                              )
print("clusters", len(set(id2cluster.values())))

### Write clusters

In [None]:
%%time
name2clusters, cluster2names = get_clusters(cluster_names,
                                            cluster_embeddings,
                                            id2cluster,
                                            cluster_embeddings,
                                            k=1024,
                                            max_clusters=MAX_SEARCH_CLUSTERS,
                                            verbose=True,
                                           )

In [None]:
write_clusters(config.cluster_path, name2clusters)

### Eval

#### Get a list of (cluster_id, score) tuples for each name

In [None]:
%%time
EVAL_SEARCH_CLUSTERS=10
if len(test_names) != len(cluster_names) or EVAL_SEARCH_CLUSTERS != MAX_SEARCH_CLUSTERS:
    name2clusters, cluster2names = get_clusters(test_names,
                                                test_embeddings,
                                                id2cluster,
                                                cluster_embeddings,
                                                k=100,
                                                max_clusters=EVAL_SEARCH_CLUSTERS,
                                                verbose=True,
                                               )

#### Get best matches

In [None]:
%%time
best_matches = get_best_cluster_matches(name2clusters, cluster2names, input_names_test)

#### Calc F1 @ search threshold

In [None]:
%%time
SEARCH_THRESHOLD=0.5
precision = avg_precision_at_threshold(weighted_actual_names_test, best_matches, SEARCH_THRESHOLD)
recall = avg_weighted_recall_at_threshold(weighted_actual_names_test, best_matches, SEARCH_THRESHOLD)
f1 = 2 * (precision * recall) / (precision + recall)
print("f1", f1, "precision", precision, "recall", recall)

#### Show PR curve

In [None]:
%%time
precision_weighted_recall_curve_at_threshold(
    weighted_actual_names_test, best_matches, min_threshold=0.01, max_threshold=1.0, step=0.05, distances=False
)

In [None]:
%%time
print(get_auc(
    weighted_actual_names_test, best_matches, min_threshold=0.01, max_threshold=1.0, step=0.05, distances=False
))

In [None]:
wandb.finish()