In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import namedtuple
import re

import joblib
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
import torch
import wandb

from src.data.filesystem import fopen
from src.data.normalize import normalize_freq_names
from src.data.utils import load_dataset
from src.models.cluster import (
    read_clusters, 
    get_clusters, 
    write_clusters, 
    write_cluster_scores, 
)
from src.models.swivel import SwivelModel, get_swivel_embeddings, write_swivel_embeddings
from src.models.swivel_encoder import SwivelEncoderModel
from src.models.utils import add_padding

In [None]:
# Config

given_surname = "given"
vocab_size = 610000 if given_surname == "given" else 2100000
embed_dim = 100
MAX_SEARCH_CLUSTERS = 25
num_matches=1000
min_search_threshold = 0.05
verbose=True
n_jobs=1


Config = namedtuple("Config", [ 
                    "train_path",
                    "embed_dim",
                    "max_search_clusters",
                    "tree_freq_path",
                    "swivel_vocab_path",
                    "swivel_model_path",
                    "tfidf_path",
                    "ensemble_model_path",
                    "cluster_path",
                    "embed_out_path",
                    "cluster_embed_out_path",
                    "cluster_scores_out_path"])
config = Config(
    train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-train.csv.gz",
    embed_dim=embed_dim,
    max_search_clusters=MAX_SEARCH_CLUSTERS,
    tree_freq_path=f"s3://familysearch-names/processed/tree-preferred-{given_surname}-aggr.csv.gz",
    swivel_vocab_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-vocab-{vocab_size}-augmented.csv",
    swivel_model_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-model-{vocab_size}-{embed_dim}-augmented.pth",
    tfidf_path=f"s3://nama-data/data/models/fs-{given_surname}-tfidf.joblib",
    ensemble_model_path=f"s3://nama-data/data/models/fs-{given_surname}-ensemble-model-{vocab_size}-{embed_dim}-augmented-100.joblib",
    cluster_path=f"s3://nama-data/data/models/fs-{given_surname}-cluster-names.csv",
    embed_out_path=f"s3://nama-data/data/processed/fs-{given_surname}-embeddings-{vocab_size}-{embed_dim}-precomputed.jsonl.gz",
    cluster_embed_out_path=f"s3://nama-data/data/processed/fs-{given_surname}-cluster-embeddings-{vocab_size}-{embed_dim}-precomputed.jsonl.gz",
    cluster_scores_out_path=f"s3://nama-data/data/processed/fs-{given_surname}-cluster-scores-{vocab_size}-{embed_dim}-precomputed.jsonl.gz",
)

In [None]:
wandb.init(
    project="nama",
    entity="nama",
    name="99_precompute",
    group=given_surname,
    notes="",
    config=config._asdict()
)

### Load data

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
input_names_train, weighted_actual_names_train, candidate_names_train = load_dataset(config.train_path)

In [None]:
all_names = list(set().union(input_names_train, candidate_names_train))
print(len(all_names))

In [None]:
freq_df = pd.read_csv(config.tree_freq_path, na_filter=False)
tree_name_freq = normalize_freq_names(freq_df, is_surname=given_surname != "given", add_padding=True)
freq_df = None

In [None]:
vocab_df = pd.read_csv(fopen(config.swivel_vocab_path, "rb"))
swivel_vocab = {name: _id for name, _id in zip(vocab_df["name"], vocab_df["index"])}

In [None]:
swivel_model = SwivelModel(len(swivel_vocab), config.embed_dim)
swivel_model.load_state_dict(torch.load(fopen(config.swivel_model_path, "rb"), map_location=torch.device(device)))
swivel_model.to(device)
swivel_model.eval()

In [None]:
tfidf_vectorizer = joblib.load(fopen(config.tfidf_path, mode='rb'))

In [None]:
ensemble_model = joblib.load(fopen(config.ensemble_model_path, mode='rb'))

In [None]:
print(len(swivel_vocab))

In [None]:
name_cluster = read_clusters(config.cluster_path)
clustered_names = list(name_cluster.keys())

In [None]:
print(len(name_cluster))

In [None]:
test_names = ['<richard>', '<rickerd>', '<dallan>', '<dallin>']

### Calc embeddings

In [None]:
# test
name_embeddings = get_swivel_embeddings(model=swivel_model,
                                        vocab=swivel_vocab,
                                        names=test_names)

In [None]:
cosine_similarity(name_embeddings)

In [None]:
%%time
name_embeddings = get_swivel_embeddings(model=swivel_model,
                                        vocab=swivel_vocab,
                                        names=all_names)

In [None]:
print(len(name_embeddings))

### Write embeddings

In [None]:
%%time
write_swivel_embeddings(config.embed_out_path, all_names, name_embeddings)

In [None]:
# do we have clustered names not in vocab?
cnt = 0
for name in clustered_names:
    if name not in swivel_vocab:
        cnt += 1
        if cnt < 20:
            print(name)
cnt

In [None]:
name_embeddings = get_swivel_embeddings(model=swivel_model,
                                        vocab=swivel_vocab,
                                        names=clustered_names)

In [None]:
%%time
write_swivel_embeddings(config.cluster_embed_out_path, clustered_names, name_embeddings)

### Calc cluster scores

In [None]:
# test
name2clusters, cluster2names = get_clusters(test_names,
                                            name_cluster,
                                            swivel_model,
                                            swivel_vocab,
                                            tfidf_vectorizer,
                                            ensemble_model,
                                            tree_name_freq,
                                            max_clusters=config.max_search_clusters,
                                            k=num_matches,
                                            search_threshold=min_search_threshold,
                                            n_jobs=n_jobs,
                                            verbose=verbose,
                                           )

In [None]:
print(name2clusters)

In [None]:
print(cluster2names)

#### calc all cluster scores

In [None]:
%%time
name2clusters, cluster2names = get_clusters(all_names,
                                            name_cluster,
                                            swivel_model,
                                            swivel_vocab,
                                            tfidf_vectorizer,
                                            ensemble_model,
                                            tree_name_freq,
                                            max_clusters=config.max_search_clusters,
                                            k=num_matches,
                                            search_threshold=min_search_threshold,
                                            n_jobs=n_jobs,
                                            verbose=verbose,
                                           )

### Write cluster scores

In [None]:
write_cluster_scores(config.cluster_scores_out_path, name2clusters)

In [None]:
wandb.finish()