In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Analyze swivel scores

In [15]:
from collections import namedtuple, defaultdict
# 
# import jellyfish
# import matplotlib.pyplot as plt
# from mpire import WorkerPool
# import pandas as pd
# import random
# from sklearn.model_selection import train_test_split

import numpy as np
import torch
# 
# from src.data.filesystem import fopen
# from src.data.utils import load_dataset
# from src.eval.utils import similars_to_ndarray
from nama.models.swivel import SwivelModel, get_best_swivel_matches
# from src.models.utils import remove_padding

from nama.data.filesystem import download_file_from_s3, save_file
from nama.data.utils import read_csv

In [3]:
# config

# TODO run both given and surname
given_surname = "given"
# given_surname = "surname"

vocab_size = 610000 if given_surname == "given" else 2100000
embed_dim = 100
Config = namedtuple("Config", [
    "std_path",
    "embed_dim",
    "swivel_vocab_path",
    "swivel_model_path",
])
config = Config(
    std_path = f"../references/std_{given_surname}.txt",
    embed_dim=embed_dim,
    swivel_vocab_path=f"s3://fs-nama-data/2024/nama-data/data/models/fs-{given_surname}-swivel-vocab-{vocab_size}-augmented.csv",
    swivel_model_path=f"s3://fs-nama-data/2024/nama-data/data/models/fs-{given_surname}-swivel-model-{vocab_size}-{embed_dim}-augmented.pth",
    
)

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
torch.cuda.empty_cache()
print(torch.cuda.is_available())
if torch.cuda.is_available():
    print("cuda total", torch.cuda.get_device_properties(0).total_memory)
    print("cuda reserved", torch.cuda.memory_reserved(0))
    print("cuda allocated", torch.cuda.memory_allocated(0))

cuda:0
True
cuda total 8141471744
cuda reserved 0
cuda allocated 0


In [5]:
# load buckets
bucket_names = defaultdict(set)
name_buckets = defaultdict(set)
with open(config.std_path, 'rt') as f:
    for line in f.readlines():
        names = line.strip().replace(':', ' ').split(' ')
        bucket_name = names[0]
        for name in names:
            name = name.strip()
            if len(name) == 0:
                continue
            bucket_names[bucket_name].add(name)
            name_buckets[name].add(bucket_name)
print(len(bucket_names), len(name_buckets))

8878 95997


In [11]:
swivel_vocab_path = download_file_from_s3(config.swivel_vocab_path) if config.swivel_vocab_path.startswith("s3://") else config.swivel_vocab_path
vocab_df = read_csv(swivel_vocab_path)
swivel_vocab = {name: _id for name, _id in zip(vocab_df["name"], vocab_df["index"])}
print(len(swivel_vocab))

610000


In [12]:
swivel_model_path = download_file_from_s3(config.swivel_model_path) if config.swivel_model_path.startswith("s3://") else config.swivel_model_path
swivel_model = SwivelModel(len(swivel_vocab), config.embed_dim)
swivel_model.load_state_dict(torch.load(swivel_model_path, map_location=torch.device(device)))
swivel_model.to(device)
swivel_model.eval()

SwivelModel(
  (wi): Embedding(610000, 100)
  (wj): Embedding(610000, 100)
  (bi): Embedding(610000, 1)
  (bj): Embedding(610000, 1)
)

## Calculate swivel scores for names in the same bucket

In [22]:
scores = []
for ix, (bucket_name, names) in enumerate(bucket_names.items()):
    if bucket_name not in swivel_vocab:
        print("bucket name missing", bucket_name)
        continue
    for name in names:
        if name not in swivel_vocab:
            print("name missing", name)
            continue
        if name == bucket_name:
            continue
        swivel_scores = get_best_swivel_matches(model=swivel_model, 
                                                vocab=swivel_vocab, 
                                                input_names=np.array([bucket_name]),
                                                candidate_names=np.array([name]), 
                                                encoder_model=None,
                                                k=1, 
                                                batch_size=1000,
                                                add_context=True,
                                                progress_bar=False,
                                                n_jobs=1)
        print(bucket_name, swivel_scores[0][0])
    if ix > 10:
        break

aad ['adi' 0.4426210658829054]
aaffien ['affie' 0.4656717356705998]
aafje ['aaftje' 0.6114681089816434]
name missing afje
aafje ['aefje' 0.651685494770062]
aagaard ['agard' 0.6384286223835175]
aage ['ouwe' 0.038618646644627275]
aaltje ['aalje' 0.5886518828655694]
aaltje ['aletje' 0.5381761533546043]
aaltje ['aaltjen' 0.7828434022831788]
name missing aeltijen
aaltje ['altino' 0.29617106886264494]
aaltje ['alchy' 0.1625622347454544]
aaltje ['aeltijn' 0.5674344886632808]
aaltje ['aaltji' 0.7765158614449673]
aaltje ['aelke' 0.4292968461811517]
aaltje ['aeltje' 0.722470821812775]
aaltje ['elte' 0.28094489217194796]
aaltje ['altin' 0.43206147550277313]
aaltje ['aeltien' 0.6104051100538883]
aaltje ['aal' 0.14934069488018445]
aaltje ['alche' 0.22591997856881682]
aaltje ['aleka' 0.19439982784549784]
aaltje ['eltjen' 0.4339137976514315]
aaltje ['aaltijn' 0.6836955632447763]
aaltje ['eltje' 0.492060944802463]
aaltje ['aeltge' 0.4980095389975159]
aaltje ['altjen' 0.5770803550313351]
aaltje ['aelke

In [None]:
_, input_names_eval_sample, _, weighted_actual_names_eval_sample = \
    train_test_split(input_names_eval, weighted_actual_names_eval, test_size=sample_size)
candidate_names_eval_sample = candidate_names_eval

In [None]:
print("input_names_eval_sample", len(input_names_eval_sample))
print("weighted_actual_names_eval_sample", len(weighted_actual_names_eval_sample))
print("candidate_names_eval_sample", len(candidate_names_eval_sample))

In [None]:
def calc_similarity_to(name):
    name = remove_padding(name)

    def calc_similarity(row):
        cand_name = remove_padding(row[0])
        dist = jellyfish.levenshtein_distance(name, cand_name)
        return 1 - (dist / max(len(name), len(cand_name)))

    return calc_similarity

In [None]:
def get_similars(shared, names, _=None):
    candidate_names_test, k = shared
    
    def get_similars_for_name(name):
        scores = np.apply_along_axis(calc_similarity_to(name), 1, candidate_names_test[:, None])

        # sorted_scores_idx = np.argsort(scores)[::-1][:k]
        partitioned_idx = np.argpartition(scores, -k)[-k:]
        sorted_partitioned_idx = np.argsort(scores[partitioned_idx])[::-1]
        sorted_scores_idx = partitioned_idx[sorted_partitioned_idx]

        candidate_names = candidate_names_test[sorted_scores_idx]
        candidate_scores = scores[sorted_scores_idx]

        return list(zip(candidate_names, candidate_scores))
    
    result = []
    for name in names:
        result.append(get_similars_for_name(name))
    return result

In [None]:
def create_batches(names, batch_size):
    batches = []
    for ix in range(0, len(names), batch_size):
        # batches are tuples to keep mpire from expanding the batch 
        batches.append((names[ix:ix + batch_size], ix))
    return batches

In [None]:
swivel_names_scores = get_best_swivel_matches(model=swivel_model, 
                                              vocab=swivel_vocab, 
                                              input_names=input_names_eval_sample,
                                              candidate_names=candidate_names_eval_sample, 
                                              encoder_model=None,
                                              k=num_matches, 
                                              batch_size=batch_size,
                                              add_context=True,
                                              n_jobs=1)

In [None]:
print(sum(len(names_scores) for names_scores in swivel_names_scores))

In [None]:
input_names_batches = create_batches(input_names_eval_sample, batch_size=batch_size)
with WorkerPool(
    shared_objects=(candidate_names_eval_sample, num_matches),
) as pool:
    lev_names_scores = pool.map(get_similars, input_names_batches, progress_bar=True)
# flatten
lev_names_scores = [name_score for batch in lev_names_scores for name_score in batch]
# convert to ndarray
lev_names_scores = similars_to_ndarray(lev_names_scores)

In [None]:
print(sum(len(names_scores) for names_scores in lev_names_scores))

In [None]:
# find pairs in both with score above a threshold
swivel_threshold = 0.45
lev_threshold = 0.55
sample_rate = 0.01
xs = []
ys = []
cs = []
xs_pos = []
ys_pos = []
xs_neg = []
ys_neg = []
weights = []
actual_score_counts = []
swivel_score_counts = []
lev_score_counts = []
all_candidate_names = set(candidate_names_eval_sample)
for input_name, wans, swivels, levs in \
    zip(input_names_eval_sample, weighted_actual_names_eval_sample, swivel_names_scores, lev_names_scores):
    # actuals - ensure names are in all_candidate_names
    actual_weights = {name: weight for name, weight, _ in wans if name in all_candidate_names}
    actual_score_counts.append(len(actual_weights))
    # swivel
    swivel_scores = {name: score for name, score in swivels if score >= swivel_threshold}
    swivel_names = set(swivel_scores.keys())
    swivel_score_counts.append(len(swivel_scores))
    # levenshtein
    lev_scores = {name: score for name, score in levs if score >= lev_threshold}
    lev_names = set(lev_scores.keys())
    lev_score_counts.append(len(lev_scores))
    
    # count various scores
    candidate_names = swivel_names.intersection(lev_names)
    for candidate_name in candidate_names:
        if random.random() > sample_rate:
            continue
        swivel_score = swivel_scores[candidate_name]
        lev_score = lev_scores[candidate_name]
        xs.append(swivel_score)
        ys.append(lev_score)
        if candidate_name in actual_weights:
            cs.append('green')
            xs_pos.append(swivel_score)
            ys_pos.append(lev_score)
            weights.append(actual_weights[candidate_name])
            del actual_weights[candidate_name]
        else:
            cs.append('red')
            xs_neg.append(swivel_score)
            ys_neg.append(lev_score)
#     for name in actual_weights.keys():
#         if name not in swivel_names:
#             print("swivel", input_name, name)
#         if name not in lev_names:
#             print("lev", input_name, name)

In [None]:
print(sum(swivel_score_counts), sum(lev_score_counts))
print(len(cs), len([c for c in cs if c == 'green']), sum(actual_score_counts)*sample_rate)

In [None]:
figure, ax = plt.subplots(1, 1, figsize=(20, 15))
ax.set_title("Swivel vs Levenshtein score")
ax.scatter(x=xs, y=ys, c=cs)
plt.xlabel("swivel score")
plt.ylabel("levenshtein score")
plt.xlim([swivel_threshold, 1.0])
plt.ylim([lev_threshold, 1.0])
plt.show()

In [None]:
figure, ax = plt.subplots(1, 1, figsize=(20, 15))
ax.set_title("Swivel vs Levenshtein score - positive only")
ax.scatter(x=xs_pos, y=ys_pos)
plt.xlabel("swivel score")
plt.ylabel("levenshtein score")
plt.xlim([swivel_threshold, 1.0])
plt.ylim([lev_threshold, 1.0])
plt.show()

In [None]:
figure, ax = plt.subplots(1, 1, figsize=(20, 15))
ax.set_title("Swivel vs Levenshtein score - negative only")
ax.scatter(x=xs_neg, y=ys_neg)
plt.xlabel("swivel score")
plt.ylabel("levenshtein score")
plt.xlim([swivel_threshold, 1.0])
plt.ylim([lev_threshold, 1.0])
plt.show()

In [None]:
figure, ax = plt.subplots(1, 1, figsize=(20, 15))
ax.set_title("Weights")
ax.hist(x=weights, bins=100)
plt.show()

In [None]:
figure, ax = plt.subplots(1, 1, figsize=(20, 15))
ax.set_title("Swivel score counts")
ax.hist(x=swivel_score_counts, bins=100)
plt.show()

In [None]:
figure, ax = plt.subplots(1, 1, figsize=(20, 15))
ax.set_title("Levenshtein score counts")
ax.hist(x=lev_score_counts, bins=100)
plt.show()