In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import namedtuple
import matplotlib.pyplot as plt
import torch
import wandb

from src.data.filesystem import fopen
from src.data.utils import load_train_test
from src.eval import metrics
from src.models.glove import GloveDataset, GloveModel, train_glove, get_best_glove_matches


In [None]:
# Config

given_surname = "given"
size = "freq"
vocab_size = 500000
embed_dim = 200
Config = namedtuple("Config", "train_path vocab_size embed_dim glove_dict_path glove_model_path")
config = Config(
    train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-similar-train-{size}.csv.gz",
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    glove_dict_path=f"s3://nama-data/data/models/fs-{given_surname}-{size}-glove-{vocab_size}-dict.pth",
    glove_model_path=f"s3://nama-data/data/models/fs-{given_surname}-{size}-glove-{vocab_size}-{embed_dim}.pt",
)

In [None]:
wandb.init(
    project="nama",
    entity="nama",
    name="52_glove",
    group=given_surname,
    notes="",
    config=config._asdict()
)

### Load data

In [None]:
[train] = load_train_test([config.train_path])

In [None]:
input_names_train, weighted_actual_names_train, candidate_names_train = train

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
dataset = GloveDataset(input_names_train, weighted_actual_names_train, config.vocab_size, device)
vocab = dataset.get_vocab()

In [None]:
model = GloveModel(len(vocab), config.embed_dim)
model.to(device=device)

In [None]:
loss_values = train_glove(model, dataset, n_epochs=100, batch_size=1024, x_max=100, alpha=0.75, lr=0.05, device=device)

In [None]:
torch.save(model.state_dict(), fopen(config.glove_model_path, "wb"))

In [None]:
model = torch.load(fopen(config.glove_model_path, "rb"), map_location=torch.device(device))

In [None]:
plt.plot(loss_values)

In [None]:
k = 100
batch_size = 256
best_matches = get_best_glove_matches(model, vocab, input_names_train, candidate_names_train, k, batch_size)

### PR Curve

In [None]:
metrics.precision_weighted_recall_curve_at_threshold(
    weighted_actual_names_train, best_matches, min_threshold=0.01, max_threshold=1.0, step=0.05, distances=False
)
