In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import namedtuple
import numpy as np
import pickle
import torch
import wandb

from src.data.filesystem import fopen
from src.data.utils import load_train_test
from src.eval import metrics
from src.eval.encoder import eval_encoder
from src.models import utils
from src.models import triplet_loss

In [None]:
given_surname = "surname"
size = "freq"
Config = namedtuple("Config", "train_path test_path near_negatives_path autoencoder_model_path triplet_model_path")
config = Config(
    train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-similar-train-{size}.csv.gz",
    test_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-similar-test-{size}.csv.gz",
    near_negatives_path=f"s3://nama-data/data/processed/tree-hr-{given_surname}-near-negatives-{size}.csv.gz",
    autoencoder_model_path=f"s3://nama-data/data/models/fs-{size}-autoencoder-bilstm-100-512.pth",
    triplet_model_path=f"s3://nama-data/data/models/fs-{size}-triplet-bilstm-100-512-40-05.pth"
)

In [None]:
wandb.init(
    project="nama",
    entity="nama",
    name="51_autoencoder_triplet",
    group=given_surname,
    notes="",
    config=config._asdict()
)

In [None]:
MAX_NAME_LENGTH = 30
char_to_idx_map, idx_to_char_map = utils.build_token_idx_maps()

### Load model

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
model = torch.load(fopen(config.autoencoder_model_path, "rb"), map_location=torch.device(device))
model.device = device

### Load data for fine-tuning and evaluation

In [None]:
train, test = load_train_test([config.train_path, config.test_path])

input_names_train, weighted_actual_names_train, candidate_names_train = train
input_names_test, weighted_actual_names_test, candidate_names_test = test

actual_names_train = [[name for name, _, _ in name_weights] for name_weights in weighted_actual_names_train]
actual_names_test = [[name for name, _, _ in name_weights] for name_weights in weighted_actual_names_test]

candidate_names_all = np.concatenate((candidate_names_train, candidate_names_test))

### Fine-tune

In [None]:
near_negatives_train = triplet_loss.get_near_negatives(
    input_names_train, weighted_actual_names_train, candidate_names_train, k=50
)

In [None]:
# save near_negatives
with fopen(config.near_negatives_path, "wb") as f:
    pickle.dump(near_negatives_train, f)

In [None]:
# load near_negatives
with fopen(config.near_negatives_path, "rb") as f:
    near_negatives_train = pickle.load(f)

In [None]:
print(len(input_names_train))
print(len(candidate_names_train))

In [None]:
print(len(input_names_test))
print(len(candidate_names_test))

In [None]:
batch_size = 512

In [None]:
triplet_loss.train_triplet_loss(
    model,
    input_names_train,
    weighted_actual_names_train,
    near_negatives_train,
    input_names_test,
    weighted_actual_names_test,
    candidate_names_test,
    candidate_names_train,
    candidate_names_all,
    char_to_idx_map,
    MAX_NAME_LENGTH,
    40,
    batch_size,
    0.05,
    100,
    device,
)

In [None]:
torch.save(model, fopen(config.triplet_model_path, "wb"))

In [None]:
model = torch.load(fopen(config.triplet_model_path, "rb"), map_location=torch.device(device))

## Evaluation

In [None]:
# Get embeddings for train candidate names
candidate_names_train_X, _ = utils.convert_names_to_model_inputs(
    candidate_names_train, char_to_idx_map, MAX_NAME_LENGTH
)
# Get Embeddings for the names from the encoder
candidate_names_train_encoded = eval_encoder(model, candidate_names_train_X, batch_size)

In [None]:
# Get embeddings for test input names
input_names_test_X, _ = utils.convert_names_to_model_inputs(input_names_test, char_to_idx_map, MAX_NAME_LENGTH)
# Get Embeddings for the names from the encoder
input_names_test_encoded = eval_encoder(model, input_names_test_X, batch_size)

In [None]:
# Get embeddings for test candidate names
candidate_names_test_X, _ = utils.convert_names_to_model_inputs(candidate_names_test, char_to_idx_map, MAX_NAME_LENGTH)
candidate_names_test_encoded = eval_encoder(model, candidate_names_test_X, batch_size)

In [None]:
candidate_names_all_encoded = np.vstack((candidate_names_train_encoded, candidate_names_test_encoded))

In [None]:
input_names_test_encoded.shape

### Test

In [None]:
# matric=euclidean is what TripletMarginLoss optimizes by default
# but this means that scores will be in terms of distance, not similarity, so take this into account when computing PR at thresholds
k = 100
best_matches = utils.get_best_matches(
    input_names_test_encoded, candidate_names_all_encoded, candidate_names_all, num_candidates=k, metric="euclidean"
)
print(best_matches.shape)
print(best_matches[0, 0, 0])
print(best_matches[0, 0, 1])

In [None]:
best_matches_names = best_matches[:, :, 0]
print(best_matches_names.shape)

### PR Curve at k

In [None]:
metrics.precision_recall_curve_at_k(actual_names_test, best_matches_names, k)

In [None]:
metrics.avg_precision_at_threshold(weighted_actual_names_test, best_matches, 0.145, distances=True)

In [None]:
metrics.avg_weighted_recall_at_threshold(weighted_actual_names_test, best_matches, 0.145, distances=True)

### PR Curve at threshold

In [None]:
# minimum score threshold to test
metrics.precision_weighted_recall_curve_at_threshold(
    weighted_actual_names_test, best_matches, min_threshold=0.01, max_threshold=1.0, step=0.005, distances=True
)

### AUC

In [None]:
metrics.get_auc(
    weighted_actual_names_test, best_matches, min_threshold=0.01, max_threshold=1.0, step=0.005, distances=True
)

### Precision and recall at a specific threshold

In [None]:
threshold = 0.14
print(
    "precision",
    metrics.avg_precision_at_threshold(weighted_actual_names_test, best_matches, threshold=threshold, distances=True),
)
print(
    "recall",
    metrics.avg_weighted_recall_at_threshold(
        weighted_actual_names_test, best_matches, threshold=threshold, distances=True
    ),
)

In [None]:
wandb.finish()