In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from collections import namedtuple
from datetime import datetime
import os

import numpy as np
import pandas as pd
import ray
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.integration.wandb import WandbLoggerCallback
import torch
import wandb

from src.data.filesystem import fopen
from src.data.utils import load_datasets, train_test_split
from src.eval import metrics
from src.models.swivel import SwivelModel, get_swivel_embeddings, get_best_swivel_matches
from src.models.swivel_encoder import SwivelEncoderModel, convert_names_to_model_inputs, train_swivel_encoder

In [None]:
# Config

wandb_api_key_file = ""  # "../.wandb-api-key"
given_surname = "given"
vocab_size = 600000 if given_surname == "given" else 2100000
embed_dim = 100
n_epochs = 15
optimize_size=0.8
validate_size=0.2

Config = namedtuple("Config", "train_path test_path embed_dim swivel_vocab_path swivel_model_path")
config = Config(
    train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-train.csv.gz",
    test_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-test.csv.gz",
    embed_dim=embed_dim,
    swivel_vocab_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-vocab-{vocab_size}.csv",
    # TODO fix
    swivel_model_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-model-{vocab_size}-{embed_dim}-50.pth",)

In [None]:
torch.cuda.empty_cache()
print(torch.cuda.is_available())
print("cuda total", torch.cuda.get_device_properties(0).total_memory)
print("cuda reserved", torch.cuda.memory_reserved(0))
print("cuda allocated", torch.cuda.memory_allocated(0))

### Load data

In [None]:
[train, test] = load_datasets([config.train_path, config.test_path])

In [None]:
input_names_train, weighted_actual_names_train, candidate_names_train = train
input_names_test, weighted_actual_names_test, candidate_names_test = test

In [None]:
print("input_names_train", len(input_names_train))
print("weighted_actual_names_train", sum(len(wan) for wan in weighted_actual_names_train))
print("candidate_names_train", len(candidate_names_train))

print("input_names_test", len(input_names_test))
print("weighted_actual_names_test", sum(len(wan) for wan in weighted_actual_names_test))
print("candidate_names_test", len(candidate_names_test))

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
vocab_df = pd.read_csv(fopen(config.swivel_vocab_path, "rb"))
print(vocab_df.head(5))

In [None]:
swivel_vocab = {name: _id for name, _id in zip(vocab_df["name"], vocab_df["index"])}
print(len(swivel_vocab))

In [None]:
swivel_model = SwivelModel(len(swivel_vocab), config.embed_dim)
swivel_model.load_state_dict(torch.load(fopen(config.swivel_model_path, "rb")))
swivel_model.eval()
print(swivel_model)

### Optimize Hyperparameters

#### Create optimization and validation sets from the training set

In [None]:
# split out the candidate names into train and validate sets
optimize, validate = train_test_split(input_names_train, weighted_actual_names_train, candidate_names_train,
                                      train_size=optimize_size, test_size=validate_size)
input_names_optimize, weighted_actual_names_optimize, candidate_names_optimize = optimize
input_names_validate, weighted_actual_names_validate, candidate_names_validate = validate
print("input_names_optimize", len(input_names_optimize))
print("candidate_names_optimize", len(candidate_names_optimize))
print("input_names_validate", len(input_names_validate))
print("candidate_names_validate", len(candidate_names_validate))

#### Use Ray to perform the search

In [None]:
def ray_training_function(config,
                          swivel_model,
                          swivel_vocab,
                          input_names_optimize,
                          candidate_names_optimize,
                          input_names_validate,
                          weighted_actual_names_validate,
                          candidate_names_validate,
                          device,
                          checkpoint_dir=None):
    names_optimize = list(set(input_names_optimize).union(set(candidate_names_optimize)))
    names_optimize_inputs = convert_names_to_model_inputs(names_optimize)
    names_optimize_embeddings = torch.Tensor(get_swivel_embeddings(swivel_model, swivel_vocab, names_optimize))

    # create model
    encoder_model = SwivelEncoderModel(n_layers=config["n_layers"],
                               char_embed_dim=config["char_embed_dim"],
                               n_hidden_units=config["n_hidden_units"],
                               output_dim=config["embed_dim"],
                               bidirectional=config["bidirectional"],
                               device=device)
    encoder_model.to(device=device)
    
    optimizer = torch.optim.Adam(encoder_model.parameters(), lr=config["lr"]) \
        if config["use_adam_opt"] \
        else torch.optim.Adagrad(encoder_model.parameters(), lr=config["lr"])

    # Load checkpoint if exists
    if checkpoint_dir:
        model_state, optimizer_state = torch.load(
            os.path.join(checkpoint_dir, "checkpoint"))
        encoder_model.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)

    for epoch in range(config["n_epochs"]):
        losses = train_swivel_encoder(encoder_model,
                                      names_optimize_inputs,
                                      names_optimize_embeddings,
                                      num_epochs=1,
                                      batch_size=config["batch_size"],
                                      use_adam_opt=config["use_adam_opt"],
                                      pack=config["pack"],
                                      verbose=False,
                                      optimizer=optimizer)
        best_matches = get_best_swivel_matches(None,
                                               None,
                                               input_names_validate,
                                               candidate_names_validate,
                                               k=100,
                                               batch_size=1024,
                                               add_context=True,
                                               encoder_model=encoder_model,
                                               n_jobs=1)
        auc = metrics.get_auc(
            weighted_actual_names_validate, best_matches, min_threshold=0.01, max_threshold=2.0, step=0.05, distances=False
        )

        # Checkpoint the model
        with tune.checkpoint_dir(epoch) as checkpoint_dir:
            path = os.path.join(checkpoint_dir, "checkpoint")
            torch.save((encoder_model.state_dict(), optimizer.state_dict()), path)

        # Report the metrics to Ray
        tune.report(auc=auc, mean_loss=np.mean(losses))

In [None]:
config_params={
    "embed_dim": embed_dim,
    "n_layers": tune.grid_search([2, 3]),
    "char_embed_dim": 32,  # tune.choice([0, 32, 64]),
    "n_hidden_units": 400,  # tune.choice([200, 300, 400]),
    "bidirectional": True,
    "lr": tune.grid_search([0.02, 0.03]),  # tune.quniform(0.02, 0.07, 0.01),
    "batch_size": tune.grid_search([128, 256]),  # tune.choice([128, 256]),
    "use_adam_opt": False,
    "pack": tune.grid_search([True, False])
    "n_epochs": n_epochs
}

current_best_params = [{
        "embed_dim": embed_dim,
        "n_layers": 2,
        "char_embed_dim": 32,
        "n_hidden_units": 400,
        "bidirectional": True,
        "lr": 0.02,
        "batch_size": 256,
        "use_adam_opt": False,
        "pack": False,
        "n_epochs": n_epochs,
    }]

In [None]:
# Will try to terminate bad trials early
# https://docs.ray.io/en/latest/tune/api_docs/schedulers.html
scheduler = ASHAScheduler(max_t=100,
                          grace_period=3,
                          reduction_factor=4)

# https://docs.ray.io/en/latest/tune/api_docs/suggestion.html#tune-hyperopt
search_alg = HyperOptSearch(points_to_evaluate=current_best_params)

In [None]:
ray.shutdown()
ray.init(_redis_max_memory=4*10**9)  # give redis extra memory

callbacks = []
if wandb_api_key_file:
    callbacks.append(WandbLoggerCallback(
        project="nama",
        entity="nama",
        group="62_swivel_encoder_tune_"+given_surname,
        notes="",
        config=config._asdict(),
        api_key_file=wandb_api_key_file
    ))

result = tune.run(
    tune.with_parameters(ray_training_function,
                         swivel_model=swivel_model,
                         swivel_vocab=swivel_vocab,
                         input_names_optimize=input_names_optimize,
                         candidate_names_optimize=candidate_names_optimize,
                         input_names_validate=input_names_validate,
                         weighted_actual_names_validate=weighted_actual_names_validate,
                         candidate_names_validate=candidate_names_validate,
                         device=device),
    resources_per_trial={"cpu": 0.5, "gpu": 1.0},
    config=config_params,
#     num_samples=100,
#     scheduler=scheduler,
#     search_alg=search_alg,
#     metric="auc",
#     mode="max",
#     checkpoint_score_attr="auc",
#     time_budget_s=0.3*3600,
#     keep_checkpoints_num=100,
#     progress_reporter=tune.JupyterNotebookReporter(
#         overwrite=False,
#         max_report_frequency=5*60
#     ),
#     callbacks=callbacks
)

### Get best model

In [None]:
# Get trial that has the highest AUC (can also do with mean_loss or any other metric)
best_trial_auc = result.get_best_trial(metric='auc', mode='max', scope='all')

In [None]:
# Parameters with the highest AUC
best_trial_auc.config

In [None]:
print(f"Best trial final train loss: {best_trial_auc.last_result['mean_loss']}")
print(f"Best trial final train auc: {best_trial_auc.last_result['auc']}")

In [None]:
# Get checkpoint dir for best model
best_checkpoint_dir = best_trial_auc.checkpoint.value

# Load best model
model_state, optimizer_state = torch.load(os.path.join(best_checkpoint_dir, 'checkpoint'))
best_trained_model = SwivelEncoderModel(n_layers=best_trial_auc.config["n_layers"],
                                        char_embed_dim=best_trial_auc.config["char_embed_dim"],
                                        n_hidden_units=best_trial_auc.config["n_hidden_units"],
                                        output_dim=embed_dim,
                                        bidirectional=best_trial_auc.config["bidirectional"],
                                        device=device)
best_trained_model.load_state_dict(model_state)
best_trained_model.eval()
best_trained_model.to(device=device)

### Get all trials as DF

In [None]:
# All trials as pandas dataframe
df = result.results_df

In [None]:
df[(df["auc"] > 0.78) & (df["mean_loss"] < 0.31)]

### Plot PR curve on validate

In [None]:
# plot pr curve with best model
best_matches = get_best_swivel_matches(None,
                                       None,
                                       input_names_validate,
                                       candidate_names_validate,
                                       k=100,
                                       batch_size=256,
                                       add_context=True,
                                       encoder_model=best_trained_model,
                                       n_jobs=4)

metrics.precision_weighted_recall_curve_at_threshold(weighted_actual_names_validate,
                                                     best_matches,
                                                     min_threshold=0.01,
                                                     max_threshold=2.0,
                                                     step=0.05,
                                                     distances=False)

### Plot PR curve on Test

In [None]:
# plot pr curve with best model
best_matches = get_best_swivel_matches(swivel_model,
                                       swivel_vocab,
                                       input_names_test,
                                       candidate_names_test,
                                       k=100,
                                       batch_size=1024,
                                       add_context=True,
                                       encoder_model=best_trained_model,
                                       n_jobs=8)

metrics.precision_weighted_recall_curve_at_threshold(weighted_actual_names_test,
                                                     best_matches,
                                                     min_threshold=0.01,
                                                     max_threshold=2.0,
                                                     step=0.05,
                                                     distances=False)

### Demo

In [None]:
rndmx_name_idx = np.random.randint(len(input_names_test))
print(f"Input name:  {input_names_test[rndmx_name_idx]}")
print("Nearest names:")
print(best_matches[rndmx_name_idx][:10])

In [None]:
print("Actual names:")
sorted(weighted_actual_names_test[rndmx_name_idx][:10], key=lambda k: k[1], reverse=True)

### Test a specific threshold

In [None]:
# precision and recall at a specific threshold

from src.eval.metrics import precision_at_threshold, weighted_recall_at_threshold

threshold = 0.4
precision = np.mean([precision_at_threshold(a, c, threshold, distances=False) \
                     for a, c in zip(weighted_actual_names_validate, best_matches)])
recall = np.mean([weighted_recall_at_threshold(a, c, threshold, distances=False) \
                     for a, c in zip(weighted_actual_names_validate, best_matches)])

print(precision, recall)