In [None]:
%load_ext autoreload
%autoreload 2

# Run hyperparameter tuning on swivel model

## DEPRECATED

In [None]:
import os

import torch
import numpy as np
from ray import train, tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.hyperopt import HyperOptSearch

from nama.data.utils import load_dataset, select_frequent_k
from nama.data.filesystem import download_file_from_s3
from nama.eval import metrics
from nama.models import swivel

In [None]:
# Config

# run this on a 256GB standard instance

# TODO run both given and surname
given_surname = "given"
# given_surname = "surname"

num_matches = 500
SAMPLE_SIZE = 30000

TRAIN_DATA_PATH = f"s3://fs-nama-data/2024/familysearch-names/processed/tree-hr-{given_surname}-train-augmented.csv.gz"

In [None]:
train_data_path = download_file_from_s3(TRAIN_DATA_PATH) if TRAIN_DATA_PATH.startswith("s3://") else TRAIN_DATA_PATH
input_names_train, record_name_frequencies_train, candidate_names_train = load_dataset(train_data_path)

In [None]:
print(len(input_names_train))
print(len(candidate_names_train))
print(len(record_name_frequencies_train))

In [None]:
# sample the dataset
input_names_train_sample, record_name_frequencies_train_sample, candidate_names_train_sample = \
    select_frequent_k(input_names_train, 
                      record_name_frequencies_train, 
                      candidate_names_train, 
                      SAMPLE_SIZE)
print("sample input names", len(input_names_train_sample))
print("sample number of record name frequencies", sum(len(rnf) for rnf in record_name_frequencies_train_sample))
print("sample candidate names", len(candidate_names_train_sample))

In [None]:
DEFAULT_VOCAB_SIZE = SAMPLE_SIZE
DEFAULT_EMBEDDING_DIM = 100
DEFAULT_CONFIDENCE_BASE = 0.18
DEFAULT_CONFIDENCE_SCALE = 0.5
DEFAULT_CONFIDENCE_EXPONENT = 0.3
DEFAULT_LEARNING_RATE = 0.14
DEFAULT_SUBMATRIX_SIZE = 2048 # Needs to be adjusted with full dataset
DEFAULT_NUM_EPOCHS = 30 # Needs to be adjusted with full dataset

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
def compute_auc(model, 
                vocab, 
                input_names,
                record_name_frequencies,
                candidate_names):
    best_matches = swivel.get_best_swivel_matches(model, 
                                                  vocab, 
                                                  input_names,
                                                  candidate_names,
                                                  k=num_matches, 
                                                  batch_size=256,
                                                  add_context=True,
                                                  n_jobs=1,
                                                  progress_bar=False)
    return metrics.get_auc(record_name_frequencies, 
                           best_matches, 
                           min_threshold=0.01, 
                           max_threshold=2.0,
                           step=0.03, 
                           distances=False)

In [None]:
def plot_pr_curve(model, 
                  vocab, 
                  input_names,
                  record_name_frequencies,
                  candidate_names):
    best_matches = swivel.get_best_swivel_matches(model, 
                                                  vocab, 
                                                  input_names,
                                                  candidate_names,
                                                  k=num_matches, 
                                                  batch_size=256,
                                                  add_context=True,
                                                  n_jobs=1,
                                                  progress_bar=False)
    
    metrics.precision_weighted_recall_curve_at_threshold(record_name_frequencies, 
                                                         best_matches, 
                                                         min_threshold=0.01, 
                                                         max_threshold=2.0, 
                                                         step=0.05, 
                                                         distances=False)

In [None]:
def train_model(param_config, 
                input_names,
                record_name_frequencies,
                candidate_names):
    
    swivel_dataset = swivel.SwivelDataset(input_names,
                                          record_name_frequencies,
                                          param_config['vocab_size'])
    swivel_vocab = swivel_dataset.get_vocab()

    # Instantiate the model
    model = swivel.SwivelModel(len(swivel_vocab),
                               param_config['embedding_dim'], 
                               param_config['confidence_base'], 
                               param_config['confidence_scale'], 
                               param_config['confidence_exponent'])
    
    # Init model biases 
    model.init_params(swivel_dataset.get_row_sums(), swivel_dataset.get_col_sums())
    
    # Put model on device
    model.to(device)
    
    # Create optimizer
    optimizer = torch.optim.Adagrad(model.parameters(), 
                                   lr=param_config['learning_rate'])
    
    for epoch in range(param_config['num_epochs']):
        loss_values = swivel.train_swivel(model, 
                                          swivel_dataset,
                                          n_steps=0, 
                                          submatrix_size=param_config['submatrix_size'], 
                                          lr=param_config['learning_rate'], 
                                          device=device,
                                          verbose=False,
                                          optimizer=optimizer)
        
        # Compute AUC on the train data
        auc = compute_auc(model, 
                          swivel_vocab,
                          input_names,
                          record_name_frequencies,
                          candidate_names)
        
        # Report the metrics to Ray
        train.report({"auc": auc, "mean_loss": np.mean(loss_values)})

### Search space for parameters

In [None]:
param_config = {
    "vocab_size": DEFAULT_VOCAB_SIZE,
    "embedding_dim": tune.grid_search([50, 100]),
    "confidence_base": DEFAULT_CONFIDENCE_BASE,    # tune.quniform(0.1, 0.2, 0.02),
    "confidence_scale": DEFAULT_CONFIDENCE_SCALE,  # tune.quniform(0.4, 0.5, 0.05),
    "confidence_exponent": DEFAULT_CONFIDENCE_EXPONENT,  # tune.quniform(0.2, .4, 0.05),
    "learning_rate": DEFAULT_LEARNING_RATE,        # tune.quniform(0.04, 0.3, 0.02),
    "submatrix_size": DEFAULT_SUBMATRIX_SIZE,      # Needs to be adjusted with full dataset
    "num_epochs": DEFAULT_NUM_EPOCHS
}

In [None]:
# Will try to terminate bad trials early
# https://docs.ray.io/en/latest/tune/api_docs/schedulers.html
scheduler = ASHAScheduler(max_t=100,
                          grace_period=1, 
                          reduction_factor=4)

In [None]:
# Can provide multiple points
current_best_params = [
        {
            "vocab_size": DEFAULT_VOCAB_SIZE,
            "embedding_dim": DEFAULT_EMBEDDING_DIM,
            "confidence_base": DEFAULT_CONFIDENCE_BASE,
            "confidence_scale": DEFAULT_CONFIDENCE_SCALE,
            "confidence_exponent": DEFAULT_CONFIDENCE_EXPONENT,
            "learning_rate": DEFAULT_LEARNING_RATE,
            "submatrix_size": DEFAULT_SUBMATRIX_SIZE,
            "num_epochs": DEFAULT_NUM_EPOCHS
        }
    ]

# https://docs.ray.io/en/latest/tune/api_docs/suggestion.html#tune-hyperopt
search_alg = HyperOptSearch(points_to_evaluate=current_best_params)

### Run HPO

In [None]:
callbacks = []

result = tune.run(tune.with_parameters(train_model, 
                                       input_names=input_names_train_sample,
                                       record_name_frequencies=record_name_frequencies_train_sample,
                                       candidate_names=candidate_names_train_sample),
                  resources_per_trial={'cpu': 0.5, 'gpu': 0.5},
                  config=param_config,
#                   scheduler=scheduler,
#                   search_alg=search_alg,
#                   num_samples=8,
#                   metric='auc', 
#                   mode='max',
#                   checkpoint_score_attr='auc',
                  time_budget_s=6*3600,
#                   keep_checkpoints_num=100,
                  progress_reporter=tune.JupyterNotebookReporter(
                      overwrite=False,
                      max_report_frequency=5*60
                  ),
                  callbacks=callbacks
                 )


### Get best model

In [None]:
# Get trial that has the highest AUC (can also do with mean_loss or any other metric)
best_trial_auc = result.get_best_trial(metric='auc', mode='max', scope='all')

In [None]:
# Parameters with the highest AUC
best_trial_auc.config

In [None]:
print(f"Best trial final train loss: {best_trial_auc.last_result['mean_loss']}")
print(f"Best trial final train auc: {best_trial_auc.last_result['auc']}")