In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Visualize swivel vectors in 2d space

In [None]:
from collections import namedtuple

import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import umap
import wandb

from src.data.filesystem import fopen
from src.data.utils import load_dataset, select_frequent_k
from src.eval import metrics
from src.models.swivel import SwivelModel, get_swivel_embeddings, get_best_swivel_matches
from src.models.utils import remove_padding, add_padding

In [None]:
# config

plt.rcParams["figure.figsize"] = [12, 12]

given_surname = "given"
vocab_size = 610000 if given_surname == "given" else 2100000
embed_dim = 100
Config = namedtuple("Config", [
    "train_path",
    "freq_path",
    "embed_dim",
    "swivel_vocab_path",
    "swivel_model_path",
])
config = Config(
    train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-train.csv.gz",
    freq_path=f"s3://familysearch-names/processed/tree-preferred-{given_surname}-aggr.csv.gz",
    embed_dim=embed_dim,
    swivel_vocab_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-vocab-{vocab_size}-augmented.csv",
    swivel_model_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-model-{vocab_size}-{embed_dim}-augmented.pth",
)

In [None]:
wandb.init(
    project="nama",
    entity="nama",
    name="71_analyze_embeddings",
    group=given_surname,
    notes="",
    config=config._asdict(),
)

### Load data

In [None]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device="cpu"
print(device)

In [None]:
input_names_eval, weighted_actual_names_eval, candidate_names_eval = \
    load_dataset(config.train_path, is_eval=True)

In [None]:
freq_df = pd.read_csv(config.freq_path, na_filter=False)

In [None]:
print(len(freq_df))

In [None]:
vocab_df = pd.read_csv(fopen(config.swivel_vocab_path, "rb"))
swivel_vocab = {name: _id for name, _id in zip(vocab_df["name"], vocab_df["index"])}

In [None]:
swivel_model = SwivelModel(len(swivel_vocab), config.embed_dim)
swivel_model.load_state_dict(torch.load(fopen(config.swivel_model_path, "rb")))
swivel_model.eval()
swivel_model.to(device)

In [None]:
encoder_model = None
# encoder_model = SwivelEncoderModel(n_layers=encoder_layers, output_dim=config.embed_dim, device=device)
# encoder_model.load_state_dict(torch.load(fopen(config.encoder_model_path, "rb"), map_location=torch.device(device)))
# encoder_model.to(device)
# encoder_model.eval()

### PR Curve

In [None]:
# input_names_sample, weighted_actual_names_sample, candidate_names_sample = \
#     select_frequent_k(input_names_eval, weighted_actual_names_eval, candidate_names_eval, 
#                       50000)

In [None]:
freq_names = set(add_padding(name) for name in freq_df["name"][:10000])

In [None]:
input_names_sample = []
weighted_actual_names_sample = []
# weighted_actual_names_temp = []
for pos, name in enumerate(input_names_eval):
    if name not in freq_names:
        continue
#     pos = input_names_eval.index(name)
    input_names_sample.append(name)
    weighted_actual_names_sample.append(weighted_actual_names_eval[pos])
#     weighted_actual_names_temp.append(weighted_actual_names_eval[pos])
# weighted_actual_names_sample = []
# candidate_names_sample = set()
# for wans in weighted_actual_names_temp:
#     wans = sorted(wans, key=lambda wan: -wan[2])[:10]
#     sum_freq = sum(freq for _, _, freq in wans)
#     wans = [(name, freq / sum_freq, freq) for name, _, freq in wans]
#     weighted_actual_names_sample.append(wans)
#     candidate_names_sample.update([name for name, _, _ in wans])
# candidate_names_sample = np.array(list(candidate_names_sample))
candidate_names_sample = candidate_names_eval

In [None]:
print(len(input_names_eval))
print(len(candidate_names_eval))
print(len(input_names_sample))
print(len(weighted_actual_names_sample))
print(len(candidate_names_sample))

In [None]:
print(input_names_sample[:100])
print(freq_df[freq_df["name"] == 'aagje'])

In [None]:
# get best matches
# NOTE: only considers as potential matches names in candidate_names_eval, not names in input_names_eval
k = 200
eval_batch_size = 1024
add_context = True
n_jobs=1
best_matches = get_best_swivel_matches(swivel_model, 
                                       swivel_vocab, 
                                       input_names_sample,
                                       candidate_names_sample, 
                                       k, 
                                       eval_batch_size,
                                       add_context=add_context, 
                                       n_jobs=n_jobs)

In [None]:
metrics.precision_weighted_recall_curve_at_threshold(
    weighted_actual_names_sample, best_matches, min_threshold=0.01, max_threshold=1.0, step=0.05, distances=False
)

In [None]:
input_names_graphed = set()
candidate_names_graphed = set()
step = 10
total = 20
for i, (name, matches, wans) in enumerate(zip(input_names_sample, best_matches, weighted_actual_names_sample)):
    if i % step != 0:
        continue
    print(name)
    input_names_graphed.add(name)
    candidate_names_graphed.add(name)
    true_names = {name: freq for name, _, freq in wans}
    for j , (match, score) in enumerate(matches):
        print(" * " if match in true_names.keys() else "   ", j, match, score, true_names.get(match, 0))
        candidate_names_graphed.add(match)
        if j >= 20:
            break
    if i >= step * total:
        break
candidate_names_graphed = np.array(list(candidate_names_graphed))
print("input_names_graphed", len(input_names_graphed), input_names_graphed)
print(len(candidate_names_graphed))

In [None]:
candidate_names_graphed.shape

### Get embeddings

In [None]:
embeddings = get_swivel_embeddings(
    swivel_model,
    swivel_vocab,
    candidate_names_graphed,
)

### Use umap to reduce dimensionality

In [None]:
reducer = umap.UMAP()
reduced = reducer.fit_transform(embeddings)
reduced.shape

### Plot embeddings

In [None]:
xs = list(x for x, _ in reduced)
ys = list(y for _, y in reduced)
plt.scatter(xs, ys)
for ix, name in enumerate(candidate_names_graphed):
    plt.annotate(name, xy=(xs[ix], ys[ix]), xytext=(5, 2),
                 textcoords='offset points', ha='right', va='bottom')

In [None]:
wandb.finish()