# 0. import

In [1]:
import polars as pl
import torch
import numpy as np
import yaml
from tqdm import tqdm
import torch.nn.functional as F

from utils.model_utils import load_checkpoint,get_embeddings, encode
from model import LoraSapbert


with open("configs.yaml", "r") as f:
    CONFIG = yaml.safe_load(f)

device = "cuda" if torch.cuda.is_available() else "cpu"


# 1. load lora-SapBERT checkpoint for embeddings

In [2]:
model_emb_ft = LoraSapbert(**CONFIG["SAPBERT_PARAMETERS"]).to(device)
_ = load_checkpoint(model_emb_ft, CONFIG["EMBEDDINGS_SAVE_PATH"]["ft_model"], device="cuda", strict=False)
model_emb_ft.eval()


trainable params: 442,368 || all params: 109,924,608 || trainable%: 0.4024


LoraSapbert(
  (model): PeftModelForFeatureExtraction(
    (base_model): LoraModel(
      (model): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0-11): 12 x BertLayer(
              (attention): BertAttention(
                (self): BertSdpaSelfAttention(
                  (query): lora.Linear(
                    (base_layer): Linear(in_features=768, out_features=768, bias=True)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=

# 2. get dataframes: concepts labels to embed, and mapped concepts if needed (not indexed)

In [26]:
df_ulms = pl.read_parquet(CONFIG["EMBEDDINGS_SAVE_PATH"]["all_ulms"])
df_hug_snomed = pl.read_parquet(CONFIG["EMBEDDINGS_SAVE_PATH"]["graph_hug_concepts"])
df_mapped = pl.read_parquet(CONFIG["CONCEPT_DATA"]["mapped_concepts"]).drop("pointer").unique()
df_mapped = df_hug_snomed.join(df_mapped, on = "id")


# 3. (already computed, no need to run) get embeddings (the idx in dataframe = the row position in the resulting tensor)

In [None]:
emb_label_snomed_ft = get_embeddings(model_emb_ft, df_hug_snomed, "label")
emb_expression_snomed_ft = get_embeddings(model_emb_ft, df_hug_snomed, "expression")
# Save fine-tuned embeddings
torch.save({
    'emb_exp': emb_expression_snomed_ft,
    'emb_concepts': emb_label_snomed_ft,
}, CONFIG["EMBEDDINGS_SAVE_PATH"]["snomed_embedding"])


emb_label_ft_ulms = get_embeddings(model_emb_ft, df_ulms, "STR")
torch.save({
    'emb_ulms_label': emb_label_ft_ulms,
}, CONFIG["EMBEDDINGS_SAVE_PATH"]["ulms_embedding"])



# 4. load all existing embeddings

In [4]:
# get all embeddings of hug-snomed
ft_data = torch.load(CONFIG["EMBEDDINGS_SAVE_PATH"]["snomed_embedding"])
# emb_expr_ft = ft_data['emb_exp'].to("cuda", dtype=torch.float16)
emb_label_ft = ft_data['emb_concepts'].to("cuda", dtype=torch.float16)

# get all embeddings of ulms
emb_ulms_label_ft = torch.load(CONFIG["EMBEDDINGS_SAVE_PATH"]["ulms_embedding"])['emb_ulms_label']


In [5]:
emb_ulms_label_ft.shape, emb_label_ft.shape #,emb_expr_ft.shape

(torch.Size([2259432, 768]), torch.Size([394508, 768]))

# compute the cosine similarity
Example of free text and all ulms concepts, find the closest standarized concept for each free text

In [8]:
def find_closest_concept(emb_query, emb_candidate, df_query, df_candidate, col_label_query, col_label_candidate):
    K = 1
    batch_size = 1024
    chunk_size = 200_000
    device = "cuda"
    use_half = False

    torch.backends.cuda.matmul.allow_tf32=False
    torch.set_grad_enabled(False)

    dtype = torch.float16 if use_half else torch.float32
    P, D = emb_query.shape
    M    = emb_candidate.shape[0]

    # normalize for better cosine computation
    emb_query = F.normalize(emb_query, dim = -1)
    emb_candidate = F.normalize(emb_candidate, dim = -1) # need to put it in gpu only when it's necessary
    all_rows = []

    for s in tqdm(range(0, P, batch_size)):
        e = min(s + batch_size, P)
        B = e - s
        # print(f"[{s}:{e}] / {P}")

        # move query batch
        Q = emb_query[s:e].to(device=device, dtype=dtype, non_blocking=True).contiguous()  # [B, D]

        # best-K placeholders
        best_vals = torch.full((B, K), -1e9, device=device, dtype=torch.float32)
        best_pos  = torch.full((B, K),  -1, device=device, dtype=torch.long)

        # scan through UMLS in tiles
        start = 0
        while start < M:
            end = min(start + chunk_size, M)

            Ublk = emb_candidate[start:end].to(device=device, dtype=dtype, non_blocking=True).contiguous()  # [C, D]
            sims = Q @ Ublk.transpose(0,1).contiguous()                                                # [B, C]

            cK = min(K, sims.size(1))
            cand_vals, cand_pos = torch.topk(sims, k=cK, dim=1, largest=True, sorted=False)            # [B, cK]
            cand_pos += start

            # merge with running best
            merged_vals = torch.cat([best_vals, cand_vals.float()], dim=1)                             # [B, K+cK]
            merged_pos  = torch.cat([best_pos,  cand_pos], dim=1)                                      # [B, K+cK]
            best_vals, idx = torch.topk(merged_vals, k=K, dim=1, largest=True, sorted=True)
            best_pos = torch.gather(merged_pos, 1, idx)

            start = end
        
        # collect rows for this batch
        best_vals_cpu = best_vals.float().cpu()
        best_pos_cpu  = best_pos.long().cpu()

        for i in range(B):
            for k in range(K):
                query_idx = df_query[s + i]["idx"][0]
                all_rows.append({
                    # "query_idx": query_idx,
                    "query_label" : df_query.filter(pl.col("idx") == query_idx)[col_label_query][0],
                    "candidate_idx": int(best_pos_cpu[i, k]),
                    "candidate_label" : df_candidate[int(best_pos_cpu[i, k])][col_label_candidate][0],
                    "candidate_coding_system": df_candidate[int(best_pos_cpu[i, k])]["SAB"][0],
                    "similarity": float(best_vals_cpu[i, k])
                })

    # build dataframe
    df_results = pl.DataFrame(all_rows)
    return df_results


## example of free text - ULMS

In [43]:
# need to build a dataframe first for freetext
texts = ["dental implant", "father not a smoker"]
emb_free_text_ft = encode(model_emb_ft, texts)
df_freetext = pl.DataFrame({"label": texts})
df_freetext.insert_column(0, pl.Series(range(len(df_freetext))).alias("idx"))

# compute similarity
df_result = find_closest_concept(emb_free_text_ft, emb_ulms_label_ft, df_freetext, df_ulms, "label", "STR")
df_result


Encoding:   0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:02<00:00,  2.45s/it]


query_label,candidate_idx,candidate_label,candidate_coding_system,similarity
str,i64,str,str,f64
"""dental implant""",1482362,"""Prosthetic dental implant""","""ICD9CM""",0.958077
"""father not a smoker""",1365236,"""Father does not smoke""","""SNOMEDCT_US""",0.953439


## example of SNOMED - ULMS

In [42]:
df_result = find_closest_concept(emb_label_ft[:5], emb_ulms_label_ft, df_hug_snomed[:5], df_ulms, "label", "STR")
df_result

100%|██████████| 1/1 [00:02<00:00,  2.56s/it]


query_label,candidate_idx,candidate_label,candidate_coding_system,similarity
str,i64,str,str,f64
"""Laser beam guide (physical obj…",1676926,"""Laser beam guide (physical obj…","""SNOMEDCT_US""",0.99996
"""|Anaerobic microbial culture (…",1652062,"""Prevotella nigrescens or Prevo…","""SNOMEDCT_US""",0.720952
"""Footwear feature (observable e…",855088,"""Footwear feature (observable e…","""SNOMEDCT_US""",0.999913
"""Open reduction of closed humer…",108893,"""Open reduction of closed humer…","""SNOMEDCT_US""",1.000063
"""Murine sarcoma virus (organism…",34889,"""Murine sarcoma virus (organism…","""SNOMEDCT_US""",1.00009


# example post - umls

In [41]:
# isolate post firstly
df_hug_snomed_post = df_hug_snomed.filter(pl.col("concept_type") == "SCT_POST")
idx_post = df_hug_snomed_post['idx'].to_list()
emb_label_ft_post = emb_label_ft[idx_post].clone().detach()

df_result = find_closest_concept(emb_label_ft_post[10:15], emb_ulms_label_ft, df_hug_snomed_post[10:15], df_ulms, "label", "STR")
df_result

100%|██████████| 1/1 [00:02<00:00,  2.53s/it]


query_label,candidate_idx,candidate_label,candidate_coding_system,similarity
str,i64,str,str,f64
"""|Amitriptyline measurement (pr…",1062757,"""Amitriptyline:Substance Concen…","""LNC""",0.848368
"""|Aerobic microbial culture (pr…",1604598,"""Swab from nasal sinus""","""SNOMEDCT_US""",0.691654
"""|Mycobacteria culture (procedu…",136616,"""Mycobacteria culture""","""SNOMEDCT_US""",0.783549
"""|Interleukin-2 assay (procedur…",138181,"""Interleukin-2 assay (procedure…","""SNOMEDCT_US""",0.933659
"""|Aerobic microbial culture (pr…",246079,"""Klebsiella oxytoca""","""SNOMEDCT_US""",0.714688


# similarity mapped concepts and all UMLS

In [27]:
idx_mapped = df_mapped['idx'].to_list()
emb_label_ft_mapped = emb_label_ft[idx_mapped].clone().detach()
df_result = find_closest_concept(emb_label_ft_mapped[10:15], emb_ulms_label_ft, df_mapped[10:15], df_ulms, "label", "STR")
df_result


100%|██████████| 1/1 [00:02<00:00,  2.46s/it]


query_label,candidate_idx,candidate_label,candidate_coding_system,similarity
str,i64,str,str,f64
"""Nutritional marasmus (disorder…",73889,"""Nutritional marasmus (disorder…","""SNOMEDCT_US""",1.000014
"""|Aerobic microbial culture (pr…",1604535,"""Swab from peritoneal cavity st…","""SNOMEDCT_US""",0.669004
"""|Aerobic microbial culture (pr…",874012,"""Streptococcus agalactiae cultu…","""SNOMEDCT_US""",0.741366
"""|Polymerase chain reaction ana…",2172044,"""Carbapenemase enzyme panel""","""LNC""",0.723215
"""|Polymerase chain reaction ana…",1877686,"""Norovirus genogroups I & II ri…","""LNC""",0.711229


# similarity mapped concepts and all ICD-9 AND 10

In [45]:
idx_mapped = df_mapped['idx'].to_list()
emb_label_ft_mapped = emb_label_ft[idx_mapped].clone().detach()

df_ulms_icd = df_ulms.filter(pl.col("SAB").str.contains("ICD"))
idx_icd = df_ulms_icd['idx'].to_list()
emb_label_ft_icd = emb_ulms_label_ft[idx_icd].clone().detach()

df_result = find_closest_concept(emb_label_ft_mapped[100:1000], emb_label_ft_icd, df_mapped[100:1000], df_ulms_icd, "label", "STR")
df_result

100%|██████████| 1/1 [00:00<00:00,  1.06it/s]


query_label,candidate_idx,candidate_label,candidate_coding_system,similarity
str,i64,str,str,f64
"""|Human leukocyte antigen DQB1 …",22778,"""Chronic viral hepatitis B with…","""ICD10CM""",0.424819
"""Laparoscopic repair of incisio…",32518,"""Laparoscopic repair of diaphra…","""ICD9CM""",0.796782
"""Sotalol (substance)""",393536,"""IST""","""ICD10CM""",0.52341
"""Surgery of cataract of bilater…",312564,"""Cataract secondary to ocular d…","""ICD10CM""",0.709546
"""Rheumatic aortic regurgitation…",7921,"""Rheumatic aortic insufficiency""","""ICD10CM""",0.934957
…,…,…,…,…
"""|Hantavirus immunoglobulin G l…",390047,"""Transfusion of Hyperimmune Glo…","""ICD10PCS""",0.456085
"""|Tooth disorder (disorder)|+|A…",21944,"""Dental alveolar anomalies""","""ICD10CM""",0.905493
"""Primitive neuroectodermal tumo…",32317,"""Malignant poorly differentiate…","""ICD10CM""",0.718497
"""|Mycology culture (procedure)|…",361070,"""Positive culture findings of u…","""ICD10CM""",0.816652
