In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pickle
import torch
import numpy as np
from src.data.ancestry import load_ancestry_train_test
from src.metrics import metrics
from src.models import utils
from src.models import triplet_loss

In [None]:
MAX_NAME_LENGTH=30
char_to_idx_map, idx_to_char_map = utils.build_token_idx_maps()

### Load model

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
model = torch.load('../models/anc-encoder-bilstm-100-512.pth')

In [None]:
print(device)
model.to(device)
model.device=device

### Load data for fine-tuning and evaluation

In [None]:
train, test = load_ancestry_train_test(f'../data/raw/records25k_data_train.csv',
                                       f'../data/raw/records25k_data_test.csv')

input_names_train, weighted_actual_names_train, candidate_names_train = train
input_names_test, weighted_actual_names_test, candidate_names_test = test

actual_names_train = [[name for name, _, _ in name_weights] for name_weights in weighted_actual_names_train]
actual_names_test = [[name for name, _, _ in name_weights] for name_weights in weighted_actual_names_test]

candidate_names_all = np.concatenate((candidate_names_train, candidate_names_test))

### Fine-tune

In [None]:
near_negatives_train = triplet_loss.get_near_negatives(input_names_train, weighted_actual_names_train, candidate_names_train, k=50)

In [None]:
# save near_negatives
with open('../data/processed/ancestry_near_negatives.pickle', 'wb') as f:
    pickle.dump(near_negatives_train, f)

In [None]:
# load near_negatives
with open('../data/processed/ancestry_near_negatives.pickle', 'rb') as f:
    near_negatives_train = pickle.load(f)

In [None]:
triplet_loss.train_triplet_loss(model, input_names_train, weighted_actual_names_train, near_negatives_train,
                                input_names_test, weighted_actual_names_test, candidate_names_test, candidate_names_train, candidate_names_all,
                                char_to_idx_map, MAX_NAME_LENGTH, 40, 512, 0.05, 100, device)

In [None]:
torch.save(model, '../data/models/anc-triplet-bilstm-100-512-40-05.pth')

In [None]:
model = torch.load('../data/models/anc-triplet-bilstm-100-512-40-05.pth')

## Evaluation

In [None]:
# move to cpu for evaluation so we don't run out of GPU memory
model.to("cpu")
model.device="cpu"

In [None]:
# Get embeddings for train candidate names
candidate_names_train_X, _ = utils.convert_names_to_model_inputs(candidate_names_train,
                                                                 char_to_idx_map,
                                                                 MAX_NAME_LENGTH)
# Get Embeddings for the names from the encoder
candidate_names_train_encoded = model(candidate_names_train_X, just_encoder=True).detach().numpy()

In [None]:
# Get embeddings for test input names
input_names_test_X, _ = utils.convert_names_to_model_inputs(input_names_test,
                                                            char_to_idx_map,
                                                            MAX_NAME_LENGTH)
# Get Embeddings for the names from the encoder
input_names_test_encoded = model(input_names_test_X, just_encoder=True).detach().numpy()

In [None]:
# Get embeddings for test candidate names
candidate_names_test_X, _ = utils.convert_names_to_model_inputs(candidate_names_test,
                                                                char_to_idx_map,
                                                                MAX_NAME_LENGTH)
candidate_names_test_encoded = model(candidate_names_test_X, just_encoder=True).detach().numpy()

In [None]:
candidate_names_all_encoded = np.vstack((candidate_names_train_encoded, candidate_names_test_encoded))

In [None]:
input_names_test_encoded.shape

### Test

In [None]:
# matric=euclidean is what TripletMarginLoss optimizes by default
# but this means that scores will be in terms of distance, not similarity, so take this into account when computing PR at thresholds
k = 100
best_matches = utils.get_best_matches(input_names_test_encoded,
                                      candidate_names_all_encoded,
                                      candidate_names_all,
                                      num_candidates=k,
                                      metric='euclidean')
print(best_matches.shape)
print(best_matches[0, 0, 0])
print(best_matches[0, 0, 1])

In [None]:
best_matches_names = best_matches[:, :, 0]
print(best_matches_names.shape)

### PR Curve at k

In [None]:
metrics.precision_recall_curve_at_k(actual_names_test, best_matches_names, k)

In [None]:
metrics.avg_precision_at_threshold(weighted_actual_names_test, best_matches, 0.145, distances=True)

In [None]:
metrics.avg_weighted_recall_at_threshold(weighted_actual_names_test, best_matches, 0.145, distances=True)

### PR Curve at threshold

In [None]:
# minimum score threshold to test
metrics.precision_weighted_recall_curve_at_threshold(weighted_actual_names_test, best_matches,
                                                     min_threshold=0.01, max_threshold=1.0, step=.005, distances=True)

### AUC

In [None]:
metrics.get_auc(weighted_actual_names_test, best_matches,
                min_threshold=0.01, max_threshold=1.0, step=.005, distances=True)

### Precision and recall at a specific threshold

In [None]:
threshold = 0.14
print("precision", metrics.avg_precision_at_threshold(weighted_actual_names_test, best_matches, threshold=threshold, distances=True))
print("recall", metrics.avg_weighted_recall_at_threshold(weighted_actual_names_test, best_matches, threshold=threshold, distances=True))