In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import namedtuple
import torch
import numpy as np
import wandb

from src.data import constants
from src.data.filesystem import fopen
from src.data.utils import load_train_test
from src.eval import metrics
from src.eval.encoder import eval_encoder
from src.models import utils
from src.models.autoencoder import AutoEncoder, train_model

In [None]:
given_surname = "surname"
size = "freq"
Config = namedtuple("Config", "size train_path test_path model_path")
config = Config(
    size=size,
    train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-similar-train-{size}.csv.gz",
    test_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-similar-test-{size}.csv.gz",
    model_path=f"s3://nama-data/data/models/fs-{size}-autoencoder-bilstm-100-512.pth"
)

In [None]:
wandb.init(
    project="nama",
    entity="nama",
    name="50_autoencoder",
    group=given_surname,
    notes="",
    config=config._asdict()
)

In [None]:
MAX_NAME_LENGTH = 30
char_to_idx_map, idx_to_char_map = utils.build_token_idx_maps()

### Load data

In [None]:
train, test = load_train_test([config.train_path, config.test_path])
_, _, candidate_names_train = train
input_names_test, weighted_actual_names_test, candidate_names_test = test

actual_names_test = [[name for name, _, _ in name_weights] for name_weights in weighted_actual_names_test]

candidate_names_all = np.concatenate((candidate_names_train, candidate_names_test))

### Convert names to ids

In [None]:
# Prepare data for training
# inputs and targets have the same data just in different representations 1-hot vs normal sequences
candidate_names_train_X, candidate_names_train_y = utils.convert_names_to_model_inputs(
    candidate_names_train, char_to_idx_map, MAX_NAME_LENGTH
)

In [None]:
print(candidate_names_train_X.shape, candidate_names_train_y.shape)

### Model

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
model = AutoEncoder(
    input_size=constants.VOCAB_SIZE + 1, hidden_size=100, num_layers=1, seq_len=MAX_NAME_LENGTH, device=device
)

In [None]:
train_model(model, candidate_names_train_X, candidate_names_train_y, 100, 512)

In [None]:
torch.save(model, fopen(config.model_path, "wb"))

In [None]:
model = torch.load(fopen(config.model_path, "rb"), map_location=torch.device(device))

### Understand AutoEncoder

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()
dataset_train = torch.utils.data.TensorDataset(candidate_names_train_X, candidate_names_train_y)
data_loader = torch.utils.data.DataLoader(dataset_train, batch_size=512, shuffle=True)

In [None]:
X, y = next(iter(data_loader))
print(X.shape, y.shape)

In [None]:
model.zero_grad()
# Encode(input,hidden) -> (batch,seq,dirs*hidden), ((dirs*layers,batch,hidden),(dirs*layers,batch,hidden)) - x_encoded is the last hidden state
_, (x_encoded, _) = model.lstm_encoder(X.to(device))
print(x_encoded.shape)

In [None]:
# Concatenate left-right hidden vectors
x_encoded = torch.cat([x_encoded[0], x_encoded[1]], dim=1)
print(x_encoded.shape)

In [None]:
# Reshape data to have seq_len time steps
# TODO why do we copy x_encoded to every time step?
x_encoded = x_encoded.unsqueeze(1).repeat(1, MAX_NAME_LENGTH, 1)
print(x_encoded.shape)

In [None]:
# Decode(hidden*dirs,hidden) -> (batch,seq,dirs*hidden), ((dirs*layers,batch,hidden),(dirs*layers,batch,hidden)) - x_decoded is the output
x_decoded, (_, _) = model.lstm_decoder(x_encoded)
print(x_decoded.shape)

In [None]:
# linear layer(hidden,input) predicts characters
x_prime = model.linear(x_decoded)
print(x_prime.shape)

In [None]:
# Reshape output to match CrossEntropyLoss input
x_prime = x_prime.transpose(1, -1)
print(x_prime.shape)

In [None]:
# Compute loss (batch,classes,seq), (batch,seq)
loss = loss_fn(x_prime, y.to(device))
print(loss)

### Evaluation

In [None]:
batch_size = 512

In [None]:
# Get Embeddings for the names from the encoder
candidate_names_train_encoded = eval_encoder(model, candidate_names_train_X, batch_size)

In [None]:
print(candidate_names_train_X.shape)
print(candidate_names_train_encoded.shape)

In [None]:
# Convert test name inputs to correct format
input_names_test_X, input_names_test_y = utils.convert_names_to_model_inputs(
    input_names_test, char_to_idx_map, MAX_NAME_LENGTH
)
# Get Embeddings for the names from the encoder
input_names_test_encoded = eval_encoder(model, input_names_test_X, batch_size)

In [None]:
print(input_names_test_X.shape, input_names_test_y.shape)
print(input_names_test_encoded.shape)

In [None]:
# Get embeddings for all candidate names (train + test)
candidate_names_test_X, _ = utils.convert_names_to_model_inputs(candidate_names_test, char_to_idx_map, MAX_NAME_LENGTH)
candidate_names_test_encoded = eval_encoder(model, candidate_names_test_X, batch_size)

In [None]:
print(candidate_names_test_X.shape)
print(candidate_names_test_encoded.shape)

In [None]:
candidate_names_all_encoded = np.vstack((candidate_names_train_encoded, candidate_names_test_encoded))

In [None]:
print(candidate_names_all_encoded.shape)

In [None]:
k = 100
candidate_names_scores = utils.get_best_matches(
    input_names_test_encoded, candidate_names_all_encoded, candidate_names_all, num_candidates=k
)
print(candidate_names_scores.shape)
print(candidate_names_scores[0, 0, 0])
print(candidate_names_scores[0, 0, 1])

In [None]:
candidate_names = candidate_names_scores[:, :, 0]
print(candidate_names.shape)

### Demo

In [None]:
test_name_X = ["<schumacher>"]
test_name_X, _ = utils.convert_names_to_model_inputs(test_name_X, char_to_idx_map, MAX_NAME_LENGTH)
test_name_embedding = model(test_name_X, just_encoder=True).detach().numpy()

print(utils.get_best_matches(test_name_embedding, candidate_names_all_encoded, candidate_names_all, num_candidates=10))

## Evaluate using weighted relevant names and score thresholds

### Average precision at 0.97

In [None]:
metrics.avg_precision_at_threshold(weighted_actual_names_test, candidate_names_scores, 0.97)

### Average recall at 0.97

In [None]:
metrics.avg_weighted_recall_at_threshold(weighted_actual_names_test, candidate_names_scores, 0.97)

### PR Curve

In [None]:
# minimum score threshold to test
min_threshold = 0.5
metrics.precision_weighted_recall_curve_at_threshold(weighted_actual_names_test, candidate_names_scores, min_threshold)

### AUC

In [None]:
metrics.get_auc(weighted_actual_names_test, candidate_names_scores)