In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from collections import namedtuple
import re

import pandas as pd
import torch
import wandb

from src.data.filesystem import fopen
from src.data.utils import load_datasets, load_nicknames
from src.models.cluster import (
    read_clusters, 
    get_clusters, 
    write_clusters, 
    write_cluster_scores, 
    merge_name2clusters
)
from src.models.swivel import SwivelModel, get_swivel_embeddings, write_swivel_embeddings
from src.models.swivel_encoder import SwivelEncoderModel
from src.models.utils import add_padding

In [None]:
# Config

given_surname = "given"
vocab_size = 600000 if given_surname == "given" else 2100000
encoder_vocab_size = vocab_size
embed_dim = 100
MAX_SEARCH_CLUSTERS = 25

Config = namedtuple("Config", [ 
                    "pref_name_path",
                    "train_path",
                    "test_path",
                    "swivel_vocab_path",
                    "swivel_model_path",
#                     "encoder_model_path",
                    "nicknames_path",
                    "cluster_path",
                    "embed_dim",
                    "max_search_clusters",
                    "embed_out_path",
                    "cluster_scores_out_path"])
config = Config(
    pref_name_path=f"s3://familysearch-names/processed/tree-preferred-{given_surname}-aggr.csv.gz",
    train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-train.csv.gz",
    test_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-test.csv.gz",
    nicknames_path="../references/givenname_nicknames.csv",
    embed_dim=embed_dim,
    max_search_clusters=MAX_SEARCH_CLUSTERS,
    swivel_vocab_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-vocab-{vocab_size}.csv",
    swivel_model_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-model-{vocab_size}-{embed_dim}.pth",
#     encoder_model_path=f"s3://nama-data/data/models/fs-{given_surname}-encoder-model-{encoder_vocab_size}-{embed_dim}.pth",
    cluster_path=     f"s3://nama-data/data/models/fs-{given_surname}-clusters-{vocab_size}-{embed_dim}.csv.gz",
    embed_out_path=   f"s3://nama-data/data/processed/fs-{given_surname}-embeddings-{vocab_size}-{embed_dim}-precomputed.jsonl.gz",
    cluster_scores_out_path=f"s3://nama-data/data/processed/fs-{given_surname}-cluster-scores-{vocab_size}-{embed_dim}-precomputed.jsonl.gz",
)

In [None]:
wandb.init(
    project="nama",
    entity="nama",
    name="99_precompute",
    group=given_surname,
    notes="",
    config=config._asdict()
)

### Load data

In [None]:
pref_names = pd.read_csv(config.pref_name_path, na_filter=False)["name"].tolist()

In [None]:
alpha = re.compile("[a-z]+")
pref_names = [add_padding(name) for name in pref_names if name and alpha.fullmatch(name)]

In [None]:
train, test = load_datasets([config.train_path, config.test_path])

input_names_train, weighted_actual_names_train, candidate_names_train = train
input_names_test, weighted_actual_names_test, candidate_names_test = test

In [None]:
names = list(set().union(pref_names, 
                         input_names_train, 
                         candidate_names_train, 
                         input_names_test, 
                         candidate_names_test, 
                        ))
print(len(names))

In [None]:
name2variants = load_nicknames(config.nicknames_path)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
swivel_vocab_df = pd.read_csv(fopen(config.swivel_vocab_path, "rb"))
print(swivel_vocab_df.head(5))

In [None]:
swivel_vocab = {name: _id for name, _id in zip(swivel_vocab_df["name"], swivel_vocab_df["index"])}
print(swivel_vocab["<john>"])

In [None]:
swivel_model = SwivelModel(len(swivel_vocab), embedding_dim=config.embed_dim)
swivel_model.load_state_dict(torch.load(fopen(config.swivel_model_path, "rb"), map_location=torch.device(device)))
swivel_model.to(device)
swivel_model.eval()
print(swivel_model)

In [None]:
encoder_model=None
# encoder_model = SwivelEncoderModel(output_dim=config.embed_dim, device=device)
# encoder_model.load_state_dict(torch.load(fopen(config.encoder_model_path, "rb"), map_location=torch.device(device)))
# encoder_model.to(device)
# encoder_model.eval()
# print(encoder_model)

In [None]:
name2cluster = read_clusters(config.cluster_path)
clustered_names = list(name2cluster.keys())
clustered_name2cluster_id =list(name2cluster.values())
print("cluster_names", len(clustered_names))

### Calc embeddings

In [None]:
%%time
clustered_name_embeddings = get_swivel_embeddings(model=swivel_model,
                                                  vocab=swivel_vocab,
                                                  names=clustered_names,
                                                  encoder_model=encoder_model)

In [None]:
%%time
name_embeddings = get_swivel_embeddings(model=swivel_model,
                                        vocab=swivel_vocab,
                                        names=names,
                                        encoder_model=encoder_model)

In [None]:
print(names[:10])
print(name_embeddings[:10])

### Write embeddings

In [None]:
%%time
write_swivel_embeddings(config.embed_out_path, names, name_embeddings)

### Calc cluster scores

#### test

In [None]:
test_names = names[:10]
test_embeddings = name_embeddings[:10]
name2clusters, cluster2names = get_clusters(test_names,
                                            test_embeddings,
                                            clustered_name2cluster_id,
                                            clustered_name_embeddings,
                                            k=1024,
                                            max_clusters=config.max_search_clusters,
                                            verbose=True,
                                           )

In [None]:
print(name2clusters)

In [None]:
print(cluster2names)

#### calc all cluster scores

In [None]:
%%time
name2clusters, _ = get_clusters(names,
                                name_embeddings,
                                clustered_name2cluster_id,
                                clustered_name_embeddings,
                                k=1024,
                                max_clusters=config.max_search_clusters,
                                verbose=True,
                               )

### Write cluster scores

#### add nicknames

In [None]:
def get_variant_cluster_scores(name2variants, name2clusters):
    variant_cluster_scores = {}
    for name, variants in name2variants.items():
        # gather cluster scores for each variant
        cluster_scores = {}
        for variant in variants:
            cluster_scores[variant] = name2clusters[variant]
        variant_cluster_scores[name] = merge_name2clusters(cluster_scores)    
    return variant_cluster_scores

In [None]:
if given_surname == "given":
    variant_cluster_scores = get_variant_cluster_scores(name2variants, name2clusters)
    for name, cluster_scores in variant_cluster_scores.items():
        # replace single-name cluster-scores with the merged cluster-scores for each variant
        name2clusters[name] = cluster_scores

In [None]:
write_cluster_scores(config.cluster_scores_out_path, name2clusters)

In [None]:
wandb.finish()

### Test

In [None]:
given_surname = "given"
name2clusters = {"john": [("1", 0.9), ("2", 0.6), ("3", 0.5)],
                 "mary": [("4", 1.0), ("5", 0.8), ("6", 0.3)],
                 "johnny": [("2", 1.0), ("7", 0.5), ("3", 0.2)]
                }
name2variants = {"john": set(["john", "johnny"]), 
                 "johnny": set(["john", "johnny"])
                }

In [None]:
variant_cluster_scores = get_variant_cluster_scores(name2variants, name2clusters)
for name, cluster_scores in variant_cluster_scores.items():
    # replace single-name cluster-scores with the merged cluster-scores for each variant
    name2clusters[name] = cluster_scores
print(name2clusters)