In [1]:
import polars as pl
import utils.embed_concepts as embed_concepts
import torch
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
FILE_CONFIG = {
    "path_all_concept": "D:/lora_finetune_eval/basic_info/concept_all.parquet",
    "path_is_a" : "D:/lora_finetune_eval/basic_info/graph_is_a_invariant.csv",
    "mapped_concept": "D:/lora_finetune_eval/basic_info/mapped_concepts_2025-04-01.csv",
    "training_triplet_idx": "D:/lora_finetune_eval/basic_info/training_anchor_idx_1M.parquet",

    "icd_snomed" : "D:/lora_finetune_eval/icd_snomed/",

    "embedding_save_path": "D:/lora_finetune_eval/embedding_by_model/",
    "syn_embedding_save_path": "D:/lora_finetune_eval/syn_embedding_by_model/",
    "new_exp_embedding_save_path": "D:/lora_finetune_eval/new_exp_embedding_by_model/",
    "embedding_icd_snomed_save_path": "D:/lora_finetune_eval/embedding_icd_snomed_by_model/",

}

# 1. embed

In [3]:
all_concept = pl.read_parquet(FILE_CONFIG["path_all_concept"])
new_expression_is_a = pl.read_csv(FILE_CONFIG["path_is_a"])
df_syn_defined = all_concept.explode("syns_list").filter(pl.col("status") == "defined").drop_nulls()

expressions = all_concept['expression'].to_list()
labels = all_concept['n.label'].to_list()

new_exp = new_expression_is_a['new_exp'].to_list()
idx_true = new_expression_is_a['idx'].to_list()

syns = df_syn_defined["syns_list"].to_list()
idx = df_syn_defined["idx"].to_list()

snomed_info = pl.read_csv(FILE_CONFIG["icd_snomed"] + "snomed_info.csv")
icd_info = pl.read_csv(FILE_CONFIG["icd_snomed"] + "icd_info.csv")

In [4]:
batch_size = 128
model_config = embed_concepts.MODEL_CONFIG

In [None]:
# expression and label
for model_key in model_config.keys():
    save_path = FILE_CONFIG["embedding_save_path"] + model_key + ".pt"

    if os.path.exists(save_path):
        print(f"Skipping {model_key}: Embeddings already saved at {save_path}")
        continue  # skip to next model

    print(f"\nEncoding with {model_key}...")

    model = embed_concepts.load_model(model_key, device=device)

    with torch.no_grad():
        print("Encoding expressions...", model_key)
        embeddings_exp = model.encode(expressions, batch_size=batch_size)
        print("Encoding labels...", model_key)
        embeddings_labels = model.encode(labels, batch_size=batch_size)

    embed_concepts.save_embeddings({
                    "model_name": model_key,
                    "expressions_embeddings": embeddings_exp.cpu(),
                    "labels_embeddings": embeddings_labels.cpu(),
                }, save_path)


        

Skipping finetune_1M: Embeddings already saved at D:/lora_finetune_eval/embedding_by_model/finetune_1M.pt
Skipping baseline: Embeddings already saved at D:/lora_finetune_eval/embedding_by_model/baseline.pt
Skipping clinicalbert: Embeddings already saved at D:/lora_finetune_eval/embedding_by_model/clinicalbert.pt
Skipping bio_clinicalbert: Embeddings already saved at D:/lora_finetune_eval/embedding_by_model/bio_clinicalbert.pt
Skipping e5_base: Embeddings already saved at D:/lora_finetune_eval/embedding_by_model/e5_base.pt
Skipping gte_base: Embeddings already saved at D:/lora_finetune_eval/embedding_by_model/gte_base.pt

Encoding with sapbert...
Encoding expressions... sapbert


Encoding:   0%|          | 0/3068 [00:00<?, ?it/s]

In [None]:
# synonyms
syns = df_syn_defined["syns_list"].to_list()
idx = df_syn_defined["idx"].to_list()

for model_key in model_config.keys():
    save_path = FILE_CONFIG["syn_embedding_save_path"] + model_key + ".pt"

    if os.path.exists(save_path):
        print(f"Skipping {model_key}: Embeddings already saved at {save_path}")
        continue  # skip to next model

    print(f"\nEncoding with {model_key}...")
    
    model = embed_concepts.load_model(model_key, device=device)
    with torch.no_grad():
        print("Encoding synonyms...", model_key)
        embeddings_syn = model.encode(syns, batch_size=batch_size)

    embed_concepts.save_embeddings({
        "model_name": model_key,
        "synonyms": syns,
        "expressions_embeddings": embeddings_syn.cpu(),
        "idx": idx,
    }, save_path)

In [None]:
# new exp
for model_key in model_config.keys():
    save_path = FILE_CONFIG["new_exp_embedding_save_path"] + model_key + ".pt"

    if os.path.exists(save_path):
        print(f"Skipping {model_key}: Embeddings already saved at {save_path}")
        continue  # skip to next model

    print(f"\nEncoding with {model_key}...")
    model = embed_concepts.load_model(model_key, device=device)
    with torch.no_grad():
        print("Encoding new expressions...", model_key)

        embeddings_exp_new = model.encode(new_exp, batch_size=batch_size)
    embed_concepts.save_embeddings({
                    "model_name": model_key,
                    "new_expressions_embeddings": embeddings_exp_new.cpu(),
                    "idx_true": idx_true
                }, save_path)

In [None]:
# icd and snomed matching
snomed_labels = snomed_info["SNOMED_label"].to_list()
icd_labels = icd_info["ICD_label"].to_list()

for model_key in model_config.keys():
    save_path = FILE_CONFIG["embedding_icd_snomed_save_path"] + model_key + ".pt"
    
    if os.path.exists(save_path):
        print(f"Skipping {model_key}: Embeddings already saved at {save_path}")
        continue  # skip to next model

    print(f"\nEncoding with {model_key}...")
    print("Encoding ICD...")
    model = embed_concepts.load_model(model_key, device=device)
    with torch.no_grad():
        embeddings_icd = model.encode(icd_labels, batch_size=batch_size)
        print("Encoding SNOMED...")
        embeddings_snomed = model.encode(snomed_labels, batch_size=batch_size)

    embed_concepts.save_embeddings({
        "model_name": model_key,
        "icd_embeddings": embeddings_icd.cpu(),
        "snomed_embeddings": embeddings_snomed.cpu(),
    }, save_path)
