In [None]:
%load_ext autoreload
%autoreload 2

# Generate a triplet-loss model based upon the autoencoder and near-negatives

In [None]:
from collections import namedtuple
import pickle
import torch
import wandb

from src.data.filesystem import fopen
from src.data.utils import load_train_test
from src.eval import metrics
from src.models.autoencoder import get_best_autoencoder_matches
from src.models.triplet_loss import get_near_negatives, train_triplet_loss

In [None]:
given_surname = "given"
size = "freq"
Config = namedtuple("Config", "train_path test_path near_negatives_path autoencoder_model_path triplet_model_path")
config = Config(
#     train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-similar-{size}.csv.gz",
#     test_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-similar-{size}.csv.gz",
    train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-similar-train-{size}.csv.gz",
    test_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-similar-test-{size}.csv.gz",
    near_negatives_path=f"s3://nama-data/data/processed/tree-hr-{given_surname}-near-negatives-{size}.csv.gz",
    autoencoder_model_path=f"s3://nama-data/data/models/fs-{given_surname}-freq-autoencoder-bilstm-100-512.pth",
    triplet_model_path=f"s3://nama-data/data/models/fs-{given_surname}-{size}-triplet-bilstm-100-512-40-05.pth"
)

In [None]:
wandb.init(
    project="nama",
    entity="nama",
    name="51_autoencoder_triplet",
    group=given_surname,
    notes="",
    config=config._asdict()
)

### Load autoencoder model

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
model = torch.load(fopen(config.autoencoder_model_path, "rb"), map_location=torch.device(device))
model.device = device

### Load data for fine-tuning and evaluation

In [None]:
train, test = load_train_test([config.train_path, config.test_path])

input_names_train, weighted_actual_names_train, candidate_names_train = train
input_names_test, weighted_actual_names_test, candidate_names_test = test

### Compute near-negatives

In [None]:
near_negatives_train = get_near_negatives(
    input_names_train, weighted_actual_names_train, candidate_names_train, k=50
)

In [None]:
# save near_negatives
with fopen(config.near_negatives_path, "wb") as f:
    pickle.dump(near_negatives_train, f)

In [None]:
# load near_negatives
with fopen(config.near_negatives_path, "rb") as f:
    near_negatives_train = pickle.load(f)

In [None]:
for ix, (key, values) in enumerate(near_negatives_train.items()):
    if ix > 3:
        break
    print(key, ":", " ".join(values))

In [None]:
ix = input_names_train.index("<ada>")
weighted_actual_names_train[ix]

In [None]:
print(len(input_names_train))
print(len(candidate_names_train))

In [None]:
print(len(input_names_test))
print(len(candidate_names_test))

In [None]:
batch_size = 512

In [None]:
train_triplet_loss(
    model,
    input_names_train,
    weighted_actual_names_train,
    near_negatives_train,
    input_names_test,
    weighted_actual_names_test,
    candidate_names_test,
    num_epochs=50,
    batch_size=batch_size,
    margin=.05,
    k=100,
    device=device,
)

In [None]:
torch.save(model, fopen(config.triplet_model_path, "wb"))

In [None]:
model = torch.load(fopen(config.triplet_model_path, "rb"), map_location=torch.device(device))

## Evaluation

In [None]:
# metric=euclidean is what TripletMarginLoss optimizes by default
# but this means that scores will be in terms of distance, not similarity, so take this into account when computing PR at thresholds
k = 100
best_matches = get_best_autoencoder_matches(model, input_names_test, candidate_names_test, k, batch_size)

### Test

In [None]:
print(best_matches.shape)
print(best_matches[0, 0, 0])
print(best_matches[0, 0, 1])

In [None]:
best_matches_names = best_matches[:, :, 0]
print(best_matches_names.shape)

### PR Curve

In [None]:
# minimum score threshold to test
metrics.precision_weighted_recall_curve_at_threshold(
    weighted_actual_names_test, best_matches, min_threshold=0.01, max_threshold=5.0, step=0.05, distances=True
)

### AUC

In [None]:
metrics.get_auc(
    weighted_actual_names_test, best_matches, min_threshold=0.01, max_threshold=5.0, step=0.05, distances=True
)

### Precision and recall at a threshold

In [None]:
threshold = 1.2

print(
    "precision",
    metrics.avg_precision_at_threshold(weighted_actual_names_test, best_matches, threshold=threshold, distances=True),
)
print(
    "recall",
    metrics.avg_weighted_recall_at_threshold(
        weighted_actual_names_test, best_matches, threshold=threshold, distances=True
    ),
)

In [None]:
wandb.finish()

In [None]:
train_matches = get_best_autoencoder_matches(model, input_names_train, 
                                             candidate_names_train, k, batch_size, n_jobs=1)

In [None]:
metrics.precision_weighted_recall_curve_at_threshold(
    weighted_actual_names_train, train_matches, min_threshold=0.01, max_threshold=5.0, step=0.05, distances=True
)

In [None]:
metrics.get_auc(
    weighted_actual_names_train, train_matches, min_threshold=0.01, max_threshold=5.0, step=0.05, distances=True
)

In [None]:
sum(len(wan) for wan in weighted_actual_names_train)

In [None]:
input_names_train[:15]

In [None]:
import pandas as pd

pref_path = "s3://familysearch-names/processed/tree-preferred-given-aggr.csv.gz"
pref_df = pd.read_csv(pref_path)

In [None]:
top_names = set(["aafje", "aafke", "aage", "aagje", "aagot", "dallin", "dallan"])
pref_df[pref_df["name"].isin(top_names)]