In [1]:
import torch
import polars as pl
from tqdm.auto import tqdm  # Add this at top if not already
import pickle


In [2]:
PATH = "D:/UMLS/"
FILE_PATH = "D:/UMLS/exact_mapping.csv"
SAVE_PATH = "D:/UMLS/icd_snomed_mappings.pkl"


In [3]:

# Load the CSV file into a Polars DataFrame
df_mapping = pl.read_csv(FILE_PATH)

df_mapping_snomed = df_mapping["SNOMED_ID", "SNOMED_label"].unique()
s = pl.Series("idx", range(len(df_mapping_snomed)))
df_mapping_snomed.insert_column(0, s).write_csv("D:/UMLS/snomed_info.csv")

df_mapping_icd = df_mapping["ICD_ID", "ICD_label"].unique()
s = pl.Series("idx", range(len(df_mapping_icd)))
df_mapping_icd.insert_column(0, s).write_csv("D:/UMLS/icd_info.csv")

id2snomed = dict(zip(df_mapping_snomed["idx"], df_mapping_snomed["SNOMED_ID"]))
id2icd = dict(zip(df_mapping_icd["idx"], df_mapping_icd["ICD_ID"]))
snomed2id = dict(zip(df_mapping_snomed["SNOMED_ID"], df_mapping_snomed["idx"]))
icd2id = dict(zip(df_mapping_icd["ICD_ID"], df_mapping_icd["idx"]))

In [4]:
df_mapping = df_mapping.with_columns(idx_snomed = pl.col("SNOMED_ID").replace(snomed2id),
                        idx_icd = pl.col("ICD_ID").replace(icd2id))

In [5]:
df_icd2snomed = df_mapping.group_by("idx_icd").agg("idx_snomed")
icd2snomed = dict(zip(df_icd2snomed["idx_icd"], df_icd2snomed["idx_snomed"]))

df_snomed2icd = df_mapping.group_by("idx_snomed").agg("idx_icd")
snomed2icd = dict(zip(df_snomed2icd["idx_snomed"], df_snomed2icd["idx_icd"]))




In [6]:
all_mappings = {
    "icd2snomed": icd2snomed,
    "snomed2icd": snomed2icd,
    "id2snomed": id2snomed,
    "snomed2id": snomed2id,
    "id2icd": id2icd,
    "icd2id": icd2id,
}

with open(SAVE_PATH, "wb") as f:
    pickle.dump(all_mappings, f)

print(f"All mappings saved to: {SAVE_PATH}")

All mappings saved to: D:/UMLS/icd_snomed_mappings.pkl


# embed all labels

In [7]:
from sentence_transformers import SentenceTransformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [8]:
snomed_info = pl.read_csv("D:/UMLS/snomed_info.csv")
icd_info = pl.read_csv("D:/UMLS/icd_info.csv")
with open(SAVE_PATH, "rb") as f:
    loaded_mappings = pickle.load(f)

# Access individual mappings like:
id2snomed = loaded_mappings["id2snomed"]
id2icd = loaded_mappings["id2icd"]
snomed2id = loaded_mappings["snomed2id"]
icd2id = loaded_mappings["icd2id"]
icd2snomed = loaded_mappings["icd2snomed"]
snomed2icd = loaded_mappings["snomed2icd"]

In [9]:
snomed_labels = snomed_info["SNOMED_label"].to_list()
icd_labels = icd_info["ICD_label"].to_list()

In [14]:
model = SentenceTransformer("yyzheng00/sapbert_lora_triplet_rank16_merged", device=device)
embeddings_snomed = model.encode(snomed_labels, batch_size=32, show_progress_bar=True)
embeddings_icd = model.encode(icd_labels, batch_size=32, show_progress_bar=True)

save_path_embed =  PATH + "sapbert_lora_triplet_rank16_merged.pt" 
torch.save({
    "model_name": "sapbert_lora_triplet_rank16",
    "snomed_embeddings": embeddings_snomed,
    "icd_embeddings": embeddings_icd,
}, save_path_embed)

modules.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


config_sentence_transformers.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/3.86k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/56.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/691 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.38k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/226k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/706k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/732 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/305 [00:00<?, ?B/s]

Batches:   0%|          | 0/1424 [00:00<?, ?it/s]

Batches:   0%|          | 0/162 [00:00<?, ?it/s]

In [15]:
embeddings_icd.shape

(5156, 768)