In [1]:
%load_ext autoreload
%autoreload 2

# Train a bi-encoder

learn name-to-vec encodings

Hyperparameters
* epochs
* embedding_dim
* hard_negs
* easy_negs
* bi_encoder_vocab_size

### Results from swivel

- lr = 0.001 num_epochs = 1 embedding_dim = 256 hard_negs = 5 easy_negs = 5 vocab_size = 256 Epoch 0 loss 0.024382922810735477
- lr = 0.00005 num_epochs = 1 embedding_dim = 256 hard_negs = 15 easy_negs = 15 vocab_size = 256 Epoch 0 loss 0.023236593952613904
- lr = 0.00005 num_epochs = 1 embedding_dim = 256 hard_negs = 10 easy_negs = 30 vocab_size = 256 Epoch 0 loss 0.025198661789189367
- lr = 0.000025 num_epochs = 1 embedding_dim = 256 hard_negs = 10 easy_negs = 30 vocab_size = 256 Epoch 0 loss 0.025274253689720905
- lr = 0.001 num_epochs = 1 embedding_dim = 256 hard_negs = 5 easy_negs = 5 vocab_size = 2048 Epoch 0 loss 0.020269810817918666
- lr = 0.001 num_epochs = 1 embedding_dim = 256 hard_negs = 10 easy_negs = 20 vocab_size = 2048 Epoch 0 loss 0.021378852258816573
- lr = 0.0001 num_epochs = 1 embedding_dim = 256 hard_negs = 10 easy_negs = 20 vocab_size = 2048 Epoch 0 loss 0.019285232987143144
- lr = 0.00005 num_epochs = 1 embedding_dim = 256 hard_negs = 10 easy_negs = 20 vocab_size = 2048 Epoch 0 loss 0.018867752840737512
- lr = 0.0001 num_epochs = 1 embedding_dim = 256 hard_negs = 15 easy_negs = 15 vocab_size = 2048 Epoch 0 loss 0.018372964499742593
- lr = 0.00005 num_epochs = 1 embedding_dim = 256 hard_negs = 15 easy_negs = 15 vocab_size = 2048 Epoch 0 loss 0.01798968744305818
- lr = 0.00005 num_epochs = 1 embedding_dim = 256 hard_negs = 10 easy_negs = 30 vocab_size = 256
- lr = 0.000025 num_epochs = 1 embedding_dim = 256 hard_negs = 10 easy_negs = 30 vocab_size = 256 Epoch 0 loss 0.019587983736692644

### Results from cross-encoder

- lr = 0.00005 num_epochs = 8 embedding_dim = 256 vocab_size = 2048 Epoch 7 loss 0.02193056223032252
- lr = 0.0001 num_epochs = 8 embedding_dim = 256 vocab_size = 2048 Epoch 7 loss 0.022243591166723677
- lr = 0.001 num_epochs = 3 embedding_dim = 256 vocab_size = 2048 Epoch 2 loss 0.02397067276433419

In [2]:
import csv
import gzip
import random

import torch
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split

from nama.data.filesystem import download_file_from_s3, upload_file_to_s3
from nama.models.biencoder import BiEncoder

from nama.models.tokenizer import get_tokenize_function_and_vocab

In [3]:
# Config

# TODO run both given and surname
given_surname = "given"
# given_surname = "surname"

use_ce = True
tree_name_min_freq = 1000

# hyperparameters
num_epochs = 12
embedding_dim = 256
hard_negs = 10  # only for swivel
easy_negs = 30  # only for swivel
bi_encoder_vocab_size = 2048
learning_rate = 0.00005

batch_size = 64
use_amsgrad = False

report_size = 10000
max_tokens = 10

tokenizer_path=f"s3://fs-nama-data/2024/nama-data/data/models/fs-{given_surname}-subword-tokenizer-{bi_encoder_vocab_size}.json"
triplets_paths=[
    # swivel path
    # f"s3://fs-nama-data/2024/familysearch-names/processed/tree-hr-{given_surname}-triplets-{hard_negs}-{easy_negs}.csv.gz",
    f"s3://fs-nama-data/2024/familysearch-names/processed/cross-encoder-triplets-{given_surname}-train.csv",
    f"s3://fs-nama-data/2024/familysearch-names/processed/cross-encoder-triplets-{given_surname}-common.csv",
    f"s3://fs-nama-data/2024/familysearch-names/processed/tree-hr-{given_surname}-triplets-{tree_name_min_freq}-augmented.csv.gz",
]

if use_ce:
    model_filename = f"bi_encoder-ce-{given_surname}-{num_epochs}-{embedding_dim}-{num_epochs}-{bi_encoder_vocab_size}-{learning_rate}"
else:
    model_filename = f"bi_encoder-{given_surname}-{num_epochs}-{embedding_dim}-{hard_negs}-{easy_negs}-{bi_encoder_vocab_size}-{learning_rate}"
    
local_model_path_prefix = f"../data/{model_filename}"
remote_model_path = f"s3://fs-nama-data/2024/nama-data/data/models/{model_filename}.pth"

In [4]:
torch.cuda.empty_cache()
print(torch.cuda.is_available())
print("cuda total", torch.cuda.get_device_properties(0).total_memory)
print("cuda reserved", torch.cuda.memory_reserved(0))
print("cuda allocated", torch.cuda.memory_allocated(0))

True
cuda total 8141471744
cuda reserved 0
cuda allocated 0


## Load data

In [5]:
path = download_file_from_s3(tokenizer_path) if tokenizer_path.startswith("s3://") else tokenizer_path
tokenize, tokenizer_vocab = get_tokenize_function_and_vocab(tokenizer_path=path, max_tokens=max_tokens)
print(len(tokenizer_vocab))
print(tokenize('dallan'))

2048
[1584, 489, 1, 1, 1, 1, 1, 1, 1, 1]


In [6]:
%%time
all_data = []
for triplets_path in triplets_paths:
    print(triplets_path)
    path = download_file_from_s3(triplets_path) if triplets_path.startswith("s3://") else triplets_path
    my_open = gzip.open if path.endswith(".gz") else open
    with my_open(path, 'rt') as f:
        csv_reader = csv.DictReader(f)
        for row in tqdm(csv_reader, mininterval=10):
            anchor_tokens = tokenize(row['anchor'])
            pos_tokens = tokenize(row['positive'])
            neg_tokens = tokenize(row['negative'])
            target_margin = float(row['positive_score']) - float(row['negative_score'])
            all_data.append((
                anchor_tokens,
                pos_tokens,
                neg_tokens,
                target_margin,
            ))
print(len(all_data))

s3://fs-nama-data/2024/familysearch-names/processed/cross-encoder-triplets-given-train.csv


0it [00:00, ?it/s]

s3://fs-nama-data/2024/familysearch-names/processed/cross-encoder-triplets-given-common.csv


0it [00:00, ?it/s]

s3://fs-nama-data/2024/familysearch-names/processed/tree-hr-given-triplets-1000-augmented.csv.gz


0it [00:00, ?it/s]

35486190
CPU times: user 5min 23s, sys: 10.9 s, total: 5min 34s
Wall time: 6min 16s


In [7]:
%%time
random.shuffle(all_data)

CPU times: user 30.3 s, sys: 0 ns, total: 30.3 s
Wall time: 30.3 s


In [8]:
%%time
train_data, val_data = train_test_split(all_data, test_size=0.01, random_state=42)
del all_data
print(len(train_data), len(val_data))

35131328 354862
CPU times: user 20 s, sys: 377 ms, total: 20.4 s
Wall time: 20.6 s


## Train bi-encoder

In [9]:
def loss_fn(anchors, positives, negatives, labels):
    # anchor_pos_sim = (anchors * positives).sum(dim=-1)
    # anchor_neg_sim = (anchors * negatives).sum(dim=-1)
    anchor_pos_sim = F.cosine_similarity(anchors, positives, dim=-1)
    anchor_neg_sim = F.cosine_similarity(anchors, negatives, dim=-1)
    margin_pred = anchor_pos_sim - anchor_neg_sim
    return F.mse_loss(margin_pred, labels)

In [10]:
# Training loop
def train(model, train_loader, val_loader, loss_fn, optimizer, num_epochs, model_path, verbose=True):
    for epoch in range(num_epochs):
        # make sure gradient tracking is on
        model.train()
        running_loss = 0

        for ix, data in enumerate(train_loader):
            # get batch
            anchors, positives, negatives, target_margins = data

            # zero gradients
            optimizer.zero_grad()

            # Forward pass
            anchor_embeddings = model(anchors)  # Shape: (batch_size, embedding_dim)
            pos_embeddings = model(positives)  # Shape: (batch_size, embedding_dim)
            neg_embeddings = model(negatives)  # Shape: (batch_size, embedding_dim)

            # Calculate loss
            loss = loss_fn(anchor_embeddings, pos_embeddings, neg_embeddings, target_margins)

            # Backward pass and optimization step
            loss.backward()
            optimizer.step()

            # Calculate loss and report
            if verbose:
                running_loss += loss.item()
                if ix % report_size == report_size - 1:
                    avg_loss = running_loss / report_size  # loss per batch
                    print(f"Epoch {epoch} batch {ix} loss {avg_loss}")
                    running_loss = 0

        # set model to evaluation mode
        model.eval()

        # disable gradient computation
        running_loss = 0
        num_val_batches = 0
        with torch.no_grad():
            for data in val_loader:
                anchors, positives, negatives, target_margins = data
                anchor_embeddings = model(anchors)  # Shape: (batch_size, embedding_dim)
                pos_embeddings = model(positives)  # Shape: (batch_size, embedding_dim)
                neg_embeddings = model(negatives)  # Shape: (batch_size, embedding_dim)
                loss = loss_fn(anchor_embeddings, pos_embeddings, neg_embeddings, target_margins)
                running_loss += loss.item()  
                num_val_batches += 1

        # calculate average validation loss
        val_loss = running_loss / num_val_batches
        if verbose:
            print(f"VALIDATION: Epoch {epoch} loss {val_loss}")
        # save model state + model
        epoch_model_path = f"{model_path}-{epoch}"
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, epoch_model_path+".state")
        torch.save(model, epoch_model_path+".pth")
        
    # return final epoch validation loss
    return val_loss

In [11]:
# Create an instance of the bi-encoder model
model = BiEncoder(embedding_dim, len(tokenizer_vocab), max_tokens, tokenizer_vocab['[PAD]'])

# Define the optimizer
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, amsgrad=use_amsgrad)

In [12]:
def my_collate_fn(batch):
    # Transpose the batch (list of tuples) to a tuple of lists
    transposed_batch = list(zip(*batch))
    # Convert each list in the tuple to a tensor
    tensor_batch = tuple(torch.tensor(x) for x in transposed_batch)
    return tensor_batch

In [13]:
%%time
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn)
train(model, train_loader, val_loader, loss_fn, optimizer, num_epochs, local_model_path_prefix)

Epoch 0 batch 9999 loss 0.12174480192176998
Epoch 0 batch 19999 loss 0.1157549730822444
Epoch 0 batch 29999 loss 0.1099530340641737
Epoch 0 batch 39999 loss 0.10535381418615579
Epoch 0 batch 49999 loss 0.10123585356995464
Epoch 0 batch 59999 loss 0.09787476022019982
Epoch 0 batch 69999 loss 0.09467543686442077
Epoch 0 batch 79999 loss 0.0913670886632055
Epoch 0 batch 89999 loss 0.0887442516848445
Epoch 0 batch 99999 loss 0.08627734028156847
Epoch 0 batch 109999 loss 0.08357494121044874
Epoch 0 batch 119999 loss 0.0814166857328266
Epoch 0 batch 129999 loss 0.07909530360940843
Epoch 0 batch 139999 loss 0.07649908151868731
Epoch 0 batch 149999 loss 0.07457070492338388
Epoch 0 batch 159999 loss 0.07250337218064815
Epoch 0 batch 169999 loss 0.0707965960489586
Epoch 0 batch 179999 loss 0.06873922563754022
Epoch 0 batch 189999 loss 0.06668035651966929
Epoch 0 batch 199999 loss 0.06527890753187239
Epoch 0 batch 209999 loss 0.06361296641919761
Epoch 0 batch 219999 loss 0.06206824116874486
Epoch

0.02196098476987604

## Save Model

In [14]:
model.eval()
local_model_path = local_model_path_prefix+".pth"
torch.save(model, local_model_path)
upload_file_to_s3(local_model_path, remote_model_path)

True

## Hyperparameter search