# Notebook for Pre-trained Knowledge Graph Embedding (on UMLS)

In [3]:
## Import packages

import csv
import pandas as pd
import torch
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from pyhealth.medcode import ICD9CM, ICD9PROC
from pyhealth.medcode.pretrained_embeddings.kg_emb.models import TransE, RotatE, ComplEx, DistMult
from pyhealth.medcode.pretrained_embeddings.kg_emb.datasets import UMLSDataset, split
from pyhealth.medcode.pretrained_embeddings.kg_emb.tasks import link_prediction_fn

## Load Pre-trained KGE model

In [None]:
umls_ds = UMLSDataset(
    root="/data/pj20/umls/",
    # root="https://storage.googleapis.com/pyhealth/umls/",
    dev=False,
    refresh_cache=False
)

# check the dataset statistics before setting task
print(umls_ds.stat()) 

# check the relation numbers in the dataset
print("Relations in KG:", umls_ds.relation2id)

umls_ds = umls_ds.set_task(link_prediction_fn, negative_sampling=64, save=False)

model = TransE(
    dataset=umls_ds,
    e_dim=512, 
    r_dim=512, 
)

print('Loaded model: ', model)
state_dict = torch.load("/data/pj20/umls_kge/pretrained_model/umls_transe_new/1_250000_last.ckpt")
model.load_state_dict(state_dict)

INFO: Pandarallel will run on 64 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.
Loading UMLS knowledge graph...
Processing UMLS knowledge graph...
Building UMLS knowledge graph...


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 43842950/43842950 [00:29<00:00, 1486771.41it/s]



Statistics of base dataset (dev=False):
	- Dataset: UMLSDataset
	- Number of triples: 43842950
	- Number of entities: 3110571
	- Number of relations: 965
	- Task name: Null
	- Number of samples: 0

None
Relations in KG: {'RB': 0, 'translation_of': 1, 'permuted_term_of': 2, 'SY': 3, 'AQ': 4, 'PAR': 5, 'mapped_to': 6, 'associated_with': 7, 'has_permuted_term': 8, 'has_translation': 9, 'has_transliterated_form': 10, 'measures': 11, 'parent_of': 12, 'form_of': 13, 'CHD': 14, 'has_component': 15, 'transliterated_form_of': 16, 'RO': 17, 'RN': 18, 'inverse_isa': 19, 'disposition_of': 20, 'exhibited_by': 21, 'see_from': 22, 'see': 23, 'entry_combination_of': 24, 'mapped_from': 25, 'has_causative_agent': 26, 'used_for': 27, 'use': 28, 'isa': 29, 'subset_includes_concept': 30, 'has_direct_substance': 31, 'has_ingredient': 32, 'has_tradename': 33, 'mapping_qualifier_of': 34, 'contains': 35, 'active_ingredient_of': 36, 'has_active_ingredient': 37, 'has_active_moiety': 38, 'has_member': 39, 'has_b

 20%|███████████████████████▍                                                                                           | 8928541/43842950 [02:08<06:51, 84924.09it/s]

In [None]:
import pickle

with open("/data/pj20/umls_kge/pretrained_model/umls_transe_new/model.pkl", "wb") as f:
    pickle.dump(model, f)


In [None]:
len(model.R_emb), len(model.E_emb)

In [None]:
with open("/data/pj20/umls_kge/pretrained_model/umls_transe_new/E_emb.pkl", "wb") as f:
    pickle.dump(model.E_emb, f)

with open("/data/pj20/umls_kge/pretrained_model/umls_transe_new/R_emb.pkl", "wb") as f:
    pickle.dump(model.R_emb, f)

### Load Pre-trained Entity Embedding and Relation Embedding

In [None]:
import pickle

with open("/data/pj20/umls_kge/pretrained_model/umls_transe_new/E_emb.pkl", "rb") as f:
    E_emb = pickle.load(f)

with open("/data/pj20/umls_kge/pretrained_model/umls_transe_new/R_emb.pkl", "rb") as f:
    R_emb = pickle.load(f)

In [None]:
import json

with open("/data/pj20/umls_kge/pretrained_model/umls_transe_new/id2entity.json", "r") as f:
    id2entity = json.load(f)

with open("/data/pj20/umls_kge/pretrained_model/umls_transe_new/id2relation.json", "r") as f:
    id2relation = json.load(f)

entity2id = {v: k for k, v in id2entity.items()}
relation2id = {v: k for k, v in id2relation.items()}

## ATC

In [None]:
# load the mapping from ATC to UMLS
atc_umls = pd.read_csv("../resource/ATC_to_UMLS.csv", header=None)
atc_umls_cuis = atc_umls[1].tolist()[1:]


In [None]:
# Check if there are any ATC codes that are not in the UMLS

cnt = 0
not_covered_cui_atc = []
for cui in tqdm(atc_umls_cuis):
    if cui not in entity2id.keys():
        not_covered_cui_atc.append(cui)
        cnt+=1

cnt, not_covered_cui_atc

In [None]:
# get the embeddings

atc_to_umls = {}
for i in tqdm(range(len(atc_umls))):
    if atc_umls[1][i] != "UMLS":
        atc_to_umls[atc_umls[0][i]] = atc_umls[1][i]

atc_id2emb = {}
for atc_id in tqdm(atc_to_umls.keys()):
    atc_id2emb[atc_id] = E_emb[int(entity2id[atc_to_umls[atc_id]])].detach().numpy().tolist()

with open(f"../resource/embeddings/KG/drugs/atc.json", "w") as f:
    json.dump(atc_id2emb, f, indent=6)


In [None]:
from collections import defaultdict

data = pd.read_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/ATC.csv')

atc_to_umls = {}
for i in tqdm(range(len(atc_umls))):
    if atc_umls[1][i] != "UMLS":
        atc_to_umls[atc_umls[0][i]] = atc_umls[1][i]
        
atc_umls_dict = defaultdict(dict)

for atc_id in tqdm(atc_to_umls.keys()):
    atc_umls_dict[atc_id]['UMLS CUI'] = atc_to_umls[atc_id]
    atc_umls_dict[atc_id]['UMLS-KG Embedding'] = E_emb[int(entity2id[atc_to_umls[atc_id]])].detach().numpy().tolist()
    
data['UMLS CUI'] = ''
data['UMLS-KG Embedding'] = ''

for index, row in data.iterrows():
    code = row['code']
    if code in atc_umls_dict:
        data.at[index, 'UMLS CUI'] = atc_umls_dict[code]['UMLS CUI']
        data.at[index, 'UMLS-KG Embedding'] = atc_umls_dict[code]['UMLS-KG Embedding']

In [None]:
import pandas as pd
from tqdm import tqdm

# Assuming other necessary variables like `data`, `E_emb`, and `entity2id` are already defined...
data = pd.read_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/ATC.csv')
data = data.drop(columns=['description'])
data = data.drop(columns=['indication'])
# Define the output file paths

embedding_file_path = '/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ATC_Embedding.tsv'
metadata_file_path = '/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ATC_Metadata.tsv'

# Create empty lists to store valid embeddings and metadata
valid_embeddings = []
valid_metadata = []

# Loop through each row in your metadata DataFrame
for _, row in tqdm(data.iterrows(), total=data.shape[0]):
    # Get the ATC code from the current row
    atc_id = row['code']
    
    # Check if ATC code has a corresponding UMLS CUI and embedding
    if atc_id in atc_to_umls:
        umls_cui = atc_to_umls[atc_id]
        
        # Get and format the embedding
        embedding = E_emb[int(entity2id[umls_cui])].detach().numpy().tolist()
        embedding_str = '\t'.join(map(str, embedding))
        
        # Append the embedding and metadata to the respective lists
        valid_embeddings.append([embedding_str])
        
        # Add UMLS CUI to the row before appending to valid_metadata
        row_dict = row.to_dict()
        row_dict['UMLS CUI'] = umls_cui
        valid_metadata.append(row_dict)

# Convert lists to DataFrames
valid_embeddings_df = pd.DataFrame(valid_embeddings)
valid_metadata_df = pd.DataFrame(valid_metadata)

# Save the valid embeddings and metadata to TSV files
valid_embeddings_df.to_csv(embedding_file_path, sep='\t', index=False, header=False)
valid_metadata_df.to_csv(metadata_file_path, sep='\t', index=False, header=True)


In [None]:
len(valid_embeddings_df), len(valid_metadata_df)

In [None]:
data

In [None]:
data.to_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ATC.csv', index=False)

## ICD-9-CM

In [None]:
# load the mapping from ICD-9-CM to UMLS
icd9cm_umls = pd.read_csv("../resource/ICD9CM_to_UMLS.csv", header=None)
icd9cm_umls_cuis = icd9cm_umls[1].tolist()[1:]

In [None]:
# Check if there are any ICD-9-CM codes that are not in the UMLS

cnt = 0
not_covered_cui_icd9cm = []
for cui in tqdm(icd9cm_umls_cuis):
    if cui not in entity2id.keys():
        not_covered_cui_icd9cm.append(cui)
        cnt+=1

cnt, not_covered_cui_icd9cm

In [None]:
# get the embeddings

icd9cm_to_umls = {}
for i in tqdm(range(len(icd9cm_umls))):
    if icd9cm_umls[1][i] != "UMLS":
        icd9cm_to_umls[icd9cm_umls[0][i]] = icd9cm_umls[1][i]

icd9cm_id2emb = {}
for icd9cm_id in tqdm(icd9cm_to_umls.keys()):
    key = ICD9CM.standardize(icd9cm_id).replace('.', '')
    icd9cm_id2emb[key] = E_emb[int(entity2id[icd9cm_to_umls[icd9cm_id]])].detach().numpy().tolist()

with open(f"../resource/embeddings/KG/conditions/icd9cm.json", "w") as f:
    json.dump(icd9cm_id2emb, f, indent=6)

In [None]:
from collections import defaultdict

data = pd.read_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/ICD9CM.csv')

icd9cm_to_umls = {}
for i in tqdm(range(len(icd9cm_umls))):
    if icd9cm_umls[1][i] != "UMLS":
        icd9cm_to_umls[icd9cm_umls[0][i]] = icd9cm_umls[1][i]
        
icd9cm_umls_dict = defaultdict(dict)

for icd9cm_key in tqdm(icd9cm_to_umls.keys()):
    icd9cm_umls_dict[icd9cm_key]['UMLS CUI'] = icd9cm_to_umls[icd9cm_key]
    icd9cm_umls_dict[icd9cm_key]['UMLS-KG Embedding'] = E_emb[int(entity2id[icd9cm_to_umls[icd9cm_key]])].detach().numpy().tolist()
    
data['UMLS CUI'] = ''
data['UMLS-KG Embedding'] = ''

for index, row in data.iterrows():
    code = row['code']
    if code in icd9cm_umls_dict:
        data.at[index, 'UMLS CUI'] = icd9cm_umls_dict[code]['UMLS CUI']
        data.at[index, 'UMLS-KG Embedding'] = icd9cm_umls_dict[code]['UMLS-KG Embedding']

data.to_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ICD9CM.csv', index=False)

In [None]:
import pandas as pd

# Load your CSV
data = pd.read_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ICD9CM.csv')

# 1. Create Embedding File
# Extract and process the UMLS-KG Embedding
embedding_data = data['UMLS-KG Embedding'].apply(lambda x: pd.Series(eval(x)))  # Using eval to convert string to list

# Save to TSV without headers and index
embedding_data.to_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ICD9CM_Embedding.tsv', sep='\t', index=False, header=False)

# 2. Create Metadata File
# Use all columns except 'UMLS-KG Embedding' as metadata
metadata_data = data.drop(columns=['UMLS-KG Embedding'])

# Save to TSV with headers and without index
metadata_data.to_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ICD9CM_Metadata.tsv', sep='\t', index=False, header=True)


## ICD-9-PROC

In [None]:
# load the mapping from ICD-9-proc to UMLS
icd9proc_umls = pd.read_csv("../resource/ICD9CM_to_UMLS.csv", header=None)
icd9proc_umls_cuis = icd9proc_umls[1].tolist()[1:]

In [None]:
# Check if there are any ICD-9-proc codes that are not in the UMLS

cnt = 0
not_covered_cui_icd9proc = []
for cui in tqdm(icd9proc_umls_cuis):
    if cui not in entity2id.keys():
        not_covered_cui_icd9proc.append(cui)
        cnt+=1

cnt, not_covered_cui_icd9proc

In [None]:
# get the embeddings

icd9proc_to_umls = {}
for i in tqdm(range(len(icd9proc_umls))):
    if icd9proc_umls[1][i] != "UMLS":
        icd9proc_to_umls[icd9proc_umls[0][i]] = icd9proc_umls[1][i]

icd9proc_id2emb = {}
for icd9proc_id in tqdm(icd9proc_to_umls.keys()):
    key = ICD9PROC.standardize(icd9proc_id).replace('.', '')
    icd9proc_id2emb[key] = E_emb[int(entity2id[icd9proc_to_umls[icd9proc_id]])].detach().numpy().tolist()

with open(f"../resource/embeddings/KG/procedures/icd9proc.json", "w") as f:
    json.dump(icd9proc_id2emb, f, indent=6)

In [None]:
icd9proc_to_umls

In [None]:
from collections import defaultdict

data = pd.read_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/ICD9PROC.csv')

icd9proc_to_umls = {}
for i in tqdm(range(len(icd9proc_umls))):
    if icd9proc_umls[1][i] != "UMLS":
        icd9proc_to_umls[icd9proc_umls[0][i]] = icd9proc_umls[1][i]

icd9proc_umls_dict = defaultdict(dict)

for icd9proc_key in tqdm(icd9proc_to_umls.keys()):
    icd9proc_umls_dict[icd9proc_key]['UMLS CUI'] = icd9proc_to_umls[icd9proc_key]
    icd9proc_umls_dict[icd9proc_key]['UMLS-KG Embedding'] = E_emb[int(entity2id[icd9proc_to_umls[icd9proc_key]])].detach().numpy().tolist()
    
data['UMLS CUI'] = ''
data['UMLS-KG Embedding'] = ''

for index, row in data.iterrows():
    code = row['code']
    if code in icd9proc_umls_dict:
        data.at[index, 'UMLS CUI'] = icd9proc_umls_dict[code]['UMLS CUI']
        data.at[index, 'UMLS-KG Embedding'] = icd9proc_umls_dict[code]['UMLS-KG Embedding']

data.to_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ICD9PROC.csv', index=False)

In [None]:
import pandas as pd
from tqdm import tqdm

# Assuming other necessary variables like `data`, `E_emb`, and `entity2id` are already defined...
data = pd.read_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/ICD9PROC.csv')
# Define the output file paths

embedding_file_path = '/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ICD9PROC_Embedding.tsv'
metadata_file_path = '/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ICD9PROC_Metadata.tsv'

# Create empty lists to store valid embeddings and metadata
valid_embeddings = []
valid_metadata = []

# Loop through each row in your metadata DataFrame
for _, row in tqdm(data.iterrows(), total=data.shape[0]):
    # Get the ATC code from the current row
    atc_id = row['code']
    
    # Check if ATC code has a corresponding UMLS CUI and embedding
    if atc_id in icd9proc_to_umls:
        umls_cui = icd9proc_to_umls[atc_id]
        
        # Get and format the embedding
        embedding = E_emb[int(entity2id[umls_cui])].detach().numpy().tolist()
        embedding_str = '\t'.join(map(str, embedding)).replace('\"', '')
        
        # Append the embedding and metadata to the respective lists
        valid_embeddings.append([embedding_str])
        
        # Add UMLS CUI to the row before appending to valid_metadata
        row_dict = row.to_dict()
        row_dict['UMLS CUI'] = umls_cui
        valid_metadata.append(row_dict)

# Convert lists to DataFrames
valid_embeddings_df = pd.DataFrame(valid_embeddings)
valid_metadata_df = pd.DataFrame(valid_metadata)

# Save the valid embeddings and metadata to TSV files
valid_embeddings_df.to_csv(embedding_file_path, sep='\t', index=False, header=False)
valid_metadata_df.to_csv(metadata_file_path, sep='\t', index=False, header=True)


In [None]:
data

In [None]:
with open('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ATC_Embedding.tsv', 'r') as f:
    lines = f.readlines()
  
lines_new = []  
for line in lines:
    lines_new.append(line.replace('\"', ''))
    
with open('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ATC_Embedding.tsv', 'w') as f:
    f.writelines(lines_new)

## CCSCM

In [None]:
icd9cm_to_ccscm = {}

with open("../resource/ICD9CM_to_CCSCM.csv", "r") as f:
    reader = csv.reader(f)
    for row in reader:
        if row[1] != 'CCSCM':
            icd9cm_to_ccscm[row[0]] = row[1]

ccscm_to_icd9cm = defaultdict(list)
for k, v in icd9cm_to_ccscm.items():
    ccscm_to_icd9cm[v].append(k)

ccscm_icd9cm = {}
for k, v in ccscm_to_icd9cm.items():
    ccscm_icd9cm[k] = v[0]

In [None]:
# get the embeddings
ccscm_id2emb = {}
for ccscm_id in tqdm(ccscm_icd9cm.keys()):
    try:
        ccscm_id2emb[ccscm_id] = E_emb[int(entity2id[icd9cm_to_umls[ccscm_icd9cm[ccscm_id]]])].detach().numpy().tolist()
    except:
        ccscm_id2emb[ccscm_id] = E_emb[int(entity2id[icd9cm_to_umls[ccscm_icd9cm[ccscm_id].replace('.00', '')]])].detach().numpy().tolist()

with open(f"../resource/embeddings/KG/conditions/ccscm.json", "w") as f:
    json.dump(ccscm_id2emb, f, indent=6)

## CCSPROC

In [None]:
icd9proc_to_ccsproc = {}

with open("../resource/ICD9PROC_to_CCSPROC.csv", "r") as f:
    reader = csv.reader(f)
    for row in reader:
        if row[1] != 'CCSPROC':
            icd9proc_to_ccsproc[row[0]] = row[1]

ccsproc_to_icd9proc = defaultdict(list)
for k, v in icd9proc_to_ccsproc.items():
    ccsproc_to_icd9proc[v].append(k)

ccsproc_icd9proc = {}
for k, v in ccsproc_to_icd9proc.items():
    ccsproc_icd9proc[k] = v[0]

In [None]:
# get the embeddings
ccsproc_id2emb = {}
for ccsproc_id in tqdm(ccsproc_icd9proc.keys()):
    try:
        ccsproc_id2emb[ccsproc_id] = E_emb[int(entity2id[icd9proc_to_umls[ccsproc_icd9proc[ccsproc_id]]])].detach().numpy().tolist()
    except:
        try:
            icd9procid = ccsproc_icd9proc[ccsproc_id]
            if icd9procid[0] == '0':
                icd9procid = icd9procid[1:]
            if icd9procid[-1] == '0':
                icd9procid = icd9procid[:-1]
            if icd9procid[-1] == '0':
                icd9procid = icd9procid[:-2]
            if icd9procid[-1] == '.':
                icd9procid = icd9procid[:-1]

            ccsproc_id2emb[ccsproc_id] = E_emb[int(entity2id[icd9proc_to_umls[icd9procid]])].detach().numpy().tolist()
            
        except:
            icd9procid = icd9procid[:-1]
            ccsproc_id2emb[ccsproc_id] = E_emb[int(entity2id[icd9proc_to_umls[icd9procid]])].detach().numpy().tolist()

with open(f"../resource/embeddings/KG/procedures/ccsproc.json", "w") as f:
    json.dump(ccsproc_id2emb, f, indent=6)

## Special Tokens

In [None]:
special_tokens = {}
tokens = ['<pad>', '<unk>']

for token in tokens:
    special_tokens[token] = np.random.randn(512).tolist()

with open(f"../resource/embeddings/KG/special_tokens/special_tokens.json", "w") as f:
    json.dump(special_tokens, f, indent=6)