In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Train a swivel model

## DEPRECATED

[Swivel](https://arxiv.org/abs/1602.02215) turns out to be another key component. It's an improvement over Glove, because it trains non-matching name pairs
to have a 0 similarity.

In [None]:
from collections import namedtuple
from datetime import datetime

import numpy as np
import pandas as pd
import math
import matplotlib.pyplot as plt
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import MinMaxScaler
import torch
from tqdm import tqdm
import umap

from nama.data.utils import load_dataset
from nama.data.filesystem import download_file_from_s3, save_file
from nama.eval import metrics
from nama.models.swivel import SwivelDataset, SwivelModel, train_swivel, get_best_swivel_matches

In [None]:
# Config

# run this on a 256GB standard instance

# TODO run both given and surname
given_surname = "given"
# given_surname = "surname"

vocab_size = 610000 if given_surname == "given" else 2100000
embed_dim = 100
n_epochs = 200
num_matches = 500
Config = namedtuple("Config", "train_path eval_path vocab_size embed_dim confidence_base confidence_scale confidence_exponent n_epochs submatrix_size lr vocab_path model_path")
config = Config(
    train_path=f"s3://fs-nama-data/2024/familysearch-names/processed/tree-hr-{given_surname}-train-augmented.csv.gz",
    eval_path=f"s3://fs-nama-data/2024/familysearch-names/processed/tree-hr-{given_surname}-train.csv.gz",
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    confidence_base=0.18 if given_surname == "given" else 0.14,
    confidence_scale=0.5 if given_surname == "given" else 0.45,
    confidence_exponent=0.3 if given_surname == "given" else 0.3,
    lr = 0.14 if given_surname == "given" else 0.24,
    n_epochs = n_epochs,
    submatrix_size = 4096,
    vocab_path=f"s3://fs-nama-data/2024/nama-data/data/models/fs-{given_surname}-swivel-vocab-{vocab_size}-augmented.csv",
    model_path=f"s3://fs-nama-data/2024/nama-data/data/models/fs-{given_surname}-swivel-model-{vocab_size}-{embed_dim}-augmented.pth",
)

In [None]:
torch.cuda.empty_cache()
print(torch.cuda.is_available())
if torch.cuda.is_available():
    print("cuda total", torch.cuda.get_device_properties(0).total_memory)
    print("cuda reserved", torch.cuda.memory_reserved(0))
    print("cuda allocated", torch.cuda.memory_allocated(0))

### Load data

In [None]:
%%time
train_path = download_file_from_s3(config.train_path) if config.train_path.startswith("s3://") else config.train_path
input_names_train, record_name_frequencies_train, candidate_names_train = load_dataset(train_path)

In [None]:
# keep only the most-frequent vocab_size names
# input_names_train, record_name_frequencies_train, candidate_names_train = \
#     select_frequent_k(input_names_train, 
#                       record_name_frequencies_train, 
#                       candidate_names_train,
#                       config.vocab_size)

In [None]:
print("input_names_train", len(input_names_train))
print("record_name_frequencies_train", sum(len(rnf) for rnf in record_name_frequencies_train))
print("total pairs", sum(freq for rnfs in record_name_frequencies_train for _, freq in rnfs))
print("candidate_names_train", len(candidate_names_train))
print("total names", len(set(input_names_train).union(set(candidate_names_train))))

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
%%time
dataset = SwivelDataset(input_names_train, record_name_frequencies_train, config.vocab_size, symmetric=True)
vocab = dataset.get_vocab()

In [None]:
%%time
# get vocab names in order by id
vocab_names = list(name_id[0] for name_id in sorted(vocab.items(), key=lambda x: x[1]))
print(len(vocab_names))

In [None]:
%%time
model = SwivelModel(len(vocab), config.embed_dim, config.confidence_base, config.confidence_scale, config.confidence_exponent)

### Initialize vectors

In [None]:
%%time
# create vectors with tfidf values
max_ngram = 3
min_df = 10
max_df = 0.8
tfidf_vectorizer = TfidfVectorizer(ngram_range=(1, max_ngram), analyzer="char_wb", min_df=min_df, max_df=max_df)
tfidf_X_train = tfidf_vectorizer.fit_transform(vocab_names)
print(tfidf_X_train.shape)

In [None]:
%%time
# reducer = TruncatedSVD(n_components=config.embed_dim)
reducer = umap.UMAP(n_components=config.embed_dim)
tfidf_X_train = reducer.fit_transform(tfidf_X_train)
print(tfidf_X_train.shape)

In [None]:
%%time
# scale to uniform [-sqrt(1/embed_dim), sqrt(1/embed_dim)]
scaler_max = math.sqrt(1 / config.embed_dim)
scaler = MinMaxScaler(feature_range=(-scaler_max, scaler_max))
tfidf_X_train = scaler.fit_transform(tfidf_X_train)

In [None]:
%%time
# init weights to be tfidf values (does this help?)
model.init_params(dataset.get_row_sums(), dataset.get_col_sums(), tfidf_X_train)

### Train

In [None]:
%%time
import torch.optim as optim

n_steps_per_epoch = 0

model.to(device)
optimizer = optim.Adagrad(model.parameters(), lr=config.lr)

all_loss_values = []
for e in tqdm(range(0, config.n_epochs)):
    print("Epoch", e, datetime.now())
    loss_values = train_swivel(model, dataset, n_steps=n_steps_per_epoch, 
                     submatrix_size=config.submatrix_size, 
                     lr=config.lr, device=device, optimizer=optimizer)
    all_loss_values.extend(loss_values)
    save_file(f"{config.model_path}.{e}",
              lambda local_out_path : torch.save(model.state_dict(), open(local_out_path, "wb")))  

#### Save vocab

In [None]:
%%time
vocab_df = pd.DataFrame(vocab.items(), columns=["name", "index"])
save_file(config.vocab_path,
          lambda local_out_path : vocab_df.to_csv(open(local_out_path, "wb"), index=False))

#### Save model

In [None]:
%%time
save_file(config.model_path,
          lambda local_out_path : torch.save(model.state_dict(), open(local_out_path, "wb")))

In [None]:
print("Vocab and model saved")

#### Reload model

In [None]:
vocab_path = download_file_from_s3(config.vocab_path) if config.vocab_path.startswith("s3://") else config.vocab_path
vocab_df = pd.read_csv(open(vocab_path, "rb"))
vocab = {name: _id for name, _id in zip(vocab_df["name"], vocab_df["index"])}
model = SwivelModel(len(vocab), config.embed_dim)
model_path = download_file_from_s3(config.model_path) if config.model_path.startswith("s3://") else config.model_path
model.load_state_dict(torch.load(open(model_path, "rb")))
model.eval()

### Eval

In [None]:
ax = plt.gca()
ax.set_ylim([0, 1.0])
plt.plot(all_loss_values[::1000])

### PR Curve

In [None]:
def filter_dataset(input_names, record_name_frequencies, candidate_names):
    input_names_filtered = []
    record_name_frequencies_filtered = []
    candidate_names = set()
    for input_name, rnfs in zip(input_names, record_name_frequencies):
        if input_name not in vocab:
            continue
        rnfs_filtered = []
        max_freq = 0
        input_name_freq = 0
        for name, freq in rnfs:
            if name not in vocab or freq == 0:
                continue
            if freq > max_freq:
                max_freq = freq
            if name == input_name:
                input_name_freq = freq
            rnfs_filtered.append((name, freq))
        # continue if there are no associated record names, or if the input name isn't more frequently associated with itself than another name
        if len(rnfs_filtered) == 0 or input_name_freq < max_freq:
            continue
        input_names_filtered.append(input_name)
        record_name_frequencies_filtered.append(rnfs_filtered)
        candidate_names.add(input_name)
        candidate_names.update([name for name, _ in rnfs_filtered])
    
    candidate_names_filtered = np.array(list(candidate_names))

    return input_names_filtered, record_name_frequencies_filtered, candidate_names_filtered

In [None]:
%%time
input_names_train_filtered, record_name_frequencies_train_filtered, candidate_names_train_filtered = \
    filter_dataset(input_names_train, record_name_frequencies_train, candidate_names_train)
print(len(input_names_train_filtered))
print(len(record_name_frequencies_train_filtered))
print(len(candidate_names_train_filtered))

In [None]:
# get best matches
# NOTE: only considers as potential matches names in candidate_names_eval, not names in input_names_eval
eval_batch_size = 1024
add_context = True
n_jobs=1
input_names_sample = input_names_train[::10]
record_name_frequencies_sample = record_name_frequencies_train[::10]
best_matches = get_best_swivel_matches(model, 
                                       vocab, 
                                       input_names_sample,
                                       candidate_names_train, 
                                       k=num_matches, 
                                       batch_size=eval_batch_size,
                                       add_context=add_context, 
                                       n_jobs=n_jobs)

In [None]:
%%time
metrics.precision_weighted_recall_curve_at_threshold(
    record_name_frequencies_sample, best_matches, min_threshold=0.05, max_threshold=1.0, step=0.05, distances=False
)

In [None]:
%%time
metrics.get_auc(
    record_name_frequencies_sample, best_matches, min_threshold=0.05, max_threshold=1.0, step=0.05, distances=False
)

### Eval on original (unaugmented) data

In [None]:
%%time
eval_path = download_file_from_s3(config.eval_path) if config.eval_path.startswith("s3://") else config.eval_path
input_names_eval, record_name_frequencies_eval, candidate_names_eval = load_dataset(eval_path)

In [None]:
%%time
input_names_eval_filtered, record_name_frequencies_eval_filtered, candidate_names_eval_filtered = \
    filter_dataset(input_names_eval, record_name_frequencies_eval, candidate_names_eval)
print(len(input_names_eval_filtered))
print(len(record_name_frequencies_eval_filtered))
print(len(candidate_names_eval_filtered))

In [None]:
# get best matches
# NOTE: only considers as potential matches names in candidate_names_eval, not names in input_names_eval
eval_batch_size = 1024
add_context = True
n_jobs=1
input_names_sample = input_names_eval_filtered[::10]
record_name_frequencies_sample = record_name_frequencies_eval_filtered[::10]
best_matches = get_best_swivel_matches(model, 
                                       vocab, 
                                       input_names_sample,
                                       candidate_names_eval_filtered, 
                                       k=num_matches, 
                                       batch_size=eval_batch_size,
                                       add_context=add_context, 
                                       n_jobs=n_jobs)

### PR Curve

In [None]:
pos = 120
input_names_sample[pos]
# for ix, name in enumerate(input_names_sample):
#     print(ix, name)

In [None]:
input_names_test = input_names_sample[pos:pos+1]
record_name_frequencies_test = record_name_frequencies_sample[pos:pos+1]
print(record_name_frequencies_test)
best_matches_test = best_matches[pos:pos+1]
print(best_matches_test)

In [None]:
metrics.precision_at_threshold(record_name_frequencies_test[0], best_matches_test[0], 0.7725, False)

In [None]:
metrics.weighted_recall_at_threshold(record_name_frequencies_test[0], best_matches_test[0], 0.7725, False)

In [None]:
%%time
metrics.precision_weighted_recall_curve_at_threshold(
    record_name_frequencies_test, best_matches_test, min_threshold=0.05, max_threshold=1.0, step=0.05, distances=False
)

In [None]:
%%time
metrics.precision_weighted_recall_curve_at_threshold(
    record_name_frequencies_sample, best_matches, min_threshold=0.05, max_threshold=1.0, step=0.05, distances=False
)

In [None]:
%%time
metrics.get_auc(
    record_name_frequencies_sample, best_matches, min_threshold=0.05, max_threshold=1.0, step=0.05, distances=False
)

## Review (don't run)

In [None]:
threshold = 0.6
for i in range(100001, 400000, 10000):
    print(i, input_names_eval[i])
    matches_above_threshold = best_matches[i][best_matches[i,:,1] > threshold]
    matched_rnfs = []
    unmatched_rnfs = []
    for rnf in record_name_frequencies_eval[i]:
        if rnf[0] in matches_above_threshold[:, 0]:
            matched_rnfs.append(rnf)
        elif rnf[1] > 0:
            unmatched_rnfs.append(rnf)
    print("  matched rnfs", matched_rnfs)
    print("  unmatched rnfs", unmatched_rnfs)
    print("  matches above threshold", len(matches_above_threshold), matches_above_threshold)

In [None]:
i = 390000

In [None]:
best_matches[i]

In [None]:
record_name_frequencies_train[input_names_train.index(input_names_eval[i])]

In [None]:
variant = "<shirley>"

In [None]:
record_name_frequencies_train[input_names_train.index(variant)]

In [None]:
record_name_frequencies_eval[input_names_eval.index(variant)]

### Test (don't run)

In [None]:
import numpy as np
from nama.models.swivel import get_swivel_embeddings
from sklearn.decomposition import TruncatedSVD
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances

In [None]:
# demo names
input_names_train = ["<john>", "<mary>"]
record_name_frequencies_train = [
    [("<johnny>", 20), ("<jonathan>", 50), ("<jon>", 30)],
    [("<marie>", 70), ("<maria>", 30)],
]
candidate_names_train = np.array(["<johnny>", "<jonathan>", "<marie>", "<maria>", "<jon>"])

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
symmetric = True
dataset = SwivelDataset(input_names_train, record_name_frequencies_train, config.vocab_size, symmetric=symmetric)
vocab = dataset.get_vocab()
print(vocab)

In [None]:
print(dataset._sparse_cooc)

In [None]:
# get vocab names in order by id
vocab_names = list(name_id[0] for name_id in sorted(vocab.items(), key=lambda x: x[1]))
print(vocab_names)

In [None]:
# create vectors with tfidf values
max_ngram = 5  # 3
min_df = 1  # 10
max_df = 1.0  # 0.5
tfidf_vectorizer = TfidfVectorizer(ngram_range=(1, max_ngram), analyzer="char_wb", min_df=min_df, max_df=max_df)
tfidf_X_train = tfidf_vectorizer.fit_transform(vocab_names)
print(tfidf_X_train.shape)

In [None]:
embed_dim = 2

In [None]:
# reduce tfidf values to embed_dim
svd = TruncatedSVD(n_components=embed_dim)
tfidf_X_train = svd.fit_transform(tfidf_X_train)
tfidf_X_train.shape

In [None]:
# create swivel model
model = SwivelModel(len(vocab), embed_dim, config.confidence_base, config.confidence_scale, config.confidence_exponent)

In [None]:
# init weights to tfidf values
# model.init_params(dataset.get_row_sums(), dataset.get_col_sums(), tfidf_X_train)
model.init_params(dataset.get_row_sums(), dataset.get_col_sums(), tfidf_X_train)

In [None]:
# device="cpu"
n_steps = 10
submatrix_size = 64
learning_rate = 0.05
loss_values = train_swivel(model, dataset, n_steps=n_steps, submatrix_size=submatrix_size, lr=learning_rate, device=device)

In [None]:
ax = plt.gca()
# ax.set_ylim([0, 0.1])
plt.plot(loss_values)

In [None]:
k = 10
add_context = True

all_names = np.array(input_names_train + candidate_names_train.tolist())
all_embeddings = get_swivel_embeddings(model, vocab, all_names, add_context=add_context)

In [None]:
print(all_names)

In [None]:
demo_name = '<john>'
demo_name_pos = 0
demo_embeddings = get_swivel_embeddings(model, vocab, [demo_name], add_context=add_context)

In [None]:
# try cosine similarity
# totals = all_embeddings.sum(axis=0)
# all_embeddings_norm = all_embeddings / totals
# demo_embeddings_norm = all_embeddings_norm[[demo_name_pos]]
# scores = cosine_similarity(demo_embeddings_norm, all_embeddings_norm)
# ixs = np.argsort(-scores)[:, :k]
# sorted_scores = scores[:, ixs[0]]
# sorted_names = all_names[ixs[0]]
# best_matches = np.dstack((sorted_names, sorted_scores))
# print("cosine_norm_0", best_matches)

# totals = demo_embeddings.sum(axis=1)
# demo_embeddings_norm = demo_embeddings / totals[:, np.newaxis]
# totals = all_embeddings.sum(axis=1)
# all_embeddings_norm = all_embeddings / totals[:, np.newaxis]
# scores = cosine_similarity(demo_embeddings_norm, all_embeddings_norm)
# ixs = np.argsort(-scores)[:, :k]
# sorted_scores = scores[:, ixs[0]]
# sorted_names = all_names[ixs[0]]
# best_matches = np.dstack((sorted_names, sorted_scores))
# print("cosine_norm_1", best_matches)

scores = cosine_similarity(demo_embeddings, all_embeddings)
ixs = np.argsort(-scores)[:, :k]
sorted_scores = scores[:, ixs[0]]
sorted_names = all_names[ixs[0]]
best_matches = np.dstack((sorted_names, sorted_scores))
print("cosine", best_matches)

In [None]:
# try euclidean similarity
totals = all_embeddings.sum(axis=0)
all_embeddings_norm = all_embeddings / totals
demo_embeddings_norm = all_embeddings_norm[[demo_name_pos]]
scores = euclidean_distances(demo_embeddings_norm, all_embeddings_norm)
ixs = np.argsort(scores)[:, :k]
sorted_scores = scores[:, ixs[0]]
sorted_names = all_names[ixs[0]]
best_matches = np.dstack((sorted_names, sorted_scores))
print("euclidean_norm_0", best_matches)

# totals = demo_embeddings.sum(axis=1)
# demo_embeddings_norm = demo_embeddings / totals[:, np.newaxis]
# totals = all_embeddings.sum(axis=1)
# all_embeddings_norm = all_embeddings / totals[:, np.newaxis]
# scores = euclidean_distances(demo_embeddings_norm, all_embeddings_norm)
# ixs = np.argsort(scores)[:, :k]
# sorted_scores = scores[:, ixs[0]]
# sorted_names = all_names[ixs[0]]
# best_matches = np.dstack((sorted_names, sorted_scores))
# print("euclidean_norm_1", best_matches)

scores = euclidean_distances(demo_embeddings, all_embeddings)
ixs = np.argsort(scores)[:, :k]
sorted_scores = scores[:, ixs[0]]
sorted_names = all_names[ixs[0]]
best_matches = np.dstack((sorted_names, sorted_scores))
print("euclidean", best_matches)

In [None]:
# plot embeddings
xs = list(x for x, _ in all_embeddings)
ys = list(y for _, y in all_embeddings)
plt.scatter(xs, ys)
for ix, name in enumerate(all_names):
    plt.annotate(name, xy=(xs[ix], ys[ix]), xytext=(5, 2),
                 textcoords='offset points', ha='right', va='bottom')

In [None]:
source_names = np.array(["tom", "dick", "harry"])
source_names_X = np.array([[1,2,3],[4,5,6],[7,8,9]])
rows = np.array([[1,2,3],[4,5,6],[7,8,9]])

In [None]:
scores = cosine_similarity(rows, source_names_X)
scores

In [None]:
sorted_scores_idx = np.argsort(scores, axis=1)
sorted_scores_idx = np.flip(sorted_scores_idx, axis=1)
sorted_scores_idx

In [None]:
sorted_scores_idx = sorted_scores_idx[:, :2]
sorted_scores_idx

In [None]:
sorted_scores = np.take_along_axis(scores, sorted_scores_idx, axis=1)
sorted_scores

In [None]:
sorted_source_names_X = source_names_X[sorted_scores_idx]
sorted_source_names_X

In [None]:
for i, (row, source_names_X) in enumerate(zip(rows, sorted_source_names_X)):
    for j, source_name_X in enumerate(source_names_X):
        if np.array_equal(row, source_name_X):
            sorted_scores[i, j] = 0
sorted_scores                        

In [None]:
re_sorted_scores_idx = np.argsort(sorted_scores, axis=1)
re_sorted_scores_idx

In [None]:
re_sorted_scores_idx = np.flip(re_sorted_scores_idx, axis=1)
re_sorted_scores_idx

In [None]:
re_sorted_scores_idx = re_sorted_scores_idx[:, :1]
re_sorted_scores_idx

In [None]:
sorted_scores = np.take_along_axis(sorted_scores, re_sorted_scores_idx, axis=1)
sorted_scores

In [None]:
sorted_scores_idx = np.take_along_axis(sorted_scores_idx, re_sorted_scores_idx, axis=1)
sorted_scores_idx

In [None]:
sorted_source_names = source_names[sorted_scores_idx]
sorted_source_names