In [None]:
%load_ext autoreload
%autoreload 2

# Train a bi-encoder

learn name-to-vec encodings

### Results

Hyperparameters
* epochs
* embedding_dim


new data @ 12 epochs:         ??? ??? ??, 189 7, 1011 193, 61728
old data @ 12 epochs:         132 124 13,   1 0,  330 197, 56461
old data w model @ 12 epochs: 145 110 76,   8 0,  482 246, 48875

In [5]:
from collections import Counter
import random
import re

import matplotlib.pyplot as plt
import pandas as pd


import torch
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split

from nama.data.filesystem import download_file_from_s3
from nama.data.utils import read_csv
from nama.models.biencoder import BiEncoder

# from nama.models.tokenizer import get_tokenize_function_and_vocab

In [3]:
given_surname = "surname"
model_type = 'cecommon+0+aug'

checkpoint_path = None  # '../data/models/bi_encoder-given-cecommon+0-0-2.state'
# cross-encoder-triplets common-0-augmented.csv = 
#   cross-encoder-triplets-common + 
#   cross-encoder-triplets-0 + 
#   tree-hr-triplets-v1-1000-augmented
path_epochs = [
    (f"../data/processed/cross-encoder-triplets-{given_surname}-common-0-augmented.csv", 3),
#     (f"../data/processed/cross-encoder-triplets-{given_surname}-0.csv", 6),
#     (f"../data/processed/cross-encoder-triplets-{given_surname}-common-0.csv", 3),
#     (f"../data/processed/tree-hr-{given_surname}-triplets-v2-1000-augmented.csv.gz", 6),
]

# hyperparameters
embedding_dim = 256
learning_rate = 0.001
batch_size = 64
use_amsgrad = False

report_size = 10000
max_tokens = 10
vocab_type = 'f'  # tokenizer based upon training name frequency
subword_vocab_size = 2000  # 500, 1000, 1500, 2000
nama_bucket = 'nama-data'
subwords_path=f"../data/models/fs-{given_surname}-subword-tokenizer-{subword_vocab_size}{vocab_type}.json"

model_path = f"../data/models/bi_encoder-{given_surname}-{model_type}"

In [4]:
torch.cuda.empty_cache()
print(torch.cuda.is_available())
print("cuda total", torch.cuda.get_device_properties(0).total_memory)
print("cuda reserved", torch.cuda.memory_reserved(0))
print("cuda allocated", torch.cuda.memory_allocated(0))

True
cuda total 8141471744
cuda reserved 0
cuda allocated 0


## Combine triplets - only need to run once

In [17]:
path = download_file_from_s3(f"s3://fs-nama-data/2023/nama-data/data/processed/cross-encoder-triplets-{given_surname}-0-augmented.csv")
df = read_csv(path)
#df2 = read_csv(f"../data/processed/cross-encoder-triplets-{given_surname}-common.csv")
#df3 = read_csv(f"../data/processed/tree-hr-{given_surname}-triplets-v2-1000-augmented.csv.gz")
#df = pd.concat([df1, df2, df3])
print(df.shape)
df.head(5)
# df.to_csv(f"../data/processed/cross-encoder-triplets-{given_surname}-common-0-augmented.csv", index=False)

Error downloading file s3://fs-nama-data/2023/nama-data/data/processed/cross-encoder-triplets-surname-0-augmented.csv from S3: An error occurred (404) when calling the HeadObject operation: Not Found


FileNotFoundError: [Errno 2] No such file or directory: '/tmp/tmpzecdh15j.csv'

## Load data

### Get tokenizer

In [None]:
tokenize, tokenizer_vocab = get_tokenize_function_and_vocab(
    max_tokens=max_tokens,
    subwords_path=subwords_path,
    nama_bucket=nama_bucket,
)
len(tokenizer_vocab)

In [None]:
tokenize('dallan')

### Generate anchor-pos-neg triplets

In [None]:
# array of (anchor_tokens, pos_tokens, neg_tokens, target_margin)
def generate_training_data(train_triplets_df):
    all_data = []
    for anchor, pos, neg, pos_score, neg_score in tqdm(zip(
        train_triplets_df['anchor'],
        train_triplets_df['positive'],
        train_triplets_df['negative'],
        train_triplets_df['positive_score'],
        train_triplets_df['negative_score'],
    ), mininterval=2):
        anchor_tokens = tokenize(anchor)
        pos_tokens = tokenize(pos)
        neg_tokens = tokenize(neg)
        target_margin = pos_score - neg_score
        # anchor, positive, hard-negative
        all_data.append((
            anchor_tokens,
            pos_tokens,
            neg_tokens,
            target_margin,
        ))
    return all_data

## Train bi-encoder

In [None]:
def loss_fn(anchors, positives, negatives, labels):
    # anchor_pos_sim = (anchors * positives).sum(dim=-1)
    # anchor_neg_sim = (anchors * negatives).sum(dim=-1)
    anchor_pos_sim = F.cosine_similarity(anchors, positives, dim=-1)
    anchor_neg_sim = F.cosine_similarity(anchors, negatives, dim=-1)
    margin_pred = anchor_pos_sim - anchor_neg_sim
    return F.mse_loss(margin_pred, labels)

In [None]:
# Training loop
def train(model, train_loader, val_loader, loss_fn, optimizer, num_epochs, model_path, verbose=True):
    for epoch in range(num_epochs):
        # make sure gradient tracking is on
        model.train()
        running_loss = 0

        for ix, data in enumerate(train_loader):
            # get batch
            anchors, positives, negatives, target_margins = data

            # zero gradients
            optimizer.zero_grad()

            # Forward pass
            anchor_embeddings = model(anchors)  # Shape: (batch_size, embedding_dim)
            pos_embeddings = model(positives)  # Shape: (batch_size, embedding_dim)
            neg_embeddings = model(negatives)  # Shape: (batch_size, embedding_dim)

            # Calculate loss
            loss = loss_fn(anchor_embeddings, pos_embeddings, neg_embeddings, target_margins)

            # Backward pass and optimization step
            loss.backward()
            optimizer.step()

            # Calculate loss and report
            if verbose:
                running_loss += loss.item()
                if ix % report_size == report_size - 1:
                    avg_loss = running_loss / report_size  # loss per batch
                    print(f"Epoch {epoch} batch {ix} loss {avg_loss}")
                    running_loss = 0

        # set model to evaluation mode
        model.eval()

        # disable gradient computation
        running_loss = 0
        num_val_batches = 0
        with torch.no_grad():
            for data in val_loader:
                anchors, positives, negatives, target_margins = data
                anchor_embeddings = model(anchors)  # Shape: (batch_size, embedding_dim)
                pos_embeddings = model(positives)  # Shape: (batch_size, embedding_dim)
                neg_embeddings = model(negatives)  # Shape: (batch_size, embedding_dim)
                loss = loss_fn(anchor_embeddings, pos_embeddings, neg_embeddings, target_margins)
                running_loss += loss.item()  
                num_val_batches += 1

        # calculate average validation loss
        val_loss = running_loss / num_val_batches
        if verbose:
            print(f"VALIDATION: Epoch {epoch} loss {val_loss}")
        # save model state + model
        epoch_model_path = f"{model_path}-{epoch}"
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, epoch_model_path+".state")
        torch.save(model, epoch_model_path+".pth")
        
    # return final epoch validation loss
    return val_loss

## Hyperparameter search

## Train model

In [None]:
# Create an instance of the bi-encoder model
model = BiEncoder(embedding_dim, len(tokenizer_vocab), max_tokens, tokenizer_vocab['[PAD]'])

# Define the optimizer
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, amsgrad=use_amsgrad)

In [None]:
def my_collate_fn(batch):
    # Transpose the batch (list of tuples) to a tuple of lists
    transposed_batch = list(zip(*batch))
    # Convert each list in the tuple to a tensor
    tensor_batch = tuple(torch.tensor(x) for x in transposed_batch)
    return tensor_batch

In [None]:
if checkpoint_path:
    print(checkpoint_path)
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    model.train()

In [None]:
%%time

for ix, (path, num_epochs) in enumerate(path_epochs):
    print(path, num_epochs)
    train_triplets_df = read_csv(path)
    all_data = generate_training_data(train_triplets_df)
    del train_triplets_df
    train_data, val_data = train_test_split(all_data, test_size=0.01, random_state=42)
    del all_data
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn)
    train(model, train_loader, val_loader, loss_fn, optimizer, num_epochs, f"{model_path}-{ix}")
    del train_loader
    del val_loader
    del train_data
    del val_data

## Save Model

In [None]:
model_path

In [None]:
model.eval()
torch.save(model, model_path+".pth")