In [None]:
%load_ext autoreload
%autoreload 2

# Generate an autoencoder model based upon similar name pairs

In [None]:
from collections import namedtuple
import torch
import wandb

from src.data import constants
from src.data.filesystem import fopen
from src.data.utils import load_train_test
from src.eval import metrics
from src.models.autoencoder import train_model, AutoEncoder, MAX_NAME_LENGTH, get_best_autoencoder_matches, convert_names_to_model_inputs

In [None]:
given_surname = "given"
size = "freq"
Config = namedtuple("Config", "size train_path test_path model_path")
config = Config(
    size=size,
    train_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-similar-train-{size}.csv.gz",
    test_path=f"s3://familysearch-names/processed/tree-hr-{given_surname}-similar-test-{size}.csv.gz",
    model_path=f"s3://nama-data/data/models/fs-{given_surname}-{size}-autoencoder-bilstm-100-512.pth"
)

In [None]:
wandb.init(
    project="nama",
    entity="nama",
    name="50_autoencoder",
    group=given_surname,
    notes="",
    config=config._asdict()
)

### Load data

In [None]:
train, test = load_train_test([config.train_path, config.test_path])

_, _, candidate_names_train = train
input_names_test, weighted_actual_names_test, candidate_names_test = test

### Convert names to ids

In [None]:
# Prepare data for training
# inputs and targets have the same data just in different representations 1-hot vs normal sequences
candidate_names_train_X, candidate_names_train_y = convert_names_to_model_inputs(
    candidate_names_train
)

In [None]:
print(candidate_names_train_X.shape, candidate_names_train_y.shape)

### Model

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
model = AutoEncoder(
    input_size=constants.VOCAB_SIZE + 1, hidden_size=100, num_layers=1, device=device
)

In [None]:
train_model(model, candidate_names_train_X, candidate_names_train_y, 100, 512)

In [None]:
torch.save(model, fopen(config.model_path, "wb"))

In [None]:
model = torch.load(fopen(config.model_path, "rb"), map_location=torch.device(device))

### Understand AutoEncoder

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()
dataset_train = torch.utils.data.TensorDataset(candidate_names_train_X, candidate_names_train_y)
data_loader = torch.utils.data.DataLoader(dataset_train, batch_size=512, shuffle=True)

In [None]:
X, y = next(iter(data_loader))
print(X.shape, y.shape)

In [None]:
model.zero_grad()
# Encode(input,hidden) -> (batch,seq,dirs*hidden), ((dirs*layers,batch,hidden),(dirs*layers,batch,hidden)) - x_encoded is the last hidden state
_, (x_encoded, _) = model.lstm_encoder(X.to(device))
print(x_encoded.shape)

In [None]:
# Concatenate left-right hidden vectors
x_encoded = torch.cat([x_encoded[0], x_encoded[1]], dim=1)
print(x_encoded.shape)

In [None]:
# Reshape data to have seq_len time steps
# TODO why do we copy x_encoded to every time step?
x_encoded = x_encoded.unsqueeze(1).repeat(1, MAX_NAME_LENGTH, 1)
print(x_encoded.shape)

In [None]:
# Decode(hidden*dirs,hidden) -> (batch,seq,dirs*hidden), ((dirs*layers,batch,hidden),(dirs*layers,batch,hidden)) - x_decoded is the output
x_decoded, (_, _) = model.lstm_decoder(x_encoded)
print(x_decoded.shape)

In [None]:
# linear layer(hidden,input) predicts characters
x_prime = model.linear(x_decoded)
print(x_prime.shape)

In [None]:
# Reshape output to match CrossEntropyLoss input
x_prime = x_prime.transpose(1, -1)
print(x_prime.shape)

In [None]:
# Compute loss (batch,classes,seq), (batch,seq)
loss = loss_fn(x_prime, y.to(device))
print(loss)

### Test

In [None]:
test_name = ["<schumacher>"]
print(get_best_autoencoder_matches(model, test_name, candidate_names_test, k=10))

### Evaluation

In [None]:
batch_size = 512

In [None]:
k = 100
best_matches = get_best_autoencoder_matches(model, input_names_test, candidate_names_test, k, batch_size)
print(best_matches.shape)
print(best_matches[0, 0, 0])
print(best_matches[0, 0, 1])

In [None]:
best_matches_names = best_matches[:, :, 0]
print(best_matches_names.shape)

### PR Curve

In [None]:
# minimum score threshold to test
min_threshold = 0.5
max_threshold = 5.0
metrics.precision_weighted_recall_curve_at_threshold(weighted_actual_names_test, best_matches, 
                                                     min_threshold, max_threshold, step=0.05, distances=True)

### AUC

In [None]:
metrics.get_auc(weighted_actual_names_test, best_matches)

### Precision and recall at a threshold

In [None]:
threshold = 0.97
print(
    "precision",
    metrics.avg_precision_at_threshold(weighted_actual_names_test, best_matches, threshold=threshold, distances=True),
)
print(
    "recall",
    metrics.avg_weighted_recall_at_threshold(
        weighted_actual_names_test, best_matches, threshold=threshold, distances=True
    ),
)

In [None]:
wandb.finish()