In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
model_nemo_config_path = './xlmr-sapbert-large_entity_linking_config.yaml'
data_path = './train/subtask2-linking'
output_path = 'output_path'

In [None]:
%%capture 
!pip uninstall pytorch-lightning
!pip install pytorch-lightning==2.0.9

In [None]:
# NeMo does not support XLMR-SapBERT-Large out of the box, so we use a modified NeMo
%%capture
!unzip 'NeMo-1.20.0-modified.zip'

In [None]:
%%capture
!pip install "cython<3.0.0" && pip install --no-build-isolation pyyaml==5.4.1
!pip install '/content/NeMo-1.20.0-modified'
!pip install wget
!pip install faiss-gpu
!pip install ipywidgets hydra-core pytorch-lightning transformers accelerate sentencepiece jedi youtokentome braceexpand webdataset ijson sacremoses sacrebleu rouge_score einops opencc pangu

In [None]:
import faiss
import torch
import wget
import os
import numpy as np
import pandas as pd

from omegaconf import OmegaConf
from pytorch_lightning import Trainer
from IPython.display import display
from tqdm import tqdm

from nemo.collections import nlp as nemo_nlp
from nemo.utils.exp_manager import exp_manager

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load model
base_model_cfg = OmegaConf.load(model_nemo_config_path)

# Set train/val datasets to None to avoid loading datasets associated with training
base_model_cfg.model.train_ds = None
base_model_cfg.model.validation_ds = None
base_model = nemo_nlp.models.EntityLinkingModel(base_model_cfg.model).to(device)

In [None]:
# Helper function to get data embeddings
def get_embeddings_with_texts(model, dataloader):
    embeddings, cids, concept_texts = [], [], []

    with torch.no_grad():
        for batch in tqdm(dataloader):
            input_ids, token_type_ids, attention_mask, batch_cids = batch
            batch_embeddings = model.forward(input_ids=input_ids.to(device),
                                             token_type_ids=token_type_ids.to(device),
                                             attention_mask=attention_mask.to(device))

            # Accumulate index embeddings and their corresponding IDs
            embeddings.extend(batch_embeddings.cpu().detach().numpy())
            cids.extend(batch_cids)

            #print(input_ids.shape)
            for idx in range(input_ids.shape[0]):
                concept = model.tokenizer.decode(input_ids[idx], skip_special_tokens=True)
                concept_texts.append(concept)

    return embeddings, cids, concept_texts

In [None]:
from collections import defaultdict

def evaluate_and_save_embeddings(model, test_kb, test_queries, ks): # Should be two/three separate functions, ideally
    # Initialize knowledge base and query data loaders
    test_kb_dataloader = model.setup_dataloader(test_kb, is_index_data=True)
    test_query_dataloader = model.setup_dataloader(test_queries, is_index_data=True)

    # Get knowledge base and query embeddings
    test_kb_embs, test_kb_cids, test_kb_texts = get_embeddings_with_texts(model, test_kb_dataloader)
    test_query_embs, test_query_cids, test_query_texts = get_embeddings_with_texts(model, test_query_dataloader)

    # Save the knowledge base embeddings and codes
    np.save(f'{output_path}/embeddings.npy', np.array(test_kb_embs, dtype=np.float32), allow_pickle=True)
    np.save(f'{output_path}/snomed_codes.npy', np.array(test_kb_cids, dtype=np.str_), allow_pickle=True)
    np.save(f'{output_path}/terms.npy', np.array(test_query_texts, dtype=np.str_), allow_pickle=True)

    # Calculate the cosine distance between each query and knowledge base concept
    score_matrix = np.matmul(np.array(test_query_embs), np.array(test_kb_embs).T)
    accs = {k : 0 for k in ks}
    candidates_by_query_code = defaultdict(dict)

    # Compare the knowledge base IDs of the knowledge base entities with
    # the smallest cosine distance from the query
    for query_idx in tqdm(range(len(test_query_cids))):
        query_emb = test_query_embs[query_idx]
        query_cid = test_query_cids[query_idx]
        query_text = test_query_texts[query_idx]
        query_scores = score_matrix[query_idx]

        for k in ks:
            topk_idxs = np.argpartition(query_scores, -k)[-k:]
            topk_cids = [test_kb_cids[idx] for idx in topk_idxs]
            top_scores = [query_scores[idx] for idx in topk_idxs]

            # If the correct query ID is among the top k closest kb IDs
            # the model correctly linked the entity
            is_code_in_topk = query_cid in topk_cids
            accs[k] += int(is_code_in_topk)

            candidates_by_query_code[query_cid]['term'] = query_text
            candidates_by_query_code[query_cid][k] = topk_cids
            candidates_by_query_code[query_cid][f"top_{k}_scores"] = top_scores
            candidates_by_query_code[query_cid][f"in_{k}"] = is_code_in_topk

    for k in ks:
        accs[k] /= len(test_query_cids)

    return accs, candidates_by_query_code

In [None]:
# Create configs for our test data
test_kb = OmegaConf.create({
    "data_file": f'{data_path}/alt_umls_symptoms_gazetteer.tsv',
    "max_seq_length": 128,
    "batch_size": 64,
    "shuffle": False,
})

test_queries = OmegaConf.create({
    "data_file": f'{data_path}/query_no_exact_match_code_term_no_header.tsv',
    "max_seq_length": 128,
    "batch_size": 64,
    "shuffle": False,
})

In [None]:
ks = [1, 5, 10, 25]

base_accs, candidates_by_query_term = evaluate_and_save_embeddings(base_model, test_kb, test_queries, ks)
pd.DataFrame(data=candidates_by_query_term).transpose().rename(columns={1: 'top_1', 5: 'top_5', 10:'top_10', 25: 'top_25'}).to_excel(f'{output_path}/candidates_with_scores.xlsx', encoding='utf8')

In [None]:
print("Top 1 and Top 5 Accuracy Comparison:")
results_df = pd.DataFrame([base_accs], columns=["Model", 1, 5, 10, 25]) # base_accs
results_df = results_df.style.set_properties(**{'text-align': 'left', }).set_table_styles([dict(selector='th', props=[('text-align', 'left')])])
display(results_df)