In [None]:
%load_ext autoreload
%autoreload 2

# Save bi-encoder model weights

Save the weights for each token and position so we can use them in the java fs-nama.

In [None]:
import json

import torch

from src.models.biencoder import BiEncoder
from src.models.tokenizer import get_tokenize_function_and_vocab

In [None]:
given_surname = "given"
model_type = 'cecommon+0+aug-0-1'
model_path = f"../data/models/bi_encoder-{given_surname}-{model_type}.pth"
max_tokens = 10
subwords_path=f"../data/models/fs-{given_surname}-subword-tokenizer-2000f.json"

weights_path=f"../data/models/bi_encoder-{given_surname}-{model_type}-weights.json"

In [None]:
torch.cuda.empty_cache()
print(torch.cuda.is_available())
print("cuda total", torch.cuda.get_device_properties(0).total_memory)
print("cuda reserved", torch.cuda.memory_reserved(0))
print("cuda allocated", torch.cuda.memory_allocated(0))

## Load bi-encoder and vocabulary

In [None]:
model = torch.load(model_path)
model.eval()

In [None]:
tokenize, tokenizer_vocab = get_tokenize_function_and_vocab(
    max_tokens=max_tokens,
    subwords_path=subwords_path,
)
len(tokenizer_vocab)

In [None]:
tokenid2token = {token_id: token for token, token_id in tokenizer_vocab.items()}

## Save weights

In [None]:
embedding_weights = {tokenid2token[ix]: embedding for ix, embedding in enumerate(model.embedding.weight.tolist())}

In [None]:
len(embedding_weights)

In [None]:
len(embedding_weights['a'])

In [None]:
positional_weights = model.positional_embedding.weight.tolist()

In [None]:
len(positional_weights)

In [None]:
weights = {
    "tokens": embedding_weights,
    "positions": positional_weights,
}

In [None]:
weights_path

In [None]:
with open(weights_path, 'w') as f:
    json.dump(weights, f)

## Test similarity

In [None]:
name1 = "richard"
name2 = "rickert"
tokens1 = tokenize(name1)
tokens2 = tokenize(name2)
sim = model.predict(tokens1, tokens2)
print(sim)