In [None]:
import os

import torch
import numpy as np
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.suggest.hyperopt import HyperOptSearch

from src.data.utils import load_train_test
from src.eval import metrics
from src.models import swivel

In [None]:
TRAIN_DATA_PATH = 's3://familysearch-names/processed/tree-hr-given-similar-train-freq.csv.gz'

In [None]:
input_names_train, weighted_actual_names_train, candidate_names_train = load_train_test([TRAIN_DATA_PATH])[0]

In [None]:
VOCAB_SIZE = 1000
input_names_train_sample = input_names_train[:100]
weighted_actual_names_train_sample = weighted_actual_names_train[:100]

In [None]:
DEFUALT_EMBEDDING_DIM = 100
DEFAULT_CONFIDENCE_BASE = 0.1
DEFAULT_CONFIDENCE_SCALE = 0.25
DEFAULT_CONFIDENCE_EXPONENT = 0.5
DEFAULT_LEARNING_RATE = 0.05
DEFAULT_SUBMATRIX_SIZE = 64 # Needs to be adjusted with full dataset
DEFAULT_NUM_EPOCHS = 50 # Needs to be adjusted with full dataset

In [None]:
dataset = swivel.SwivelDataset(input_names_train_sample, 
                               weighted_actual_names_train_sample,
                               VOCAB_SIZE,
                               symmetric=True)
vocab = dataset.get_vocab()

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def compute_auc(model, 
                vocab, 
                input_names,
                weighted_actual_names,
                candidate_names):
    best_matches = swivel.get_best_swivel_matches(model, 
                                                  vocab, 
                                                  input_names,
                                                  candidate_names,
                                                  k=100, 
                                                  batch_size=256,
                                                  add_context=True, 
                                                  n_jobs=1,
                                                  progress_bar=False)
    return metrics.get_auc(weighted_actual_names, 
                           best_matches, 
                           min_threshold=0.01, 
                           max_threshold=2.0,
                           step=0.03, 
                           distances=False)

In [None]:
def plot_pr_curve(model, 
                  vocab, 
                  input_names,
                  weighted_actual_names,
                  candidate_names):
    best_matches = swivel.get_best_swivel_matches(model, 
                                                  vocab, 
                                                  input_names,
                                                  candidate_names,
                                                  k=100, 
                                                  batch_size=256,
                                                  add_context=True, 
                                                  n_jobs=1,
                                                  progress_bar=False)
    
    metrics.precision_weighted_recall_curve_at_threshold(weighted_actual_names_train, 
                                                         best_matches, 
                                                         min_threshold=0.01, 
                                                         max_threshold=2.0, 
                                                         step=0.05, 
                                                         distances=False)

In [None]:
def train_model(param_config, 
                swivel_dataset,
                vocab,
                input_names,
                weighted_actual_names,
                candidate_names,
                checkpoint_dir=None):
    
    # Instantiate the model 
    model = swivel.SwivelModel(len(vocab), 
                               param_config['embedding_dim'], 
                               param_config['confidence_base'], 
                               param_config['confidence_scale'], 
                               param_config['confidence_exponent'])
    # Put model on device
    model.to(device)
    
    # Create optimizer
    optimizer = torch.optim.Adagrad(model.parameters(), 
                                   lr=param_config['learning_rate'])
    
    # Load checkpoint if exists
    if checkpoint_dir:
        model_state, optimizer_state = torch.load(
            os.path.join(checkpoint_dir, "checkpoint"))
        model.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)

    
    for epoch in range(param_config['num_epochs']):
        loss_values = swivel.train_swivel(model, 
                                          swivel_dataset,
                                          n_steps=0, 
                                          submatrix_size=param_config['submatrix_size'], 
                                          lr=param_config['learning_rate'], 
                                          device=device,
                                          verbose=False,
                                          optimizer=optimizer)
        
        # Compute AUC on the train data
        auc = compute_auc(model, 
                          vocab, 
                          input_names,
                          weighted_actual_names,
                          candidate_names)
        
        # Checkpoint the model
        with tune.checkpoint_dir(epoch) as checkpoint_dir:
            path = os.path.join(checkpoint_dir, "checkpoint")
            torch.save((model.state_dict(), optimizer.state_dict()), path)
        
        # Report the metrics to Ray
        tune.report(auc=auc, mean_loss=np.mean(loss_values))

### Search space for parameters

In [None]:
param_config = {
    "embedding_dim": tune.choice([25, 50, 100]),
    "confidence_base": tune.quniform(0.05, 0.15, 0.01),
    "confidence_scale": tune.quniform(0.05, 0.25, 0.01),
    "confidence_exponent": tune.quniform(0.25, 1.0, 0.05),
    "learning_rate": tune.quniform(0.01, 0.2, 0.01),
    "submatrix_size": tune.choice([32, 64, 128]),
    "num_epochs": 50 
}

In [None]:
# Will try to terminate bad trials early
# https://docs.ray.io/en/latest/tune/api_docs/schedulers.html
scheduler = ASHAScheduler(max_t=100,
                          grace_period=1, 
                          reduction_factor=4)

In [None]:
# Can provide multiple points
current_best_params = [
        {
            "embedding_dim": DEFUALT_EMBEDDING_DIM,
            "confidence_base": DEFAULT_CONFIDENCE_BASE,
            "confidence_scale": DEFAULT_CONFIDENCE_SCALE,
            "confidence_exponent": DEFAULT_CONFIDENCE_EXPONENT,
            "learning_rate": DEFAULT_LEARNING_RATE,
            "submatrix_size": DEFAULT_SUBMATRIX_SIZE,
            "num_epochs": DEFAULT_NUM_EPOCHS
        }
    ]

# https://docs.ray.io/en/latest/tune/api_docs/suggestion.html#tune-hyperopt
search_alg = HyperOptSearch(points_to_evaluate=current_best_params)

### Run HPO

In [None]:
result = tune.run(tune.with_parameters(train_model, 
                                       swivel_dataset=dataset,
                                       vocab=vocab,
                                       input_names=input_names_train_sample,
                                       weighted_actual_names=weighted_actual_names_train,
                                       candidate_names=candidate_names_train),
                  resources_per_trial={'cpu': 0.5}, # Add GPU
                  scheduler=scheduler,
                  config=param_config,
                  search_alg=search_alg,
                  num_samples=25, # Adjust during actual tuning
                  metric='auc', 
                  mode='max',
                  checkpoint_score_attr='auc')

### Get best model

In [None]:
# Get trial that has the highest AUC (can also do with mean_loss or any other metric)
best_trial_auc = result.get_best_trial(metric='auc', mode='max', scope='all')

In [None]:
# Parameters with the highest AUC
best_trial_auc.config

In [None]:
print(f"Best trial final train loss: {best_trial_auc.last_result['mean_loss']}")
print(f"Best trial final train auc: {best_trial_auc.last_result['auc']}")

In [None]:
# Get checkpoint dir for best model
best_checkpoint_dir = best_trial_auc.checkpoint.value

# Load best model
model_state, optimizer_state = torch.load(os.path.join(best_checkpoint_dir, 'checkpoint'))
best_trained_model = swivel.SwivelModel(len(vocab),
                                        embedding_dim=best_trial_auc.config['embedding_dim'],
                                        confidence_base=best_trial_auc.config['confidence_base'],
                                        confidence_scale=best_trial_auc.config['confidence_scale'],
                                        confidence_exponent=best_trial_auc.config['confidence_exponent'])
best_trained_model.load_state_dict(model_state)

### Plot PR curve

In [None]:
# plot pr curve with best model
plot_pr_curve(best_trained_model, 
              vocab, 
              input_names_train_sample,
              weighted_actual_names_train_sample,
              candidate_names_train)

### Demo

In [None]:
best_matches = swivel.get_best_swivel_matches(best_trained_model, 
                                              vocab, 
                                              input_names_train_sample,
                                              candidate_names_train,
                                              k=100, 
                                              batch_size=256,
                                              add_context=True, 
                                              n_jobs=1,
                                              progress_bar=False)

In [None]:
rndmx_name_idx = np.random.randint(len(input_names_train_sample))
print(f"Input name:  {input_names_train_sample[rndmx_name_idx]}")
print("Nearest names:")
print(best_matches[rndmx_name_idx][:10])

In [None]:
print("Actual names:")
sorted(weighted_actual_names_train[rndmx_name_idx][:10], key=lambda k: k[1], reverse=True)

### Get all trials as DF

In [None]:
# All trials as pandas dataframe
df = result.results_df

In [None]:
df