In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from collections import namedtuple
import pandas as pd
import ray
from ray import tune
from sklearn.model_selection import train_test_split
import torch
import wandb

from src.data.filesystem import fopen
from src.data.utils import load_train_test
from src.eval import metrics
from src.models.swivel import get_swivel_embeddings, get_best_swivel_matches
from src.models.swivel_encoder import SwivelEncoderModel, convert_names_to_model_inputs, train_swivel_encoder

In [None]:
# Config

given_surname = "given"
size = "freq"
vocab_size = 500000
embed_dim = 200
Config = namedtuple("Config", "train_path test_path embed_dim vocab_path model_path")
config = Config(
    train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-similar-train-{size}.csv.gz",
    test_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-similar-test-{size}.csv.gz",
    embed_dim=embed_dim,
    vocab_path=f"s3://nama-data/data/models/fs-{given_surname}-{size}-swivel-{vocab_size}-vocab-tfidf.csv",
    model_path=f"s3://nama-data/data/models/fs-{given_surname}-{size}-swivel-{vocab_size}-{embed_dim}-tfidf.pt",
)

In [None]:
# wandb.init(
#     project="nama",
#     entity="nama",
#     name="54_swivel_encoder",
#     group=given_surname,
#     notes="",
#     config=config._asdict()
# )

### Load data

In [None]:
[train, test] = load_train_test([config.train_path, config.test_path])

In [None]:
input_names_train, weighted_actual_names_train, candidate_names_train = train
input_names_test, weighted_actual_names_test, candidate_names_test = test

In [None]:
print("input_names_train", len(input_names_train))
print("weighted_actual_names_train", sum(len(wan) for wan in weighted_actual_names_train))
print("candidate_names_train", len(candidate_names_train))

print("input_names_test", len(input_names_test))
print("weighted_actual_names_test", sum(len(wan) for wan in weighted_actual_names_test))
print("candidate_names_test", len(candidate_names_test))

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
vocab_df = pd.read_csv(fopen(config.vocab_path, "rb"))
print(vocab_df.head(5))

In [None]:
vocab = {name: _id for name, _id in zip(vocab_df["name"], vocab_df["index"])}
id2name = {i: n for (n, i) in vocab.items()}
print(vocab["<john>"])
print(id2name[vocab["<john>"]])

In [None]:
model = torch.load(fopen(config.model_path, "rb"))
model.eval()
print(model)

### Optimize Hyperparameters

In [None]:
# train using all names in the vocabulary
train_names = vocab.keys()

In [None]:
train_model_inputs = convert_names_to_model_inputs(train_names)
print(train_model_inputs.shape)
print(train_model_inputs.dtype)

#### Create optimization and validation sets from the training set

In [None]:
# split out 10% of the candidate names
candidate_names_optimize, candidate_names_validate = train_test_split(candidate_names_train, test_size=0.1)

In [None]:
# generate weighted_actual_names_validate from weighted_actual_names_train and candidate_names_validate
candidate_names_validate_set = set(candidate_names_validate)
weighted_actual_names_validate = []
for wans in weighted_actual_names_train:
    total_freq = sum([freq for name, _, freq in wans if name in candidate_names_validate_set])
    wans = [(name, freq/total_freq, freq) for name, _, freq in wans if freq > 0 and name in candidate_names_validate_set]
    total_weight = sum([weight for _, weight, _ in wans])
    if total_weight > 1.0001:
        print("total", total_weight, "wans", wans)
    weighted_actual_names_validate.append(wans)
print(weighted_actual_names_validate[0:2])

In [None]:
# remove validation names from optimization set
optimize_names = list(set(vocab.keys()) - set(candidate_names_validate))

In [None]:
optimize_name_inputs = convert_names_to_model_inputs(optimize_names)

In [None]:
optimize_embeddings = torch.Tensor(get_swivel_embeddings(model, vocab, optimize_names))
print(optimize_embeddings.shape)

#### Use Ray to perform the search

In [None]:
def ray_training_function(config, data):
    # Hyperparameters
    n_layers, char_embed_dim, n_hidden_units, bidirectional, lr, batch_size, use_adam_opt = \
        config["n_layers"], config["char_embed_dim"], config["n_hidden_units"], config["bidirectional"], config["lr"], config["batch_size"], config["use_adam_opt"]
    model, optimize_name_inputs, optimize_embeddings, input_names_train, candidate_names_validate, weighted_actual_names_validate = \
        data["model"], data["optimize_name_inputs"], data["optimize_embeddings"], data["input_names_train"], data["candidate_names_validate"], data["weighted_actual_names_validate"]
    # create model
    embed_dim = 200
    encoder_model = SwivelEncoderModel(n_layers=n_layers, char_embed_dim=char_embed_dim, n_hidden_units=n_hidden_units,
                                       output_dim=embed_dim, bidirectional=bidirectional, device=device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr) if use_adam_opt else torch.optim.Adagrad(model.parameters(), lr=lr)
    n_epochs = 50
    k = 100
    eval_batch_size = 256
    add_context = True
    n_jobs=1
    for step in range(n_epochs):
        train_swivel_encoder(encoder_model, optimize_name_inputs, optimize_embeddings, num_epochs=n_epochs, batch_size=batch_size, lr=lr,
                             use_adam_opt=use_adam_opt, silent=True, optimizer=optimizer)
        # calculate loss over the validation set, which is 1 - area under the PR curve
        best_matches = get_best_swivel_matches(None, vocab, input_names_train,
                                               candidate_names_validate, k, eval_batch_size,
                                               add_context=add_context, encoder_model=encoder_model, n_jobs=n_jobs)
        auc = metrics.get_auc(
            weighted_actual_names_validate, best_matches, min_threshold=0.01, max_threshold=2.0, step=0.05, distances=False
        )
        tune.report(mean_loss=1 - auc)

In [None]:
ray.shutdown()
ray.init(_redis_max_memory=4*10**9)

In [None]:
# search_alg = BayesOptSearch(metric="mean_loss", mode="min")

time_budget = 3600*4

analysis = tune.run(
    tune.with_parameters(ray_training_function, data={
        "model": model,
        "optimize_name_inputs": optimize_name_inputs,
        "optimize_embeddings": optimize_embeddings,
        "input_names_train": input_names_train,
        "candidate_names_validate": candidate_names_validate,
        "weighted_actual_names_validate": weighted_actual_names_validate,
    }),
    num_samples=10,
    time_budget_s=time_budget,
    config={
        "n_layers": tune.choice([1, 2]),
        "char_embed_dim": tune.choice([0, 16, 64]),
        "n_hidden_units": tune.choice([200]),  #64, 100, 200]),
        "bidirectional": tune.choice([True]),  #, False]),
        "lr": tune.uniform(0.01, 0.10),
        "batch_size": tune.choice([64, 256]),
        "use_adam_opt": tune.choice([True, False]),
    },
    resources_per_trial={
        "cpu": 8,
        "gpu": 1,
    },
    progress_reporter=tune.JupyterNotebookReporter(overwrite=True, max_report_frequency=5)
    # search_alg=ConcurrencyLimiter(search_alg, max_concurrent=1),
)
print("Best config: ", analysis.get_best_config(
    metric="mean_loss", mode="min"))