# Notebook for Pre-trained Knowledge Graph Embedding (on UMLS)

In [1]:
## Import packages

import csv
import pandas as pd
import torch
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from pyhealth.medcode import ICD9CM, ICD9PROC
from pyhealth.medcode.pretrained_embeddings.kg_emb.models import TransE, RotatE, ComplEx, DistMult
from pyhealth.medcode.pretrained_embeddings.kg_emb.datasets import UMLSDataset, split
from pyhealth.medcode.pretrained_embeddings.kg_emb.tasks import link_prediction_fn

  from .autonotebook import tqdm as notebook_tqdm


## Load Pre-trained KGE model

In [None]:
umls_ds = UMLSDataset(
    root="/data/pj20/umls/",
    # root="https://storage.googleapis.com/pyhealth/umls/",
    dev=False,
    refresh_cache=False
)

# check the dataset statistics before setting task
print(umls_ds.stat()) 

# check the relation numbers in the dataset
print("Relations in KG:", umls_ds.relation2id)

umls_ds = umls_ds.set_task(link_prediction_fn, negative_sampling=64, save=False)

model = TransE(
    dataset=umls_ds,
    e_dim=512, 
    r_dim=512, 
)

print('Loaded model: ', model)
state_dict = torch.load("/data/pj20/umls_kge/pretrained_model/umls_transe_new/1_250000_last.ckpt")
model.load_state_dict(state_dict)

In [None]:
import pickle

with open("/data/pj20/umls_kge/pretrained_model/umls_transe_new/model.pkl", "wb") as f:
    pickle.dump(model, f)


In [None]:
len(model.R_emb), len(model.E_emb)

In [None]:
with open("/data/pj20/umls_kge/pretrained_model/umls_transe_new/E_emb.pkl", "wb") as f:
    pickle.dump(model.E_emb, f)

with open("/data/pj20/umls_kge/pretrained_model/umls_transe_new/R_emb.pkl", "wb") as f:
    pickle.dump(model.R_emb, f)

### Load Pre-trained Entity Embedding and Relation Embedding

In [2]:
import pickle

with open("/data/pj20/umls_kge/pretrained_model/umls_transe_new/E_emb.pkl", "rb") as f:
    E_emb = pickle.load(f)

with open("/data/pj20/umls_kge/pretrained_model/umls_transe_new/R_emb.pkl", "rb") as f:
    R_emb = pickle.load(f)

In [3]:
import json

with open("/data/pj20/umls_kge/pretrained_model/umls_transe_new/id2entity.json", "r") as f:
    id2entity = json.load(f)

with open("/data/pj20/umls_kge/pretrained_model/umls_transe_new/id2relation.json", "r") as f:
    id2relation = json.load(f)

entity2id = {v: k for k, v in id2entity.items()}
relation2id = {v: k for k, v in id2relation.items()}

## ATC

In [4]:
# load the mapping from ATC to UMLS
atc_umls = pd.read_csv("../resource/ATC_to_UMLS.csv", header=None)
atc_umls_cuis = atc_umls[1].tolist()[1:]


In [5]:
# Check if there are any ATC codes that are not in the UMLS

cnt = 0
not_covered_cui_atc = []
for cui in tqdm(atc_umls_cuis):
    if cui not in entity2id.keys():
        not_covered_cui_atc.append(cui)
        cnt+=1

cnt, not_covered_cui_atc

100%|██████████| 6564/6564 [00:00<00:00, 1767099.58it/s]


(0, [])

In [52]:
# get the embeddings

atc_to_umls = {}
for i in tqdm(range(len(atc_umls))):
    if atc_umls[1][i] != "UMLS":
        atc_to_umls[atc_umls[0][i]] = atc_umls[1][i]

atc_id2emb = {}
for atc_id in tqdm(atc_to_umls.keys()):
    atc_id2emb[atc_id] = E_emb[int(entity2id[atc_to_umls[atc_id]])].detach().numpy().tolist()

with open(f"../resource/embeddings/KG/drugs/atc.json", "w") as f:
    json.dump(atc_id2emb, f, indent=6)


100%|██████████| 6565/6565 [00:00<00:00, 56700.56it/s]
100%|██████████| 6564/6564 [00:00<00:00, 28065.41it/s]


In [22]:
from collections import defaultdict

data = pd.read_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/ATC.csv')

atc_to_umls = {}
for i in tqdm(range(len(atc_umls))):
    if atc_umls[1][i] != "UMLS":
        atc_to_umls[atc_umls[0][i]] = atc_umls[1][i]
        
atc_umls_dict = defaultdict(dict)

for atc_id in tqdm(atc_to_umls.keys()):
    atc_umls_dict[atc_id]['UMLS CUI'] = atc_to_umls[atc_id]
    atc_umls_dict[atc_id]['UMLS-KG Embedding'] = E_emb[int(entity2id[atc_to_umls[atc_id]])].detach().numpy().tolist()
    
data['UMLS CUI'] = ''
data['UMLS-KG Embedding'] = ''

for index, row in data.iterrows():
    code = row['code']
    if code in atc_umls_dict:
        data.at[index, 'UMLS CUI'] = atc_umls_dict[code]['UMLS CUI']
        data.at[index, 'UMLS-KG Embedding'] = atc_umls_dict[code]['UMLS-KG Embedding']

100%|██████████| 6565/6565 [00:00<00:00, 63963.11it/s]
100%|██████████| 6564/6564 [00:00<00:00, 23570.87it/s]


In [23]:
import pandas as pd
from tqdm import tqdm

# Assuming other necessary variables like `data`, `E_emb`, and `entity2id` are already defined...
data = pd.read_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/ATC.csv')
data = data.drop(columns=['description'])
data = data.drop(columns=['indication'])
# Define the output file paths

embedding_file_path = '/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ATC_Embedding.tsv'
metadata_file_path = '/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ATC_Metadata.tsv'

# Create empty lists to store valid embeddings and metadata
valid_embeddings = []
valid_metadata = []

# Loop through each row in your metadata DataFrame
for _, row in tqdm(data.iterrows(), total=data.shape[0]):
    # Get the ATC code from the current row
    atc_id = row['code']
    
    # Check if ATC code has a corresponding UMLS CUI and embedding
    if atc_id in atc_to_umls:
        umls_cui = atc_to_umls[atc_id]
        
        # Get and format the embedding
        embedding = E_emb[int(entity2id[umls_cui])].detach().numpy().tolist()
        embedding_str = '\t'.join(map(str, embedding))
        
        # Append the embedding and metadata to the respective lists
        valid_embeddings.append([embedding_str])
        
        # Add UMLS CUI to the row before appending to valid_metadata
        row_dict = row.to_dict()
        row_dict['UMLS CUI'] = umls_cui
        valid_metadata.append(row_dict)

# Convert lists to DataFrames
valid_embeddings_df = pd.DataFrame(valid_embeddings)
valid_metadata_df = pd.DataFrame(valid_metadata)

# Save the valid embeddings and metadata to TSV files
valid_embeddings_df.to_csv(embedding_file_path, sep='\t', index=False, header=False)
valid_metadata_df.to_csv(metadata_file_path, sep='\t', index=False, header=True)


100%|██████████| 6966/6966 [00:02<00:00, 3326.12it/s]


In [21]:
len(valid_embeddings_df), len(valid_metadata_df)

(6920, 6920)

In [24]:
data

Unnamed: 0,code,parent_code,name,level,description,indication,smiles,drugbank_id,UMLS CUI,UMLS-KG Embedding
0,A,,ALIMENTARY TRACT AND METABOLISM DRUGS,1.0,,,,,C3653992,"[-0.0428786464035511, -0.10927890986204147, 0...."
1,B,,BLOOD AND BLOOD FORMING ORGAN DRUGS,1.0,,,,,C3654015,"[-0.05003020912408829, -0.07041565328836441, 0..."
2,C,,CARDIOVASCULAR SYSTEM DRUGS,1.0,,,,,C3540036,"[0.037320803850889206, -0.14568206667900085, -..."
3,D,,DERMATOLOGICALS,1.0,,,,,C0011625,"[0.016971509903669357, -0.004210879094898701, ..."
4,G,,GENITO URINARY SYSTEM AND SEX HORMONES,1.0,,,,,C3653431,"[-0.11394545435905457, -0.10015183687210083, -..."
...,...,...,...,...,...,...,...,...,...,...
6961,V10XA53,V10XA,tositumomab/iodine (<o>131</o>I) tositumomab,5.0,Murine IgG2a lambda monoclonal antibody agains...,For treatment of non-Hodgkin's lymphoma (CD20 ...,,DB00081,C0768182,"[0.043125640600919724, 0.12629146873950958, 0...."
6962,V10XX01,V10XX,sodium phosphate (<o>32</o>P),5.0,,,[Na+].[Na+].O[32P]([O-])([O-])=O,DB09370,C0305007,"[0.0465690940618515, 0.05358200520277023, -0.0..."
6963,V10XX02,V10XX,ibritumomab tiuxetan (<o>90</o>Y),5.0,Indium or yttrium conjugated murine IgG1 kappa...,For treatment of non-Hodgkin's lymphoma,,DB00078,C1134535,"[-0.019149580970406532, 0.07273773849010468, -..."
6964,V10XX03,V10XX,radium (<o>223</o>Ra) dichloride,5.0,,,,,C3700396,"[-0.04269959777593613, 0.12389065325260162, -0..."


In [23]:
data.to_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ATC.csv', index=False)

## ICD-9-CM

In [25]:
# load the mapping from ICD-9-CM to UMLS
icd9cm_umls = pd.read_csv("../resource/ICD9CM_to_UMLS.csv", header=None)
icd9cm_umls_cuis = icd9cm_umls[1].tolist()[1:]

In [26]:
# Check if there are any ICD-9-CM codes that are not in the UMLS

cnt = 0
not_covered_cui_icd9cm = []
for cui in tqdm(icd9cm_umls_cuis):
    if cui not in entity2id.keys():
        not_covered_cui_icd9cm.append(cui)
        cnt+=1

cnt, not_covered_cui_icd9cm

100%|██████████| 22406/22406 [00:00<00:00, 1184536.54it/s]


(0, [])

In [66]:
# get the embeddings

icd9cm_to_umls = {}
for i in tqdm(range(len(icd9cm_umls))):
    if icd9cm_umls[1][i] != "UMLS":
        icd9cm_to_umls[icd9cm_umls[0][i]] = icd9cm_umls[1][i]

icd9cm_id2emb = {}
for icd9cm_id in tqdm(icd9cm_to_umls.keys()):
    key = ICD9CM.standardize(icd9cm_id).replace('.', '')
    icd9cm_id2emb[key] = E_emb[int(entity2id[icd9cm_to_umls[icd9cm_id]])].detach().numpy().tolist()

with open(f"../resource/embeddings/KG/conditions/icd9cm.json", "w") as f:
    json.dump(icd9cm_id2emb, f, indent=6)

100%|██████████| 22407/22407 [00:00<00:00, 79102.78it/s]
100%|██████████| 22406/22406 [00:00<00:00, 48074.67it/s]


In [34]:
from collections import defaultdict

data = pd.read_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/ICD9CM.csv')

icd9cm_to_umls = {}
for i in tqdm(range(len(icd9cm_umls))):
    if icd9cm_umls[1][i] != "UMLS":
        icd9cm_to_umls[icd9cm_umls[0][i]] = icd9cm_umls[1][i]
        
icd9cm_umls_dict = defaultdict(dict)

for icd9cm_key in tqdm(icd9cm_to_umls.keys()):
    icd9cm_umls_dict[icd9cm_key]['UMLS CUI'] = icd9cm_to_umls[icd9cm_key]
    icd9cm_umls_dict[icd9cm_key]['UMLS-KG Embedding'] = E_emb[int(entity2id[icd9cm_to_umls[icd9cm_key]])].detach().numpy().tolist()
    
data['UMLS CUI'] = ''
data['UMLS-KG Embedding'] = ''

for index, row in data.iterrows():
    code = row['code']
    if code in icd9cm_umls_dict:
        data.at[index, 'UMLS CUI'] = icd9cm_umls_dict[code]['UMLS CUI']
        data.at[index, 'UMLS-KG Embedding'] = icd9cm_umls_dict[code]['UMLS-KG Embedding']

data.to_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ICD9CM.csv', index=False)

100%|██████████| 22407/22407 [00:00<00:00, 100709.03it/s]
100%|██████████| 22406/22406 [00:00<00:00, 50675.42it/s]


In [1]:
import pandas as pd

# Load your CSV
data = pd.read_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ICD9CM.csv')

# 1. Create Embedding File
# Extract and process the UMLS-KG Embedding
embedding_data = data['UMLS-KG Embedding'].apply(lambda x: pd.Series(eval(x)))  # Using eval to convert string to list

# Save to TSV without headers and index
embedding_data.to_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ICD9CM_Embedding.tsv', sep='\t', index=False, header=False)

# 2. Create Metadata File
# Use all columns except 'UMLS-KG Embedding' as metadata
metadata_data = data.drop(columns=['UMLS-KG Embedding'])

# Save to TSV with headers and without index
metadata_data.to_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ICD9CM_Metadata.tsv', sep='\t', index=False, header=True)


## ICD-9-PROC

In [12]:
# load the mapping from ICD-9-proc to UMLS
icd9proc_umls = pd.read_csv("../resource/ICD9CM_to_UMLS.csv", header=None)
icd9proc_umls_cuis = icd9proc_umls[1].tolist()[1:]

In [13]:
# Check if there are any ICD-9-proc codes that are not in the UMLS

cnt = 0
not_covered_cui_icd9proc = []
for cui in tqdm(icd9proc_umls_cuis):
    if cui not in entity2id.keys():
        not_covered_cui_icd9proc.append(cui)
        cnt+=1

cnt, not_covered_cui_icd9proc

100%|██████████| 22406/22406 [00:00<00:00, 1003787.27it/s]


(0, [])

In [None]:
# get the embeddings

icd9proc_to_umls = {}
for i in tqdm(range(len(icd9proc_umls))):
    if icd9proc_umls[1][i] != "UMLS":
        icd9proc_to_umls[icd9proc_umls[0][i]] = icd9proc_umls[1][i]

icd9proc_id2emb = {}
for icd9proc_id in tqdm(icd9proc_to_umls.keys()):
    key = ICD9PROC.standardize(icd9proc_id).replace('.', '')
    icd9proc_id2emb[key] = E_emb[int(entity2id[icd9proc_to_umls[icd9proc_id]])].detach().numpy().tolist()

with open(f"../resource/embeddings/KG/procedures/icd9proc.json", "w") as f:
    json.dump(icd9proc_id2emb, f, indent=6)

In [16]:
icd9proc_to_umls

{'360.04': 'C0042904',
 '360.03': 'C0154774',
 '360.02': 'C0030332',
 '360.01': 'C0154773',
 '360.00': 'C0259800',
 '801.83': 'C0159307',
 '801.82': 'C0159306',
 '801.81': 'C0159305',
 '801.80': 'C0159304',
 '801.86': 'C0159310',
 '801.85': 'C0159309',
 '801.84': 'C0159308',
 '643.83': 'C0156706',
 '643.81': 'C0156705',
 '643.80': 'C0156704',
 '161.9': 'C0007107',
 '161.8': 'C0153487',
 '525.79': 'C1955810',
 '161.1': 'C0153484',
 '161.0': 'C0153483',
 '161.3': 'C0153486',
 '161.2': 'C0153485',
 '153.2': 'C0153435',
 '153.3': 'C0153436',
 '153.0': 'C0153433',
 '153.1': 'C0153434',
 '153.6': 'C0153439',
 '153.7': 'C0153440',
 '153.4': 'C0153437',
 '153.5': 'C0496779',
 '153.8': 'C0153441',
 '153.9': 'C0007102',
 '48': 'C0161905',
 '270': 'C0002514',
 '271': 'C0154249',
 '272': 'C0154251',
 '273': 'C3875058',
 '274': 'C0018099',
 '275': 'C0154260',
 '276': 'C0267994',
 '277': 'C0268329',
 '278': 'C1561827',
 '279': 'C0041806',
 '801.89': 'C0375595',
 '366.01': 'C1112690',
 '366.00': 'C26

In [32]:
from collections import defaultdict

data = pd.read_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/ICD9PROC.csv')

icd9proc_to_umls = {}
for i in tqdm(range(len(icd9proc_umls))):
    if icd9proc_umls[1][i] != "UMLS":
        icd9proc_to_umls[icd9proc_umls[0][i]] = icd9proc_umls[1][i]

icd9proc_umls_dict = defaultdict(dict)

for icd9proc_key in tqdm(icd9proc_to_umls.keys()):
    icd9proc_umls_dict[icd9proc_key]['UMLS CUI'] = icd9proc_to_umls[icd9proc_key]
    icd9proc_umls_dict[icd9proc_key]['UMLS-KG Embedding'] = E_emb[int(entity2id[icd9proc_to_umls[icd9proc_key]])].detach().numpy().tolist()
    
data['UMLS CUI'] = ''
data['UMLS-KG Embedding'] = ''

for index, row in data.iterrows():
    code = row['code']
    if code in icd9proc_umls_dict:
        data.at[index, 'UMLS CUI'] = icd9proc_umls_dict[code]['UMLS CUI']
        data.at[index, 'UMLS-KG Embedding'] = icd9proc_umls_dict[code]['UMLS-KG Embedding']

data.to_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ICD9PROC.csv', index=False)

100%|██████████| 22407/22407 [00:00<00:00, 101962.03it/s]
100%|██████████| 22406/22406 [00:00<00:00, 49370.34it/s]


In [17]:
import pandas as pd
from tqdm import tqdm

# Assuming other necessary variables like `data`, `E_emb`, and `entity2id` are already defined...
data = pd.read_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/ICD9PROC.csv')
# Define the output file paths

embedding_file_path = '/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ICD9PROC_Embedding.tsv'
metadata_file_path = '/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ICD9PROC_Metadata.tsv'

# Create empty lists to store valid embeddings and metadata
valid_embeddings = []
valid_metadata = []

# Loop through each row in your metadata DataFrame
for _, row in tqdm(data.iterrows(), total=data.shape[0]):
    # Get the ATC code from the current row
    atc_id = row['code']
    
    # Check if ATC code has a corresponding UMLS CUI and embedding
    if atc_id in icd9proc_to_umls:
        umls_cui = icd9proc_to_umls[atc_id]
        
        # Get and format the embedding
        embedding = E_emb[int(entity2id[umls_cui])].detach().numpy().tolist()
        embedding_str = '\t'.join(map(str, embedding)).replace('\"', '')
        
        # Append the embedding and metadata to the respective lists
        valid_embeddings.append([embedding_str])
        
        # Add UMLS CUI to the row before appending to valid_metadata
        row_dict = row.to_dict()
        row_dict['UMLS CUI'] = umls_cui
        valid_metadata.append(row_dict)

# Convert lists to DataFrames
valid_embeddings_df = pd.DataFrame(valid_embeddings)
valid_metadata_df = pd.DataFrame(valid_metadata)

# Save the valid embeddings and metadata to TSV files
valid_embeddings_df.to_csv(embedding_file_path, sep='\t', index=False, header=False)
valid_metadata_df.to_csv(metadata_file_path, sep='\t', index=False, header=True)


100%|██████████| 4671/4671 [00:01<00:00, 3183.46it/s]


In [33]:
data

Unnamed: 0,code,parent_code,name,UMLS CUI,UMLS-KG Embedding
0,94.62,94.6,Alcohol detoxification,C0204597,"[-0.020793389528989792, 0.08486927300691605, -..."
1,94.69,94.6,Combined alcohol and drug rehabilitation and d...,C0204605,"[-0.05920370668172836, 0.01815624162554741, -0..."
2,94.6,94,Alcohol and drug rehabilitation and detoxifica...,C0178073,"[-0.07168742269277573, 0.07703245431184769, -0..."
3,94.61,94.6,Alcohol rehabilitation,C0204598,"[-0.08163461834192276, -0.017953498288989067, ..."
4,94.67,94.6,Combined alcohol and drug rehabilitation,C0204604,"[-0.07583430409431458, 0.10185211151838303, -0..."
...,...,...,...,...,...
4666,35.95,35.9,Revision of corrective procedure on heart,C0189750,"[-0.007993519306182861, 0.027917619794607162, ..."
4667,21.2,21,Diagnostic procedures on nose,C0176357,"[-0.022780975326895714, -0.04294693470001221, ..."
4668,21.29,21.2,Other diagnostic procedures on nose,C0176359,"[0.02899329364299774, 0.02302040532231331, 0.0..."
4669,21.21,21.2,Rhinoscopy,C0189024,"[-0.07597479969263077, -0.06356113404035568, -..."


In [24]:
with open('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ATC_Embedding.tsv', 'r') as f:
    lines = f.readlines()
  
lines_new = []  
for line in lines:
    lines_new.append(line.replace('\"', ''))
    
with open('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ATC_Embedding.tsv', 'w') as f:
    f.writelines(lines_new)

## CCSCM

In [70]:
icd9cm_to_ccscm = {}

with open("../resource/ICD9CM_to_CCSCM.csv", "r") as f:
    reader = csv.reader(f)
    for row in reader:
        if row[1] != 'CCSCM':
            icd9cm_to_ccscm[row[0]] = row[1]

ccscm_to_icd9cm = defaultdict(list)
for k, v in icd9cm_to_ccscm.items():
    ccscm_to_icd9cm[v].append(k)

ccscm_icd9cm = {}
for k, v in ccscm_to_icd9cm.items():
    ccscm_icd9cm[k] = v[0]

In [73]:
# get the embeddings
ccscm_id2emb = {}
for ccscm_id in tqdm(ccscm_icd9cm.keys()):
    try:
        ccscm_id2emb[ccscm_id] = E_emb[int(entity2id[icd9cm_to_umls[ccscm_icd9cm[ccscm_id]]])].detach().numpy().tolist()
    except:
        ccscm_id2emb[ccscm_id] = E_emb[int(entity2id[icd9cm_to_umls[ccscm_icd9cm[ccscm_id].replace('.00', '')]])].detach().numpy().tolist()

with open(f"../resource/embeddings/KG/conditions/ccscm.json", "w") as f:
    json.dump(ccscm_id2emb, f, indent=6)

100%|██████████| 283/283 [00:00<00:00, 32360.63it/s]


## CCSPROC

In [74]:
icd9proc_to_ccsproc = {}

with open("../resource/ICD9PROC_to_CCSPROC.csv", "r") as f:
    reader = csv.reader(f)
    for row in reader:
        if row[1] != 'CCSPROC':
            icd9proc_to_ccsproc[row[0]] = row[1]

ccsproc_to_icd9proc = defaultdict(list)
for k, v in icd9proc_to_ccsproc.items():
    ccsproc_to_icd9proc[v].append(k)

ccsproc_icd9proc = {}
for k, v in ccsproc_to_icd9proc.items():
    ccsproc_icd9proc[k] = v[0]

In [82]:
# get the embeddings
ccsproc_id2emb = {}
for ccsproc_id in tqdm(ccsproc_icd9proc.keys()):
    try:
        ccsproc_id2emb[ccsproc_id] = E_emb[int(entity2id[icd9proc_to_umls[ccsproc_icd9proc[ccsproc_id]]])].detach().numpy().tolist()
    except:
        try:
            icd9procid = ccsproc_icd9proc[ccsproc_id]
            if icd9procid[0] == '0':
                icd9procid = icd9procid[1:]
            if icd9procid[-1] == '0':
                icd9procid = icd9procid[:-1]
            if icd9procid[-1] == '0':
                icd9procid = icd9procid[:-2]
            if icd9procid[-1] == '.':
                icd9procid = icd9procid[:-1]

            ccsproc_id2emb[ccsproc_id] = E_emb[int(entity2id[icd9proc_to_umls[icd9procid]])].detach().numpy().tolist()
            
        except:
            icd9procid = icd9procid[:-1]
            ccsproc_id2emb[ccsproc_id] = E_emb[int(entity2id[icd9proc_to_umls[icd9procid]])].detach().numpy().tolist()

with open(f"../resource/embeddings/KG/procedures/ccsproc.json", "w") as f:
    json.dump(ccsproc_id2emb, f, indent=6)

100%|██████████| 231/231 [00:00<00:00, 49372.41it/s]


## Special Tokens

In [57]:
special_tokens = {}
tokens = ['<pad>', '<unk>']

for token in tokens:
    special_tokens[token] = np.random.randn(512).tolist()

with open(f"../resource/embeddings/KG/special_tokens/special_tokens.json", "w") as f:
    json.dump(special_tokens, f, indent=6)