In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from collections import namedtuple

import pandas as pd
import torch

from src.data.utils import load_train_test
from src.data.filesystem import fopen
from src.eval.metrics import avg_precision_at_threshold, avg_weighted_recall_at_threshold
from src.models.swivel import get_swivel_embeddings
from src.models.utils import add_padding
from src.models.cluster import (
    get_scores,
    generate_closures,
    generate_clusters,
    assign_names_to_clusters,
    get_best_cluster_matches,
    write_clusters,
)

In [None]:
# configure

given_surname = "given"
size = "freq"
NAMES_TO_CLUSTER = 100000
CLOSURE_THRESHOLD = 1000
CLUSTER_THRESHOLD = 0.4
CLUSTER_LINKAGE = "average"
K_NN = 1

vocab_size = 500000
embed_dim = 200
Config = namedtuple("Config", "train_path pref_path vocab_path model_path cluster_path")
config = Config(
    train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-similar-train-{size}.csv.gz",
    pref_path=f"s3://familysearch-names/processed/tree-preferred-{given_surname}-aggr.csv.gz",
    vocab_path=f"s3://nama-data/data/models/fs-{given_surname}-{size}-swivel-{vocab_size}-vocab-tfidf.csv",
    model_path=f"../../nama-data/models/fs-{given_surname}-{size}-swivel-{vocab_size}-{embed_dim}-tfidf.pt",
    # model_path=f"s3://nama-data/data/models/fs-{given_surname}-{size}-swivel-{vocab_size}-{embed_dim}-tfidf.pt",
    cluster_path=f"s3://nama-data/processed/tree-hr-{given_surname}-clusters-{size}.csv.gz"
)

In [None]:
wandb.init(
    project="nama",
    entity="nama",
    name="80_cluster",
    group=given_surname,
    notes="",
    config=config._asdict()
)

### Load data

In [None]:
[train] = load_train_test([config.train_path])

In [None]:
input_names_train, weighted_actual_names_train, candidate_names_train = train

In [None]:
# load preferred names (in frequency order)
pref_df = pd.read_csv(config.pref_path)

In [None]:
pref_df[:5]

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
vocab_df = pd.read_csv(fopen(config.vocab_path, "rb"))
print(len(vocab_df))
print(vocab_df.head(5))

In [None]:
swivel_vocab = {name: _id for name, _id in zip(vocab_df["name"], vocab_df["index"])}
print(next(iter(swivel_vocab.items())))

In [None]:
swivel_model = torch.load(fopen(config.model_path, "rb"))
swivel_model.eval()
print(swivel_model)

### Get all_names, cluster_names, all_embeddings, cluster_embeddings

In [None]:
all_names = set(input_names_train).union(set(candidate_names_train))
print(len(all_names))

In [None]:
pref_names = [add_padding(str(name)) for name in pref_df["name"] if add_padding(str(name)) in all_names]
print(len(pref_names))

In [None]:
all_names = list(all_names)

In [None]:
clustered_names = pref_names[:NAMES_TO_CLUSTER]
clustered_names[:5]

#### Compute embeddings

In [None]:
clustered_embeddings = get_swivel_embeddings(swivel_model, swivel_vocab, clustered_names).astype('float32')
all_embeddings = get_swivel_embeddings(swivel_model, swivel_vocab, all_names).astype('float32')

### Get sparse scores matrix and sorted scores (for generating closures)

In [None]:
%%time
clustered_scores_sparse, sorted_scores = get_scores(clustered_embeddings, threshold=0.4)

### Generate closures

In [None]:
%%time
id2closure, closure2ids, not_merged, max_score_not_merged = generate_closures(sorted_scores, CLOSURE_THRESHOLD)

print("total not merged", not_merged)
print("max score not merged", max_score_not_merged)
print(len(closure2ids))

### Compute clusters

In [None]:
%%time
id2cluster = generate_clusters(closure2ids, clustered_scores_sparse, clustered_names, 
                               CLUSTER_THRESHOLD, CLUSTER_LINKAGE, n_jobs=4)

### Write clusters

In [None]:
write_clusters(config.cluster_path, clustered_names, id2cluster)

### Eval

#### Assign all names to a cluster

In [None]:
name2cluster, cluster2names = assign_names_to_clusters(all_names,
                                                       all_embeddings,
                                                       id2cluster,
                                                       clustered_embeddings,
                                                       k=K_NN)

#### Get best matches

In [None]:
best_matches = get_best_cluster_matches(name2cluster, cluster2names, input_names_train)

#### Calc PR

In [None]:
precision = avg_precision_at_threshold(weighted_actual_names_train, best_matches, 0.5)
recall = avg_weighted_recall_at_threshold(weighted_actual_names_train, best_matches, 0.5)
print("precision", precision)
print("recall", recall)

In [None]:
wandb.finish()