In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from collections import namedtuple
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.model_selection import train_test_split
import torch
import wandb

from src.data.filesystem import fopen
from src.data.utils import load_dataset, select_frequent_k
from src.eval import metrics
from src.models.swivel import SwivelModel, get_swivel_embeddings, get_best_swivel_matches
from src.models.swivel_encoder import SwivelEncoderModel, convert_names_to_model_inputs, train_swivel_encoder

In [None]:
# Config

given_surname = "given"
vocab_size = 610000 if given_surname == "given" else 2100000
encoder_vocab_size = vocab_size
sample_size = 10000
embed_dim = 100
n_layers = 2
n_epochs = 100 if n_layers == 1 else 200 if n_layers == 2 else 400 if n_layers == 3 else 800
DROPOUT = 0.0
num_matches = 500

Config = namedtuple("Config", "train_path eval_path test_path embed_dim n_layers char_embed_dim n_hidden_units bidirectional lr batch_size use_adam_opt pack n_epochs swivel_vocab_path swivel_model_path encoder_model_path")
config = Config(
    train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-train-augmented.csv.gz",
    eval_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-train.csv.gz",
    test_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-test.csv.gz",
    embed_dim=embed_dim,
    n_layers = n_layers,
    char_embed_dim = 64,
    n_hidden_units = 400,
    bidirectional = True,
    lr = 0.03,
    batch_size = 256,
    use_adam_opt = False,
    pack = True,
    n_epochs=n_epochs,
    swivel_vocab_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-vocab-{vocab_size}-augmented.csv",
    swivel_model_path=f"s3://nama-data/data/models/fs-{given_surname}-swivel-model-{vocab_size}-{embed_dim}-augmented.pth",
    encoder_model_path=f"s3://nama-data/data/models/fs-{given_surname}-encoder-model-{encoder_vocab_size}-{embed_dim}-{n_layers}-augmented.pth",
)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
print("cuda total", torch.cuda.get_device_properties(0).total_memory)
print("cuda reserved", torch.cuda.memory_reserved(0))
print("cuda allocated", torch.cuda.memory_allocated(0))

In [None]:
wandb.init(
    project="nama",
    entity="nama",
    name="63_swivel_encoder",
    group=given_surname,
    notes="",
    config=config._asdict()
)

### Load data

In [None]:
input_names_train, weighted_actual_names_train, candidate_names_train = load_dataset(config.train_path)
input_names_eval, weighted_actual_names_eval, candidate_names_eval = load_dataset(config.eval_path, is_eval=True)
input_names_test, weighted_actual_names_test, candidate_names_test = load_dataset(config.test_path, is_eval=True)

In [None]:
print("input_names_eval", len(input_names_eval))
print("weighted_actual_names_eval", sum(len(wan) for wan in weighted_actual_names_eval))
print("candidate_names_eval", len(candidate_names_eval))

print("input_names_test", len(input_names_test))
print("weighted_actual_names_test", sum(len(wan) for wan in weighted_actual_names_test))
print("candidate_names_test", len(candidate_names_test))

In [None]:
swivel_vocab_df = pd.read_csv(fopen(config.swivel_vocab_path, "rb"))
print(swivel_vocab_df.head(5))

In [None]:
swivel_vocab = {name: _id for name, _id in zip(swivel_vocab_df["name"], swivel_vocab_df["index"])}
print(swivel_vocab["<john>"])

In [None]:
swivel_model = SwivelModel(len(swivel_vocab), config.embed_dim)
swivel_model.load_state_dict(torch.load(fopen(config.swivel_model_path, "rb")))
swivel_model.eval()
swivel_model.to(device)
print(swivel_model)

### Train

In [None]:
# train using all names in the vocabulary or a subset
if encoder_vocab_size >= len(swivel_vocab):
    train_names = list(swivel_vocab.keys())
else:
    input_names_train, weighted_actual_names_train, candidate_names_train = \
        select_frequent_k(input_names_train,
                          weighted_actual_names_train,
                          candidate_names_train,
                          encoder_vocab_size)
    train_names = list(set(input_names_train).union(set(candidate_names_train)))
print("train_names", len(train_names))

In [None]:
# free memory
input_names_train = weighted_actual_names_train = candidate_names_train = None

In [None]:
train_embeddings = torch.Tensor(get_swivel_embeddings(swivel_model, swivel_vocab, train_names))
print(train_embeddings.shape)

In [None]:
train_inputs = convert_names_to_model_inputs(train_names)
print(train_inputs.shape)
print(train_inputs.dtype)

In [None]:
# create model
encoder_model = SwivelEncoderModel(n_layers=config.n_layers,
                                   char_embed_dim=config.char_embed_dim,
                                   n_hidden_units=config.n_hidden_units,
                                   output_dim=config.embed_dim,
                                   bidirectional=config.bidirectional,
                                   pack=config.pack,
                                   dropout=DROPOUT,
                                   device=device)

In [None]:
%%time
_ = train_swivel_encoder(encoder_model,
                         train_inputs,
                         train_embeddings,
                         num_epochs=config.n_epochs,
                         batch_size=config.batch_size,
                         lr=config.lr,
                         use_adam_opt=config.use_adam_opt,
                         use_mse_loss=False,
                         checkpoint_path=config.encoder_model_path)

### Save model

In [None]:
torch.save(encoder_model.state_dict(), fopen(config.encoder_model_path, "wb"))

### Reload model

In [None]:
encoder_model = SwivelEncoderModel(n_layers=config.n_layers,
                                   char_embed_dim=config.char_embed_dim,
                                   n_hidden_units=config.n_hidden_units,
                                   output_dim=config.embed_dim,
                                   bidirectional=config.bidirectional,
                                   pack=config.pack,
                                   device=device)
encoder_model.load_state_dict(torch.load(fopen(config.encoder_model_path, "rb")))
encoder_model.eval()
encoder_model.device = device
encoder_model.to(device)

### Eval

#### On training data

In [None]:
# sample data
_, input_names_sample, _, weighted_actual_names_sample = \
   train_test_split(input_names_eval, weighted_actual_names_eval, test_size=sample_size)
candidate_names_sample = candidate_names_eval
print("input_names_sample", len(input_names_sample))
print("canidate_names_sample", len(candidate_names_sample))

In [None]:
# get best matches
# NOTE: only considers as potential matches names in candidate_names, not names in input_names
batch_size = 256
add_context = True
n_jobs=1
best_matches = get_best_swivel_matches(model=None,
                                       vocab=None,
                                       input_names=input_names_sample,
                                       candidate_names=candidate_names_sample,
                                       k=num_matches,
                                       batch_size=batch_size,
                                       add_context=add_context,
                                       encoder_model=encoder_model,
                                       n_jobs=n_jobs,
                                       progress_bar=True)

##### PR Curve

In [None]:
metrics.precision_weighted_recall_curve_at_threshold(
    weighted_actual_names_sample, best_matches, min_threshold=0.01, max_threshold=1.0, step=0.05, distances=False
)

##### AUC

In [None]:
print(metrics.get_auc(
    weighted_actual_names_sample, best_matches, min_threshold=0.1, max_threshold=1.0, step=0.05, distances=False
))

#### On test data

In [None]:
# sample data
_, input_names_sample, _, weighted_actual_names_sample = \
   train_test_split(input_names_test, weighted_actual_names_test, test_size=sample_size)
candidate_names_sample = candidate_names_test
print("input_names_sample", len(input_names_sample))
print("canidate_names_sample", len(candidate_names_sample))

In [None]:
n_zero = n_one = n_two = 0
for input_name, wans in zip(input_names_sample, weighted_actual_names_sample):
    for actual_name, _, _ in wans:
        if input_name in swivel_vocab and actual_name in swivel_vocab and input_name != actual_name:
            n_two += 1
        elif input_name in swivel_vocab or actual_name in swivel_vocab:
            n_one += 1
        else:
            n_zero += 1
print("two names in vocab (should not be possible)", n_two)
print("one name in vocab", n_one)
print("zero names in vocab", n_zero)

In [None]:
# get best matches
# NOTE: only considers as potential matches names in candidate_names, not names in input_names
batch_size = 256
add_context = True
n_jobs=1
best_matches = get_best_swivel_matches(model=None,
                                       vocab=None,
                                       input_names=input_names_sample,
                                       candidate_names=candidate_names_sample,
                                       k=num_matches,
                                       batch_size=batch_size,
                                       add_context=add_context,
                                       encoder_model=encoder_model,
                                       n_jobs=n_jobs,
                                       progress_bar=True)

##### PR Curve

In [None]:
metrics.precision_weighted_recall_curve_at_threshold(
    weighted_actual_names_sample, best_matches, min_threshold=0.01, max_threshold=1.0, step=0.05, distances=False
)

##### AUC

In [None]:
print(metrics.get_auc(
    weighted_actual_names_sample, best_matches, min_threshold=0.1, max_threshold=1.0, step=0.05, distances=False
))

In [None]:
wandb.finish()

### Test

In [None]:
test_names = ["<john>", "<johnny>", "<jonathan>",
              "<mary>", "<marie>", "<maria>"]
test_embeddings = torch.Tensor(get_swivel_embeddings(swivel_model, swivel_vocab, test_names))
print(test_embeddings.shape)

In [None]:
print(test_names[0:3])
print(cosine_similarity(test_embeddings[0:1], test_embeddings[0:3]))
print(test_names[3:])
print(cosine_similarity(test_embeddings[0:1], test_embeddings[3:]))

In [None]:
test_model_inputs = convert_names_to_model_inputs(test_names)
print(test_model_inputs.shape)
print(test_model_inputs.dtype)

In [None]:
# create model
n_layers = 1
char_embed_dim = 64
n_hidden_units = 200
embed_dim = 100
bidirectional = True
pack = False
encoder_model = SwivelEncoderModel(n_layers=n_layers, char_embed_dim=char_embed_dim, n_hidden_units=n_hidden_units,
                                   output_dim=embed_dim, bidirectional=bidirectional, pack=pack, device=device)

In [None]:
lr = 0.01
n_epochs=100
use_adam_opt = False
use_mse_loss = False
train_swivel_encoder(encoder_model, test_model_inputs, test_embeddings, num_epochs=n_epochs, batch_size=64, lr=lr,
                     use_adam_opt=use_adam_opt, use_mse_loss=use_mse_loss)

In [None]:
test_embeddings_predicted = encoder_model(test_model_inputs).detach().cpu().numpy()

In [None]:
test_embeddings_numpy = test_embeddings.cpu().numpy()

In [None]:
cosine_similarity(test_embeddings_numpy, test_embeddings_predicted)

In [None]:
cosine_similarity(test_embeddings_numpy, test_embeddings_numpy)

#### Replicate model training here

In [None]:
# create optimizer and loss function
batch_size = 16
lr = 0.05

optimizer = torch.optim.Adam(encoder_model.parameters(), lr=lr)
# optimizer = optim.Adagrad(model.parameters(), lr=lr)
loss_fn = torch.nn.MSELoss()

In [None]:
# create data loader
dataset_train = torch.utils.data.TensorDataset(test_model_inputs, test_embeddings)
data_loader = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

In [None]:
# get batch
train_batch, targets_batch = next(iter(data_loader))
print(train_batch.shape)
print(targets_batch.shape)

In [None]:
from src.data import constants
X = train_batch
encoder_model.to(device=device)

# Compute forward pass
# x_prime = model(train_batch)

# Clear out gradient
encoder_model.zero_grad()

# forward pass
X = X.to(device=device)
batch_size, seq_len = train_batch.size()
print("batch_size", batch_size, "seq_len", seq_len)

# init hidden state before each batch
n_directions = 2 if bidirectional else 1
# hidden = (
#     torch.randn(n_layers * n_directions, batch_size, n_hidden_units).to(device=device),  # initial hidden state
#     torch.randn(n_layers * n_directions, batch_size, n_hidden_units).to(device=device),  # initial cell state
# )

# sort batch by sequence length
# X_lengths = torch.count_nonzero(X, dim=1).to(device="cpu").type(torch.int64)
# ixs = torch.argsort(X_lengths, descending=True)
# X = X[ixs]
# X_lengths = X_lengths[ixs]
# print("X", X.get_device(), "X_lengths", X_lengths.get_device())


eye = torch.eye(constants.VOCAB_SIZE + 1).to(device=device)
X = eye[X]

# pack sequences
# X = pack_padded_sequence(X, X_lengths, batch_first=True, enforce_sorted=True)

# run through LSTM
# all, hidden = encoder_model.lstm(X.to(device), hidden)
all, (hidden, cell) = encoder_model.lstm(X.to(device))
print("hidden", hidden.shape, cell.shape)

embeddings = encoder_model.linear(hidden[0][-1])  # compute the linear model based on the last hidden state of the last layer
print("embeddings", embeddings.shape)

# Compute loss
loss = loss_fn(embeddings, targets_batch.to(encoder_model.device))
# do the backward pass and update parameters
loss.backward()
optimizer.step()

print(loss.item())