# Notebook for Pre-trained Knowledge Graph Embedding (on UMLS)

In [84]:
## Import packages

import csv
import pandas as pd
import torch
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from pyhealth.medcode import ICD9CM, ICD9PROC
from pyhealth.medcode.pretrained_embeddings.kg_emb.models import TransE, RotatE, ComplEx, DistMult
from pyhealth.medcode.pretrained_embeddings.kg_emb.datasets import UMLSDataset, split
from pyhealth.medcode.pretrained_embeddings.kg_emb.tasks import link_prediction_fn

## Load Pre-trained KGE model

In [None]:
umls_ds = UMLSDataset(
    root="/data/pj20/umls/",
    # root="https://storage.googleapis.com/pyhealth/umls/",
    dev=False,
    refresh_cache=False
)

# check the dataset statistics before setting task
print(umls_ds.stat()) 

# check the relation numbers in the dataset
print("Relations in KG:", umls_ds.relation2id)

umls_ds = umls_ds.set_task(link_prediction_fn, negative_sampling=64, save=False)

model = TransE(
    dataset=umls_ds,
    e_dim=512, 
    r_dim=512, 
)

print('Loaded model: ', model)
state_dict = torch.load("/data/pj20/umls_kge/pretrained_model/umls_transe_new/1_250000_last.ckpt")
model.load_state_dict(state_dict)

In [None]:
import pickle

with open("/data/pj20/umls_kge/pretrained_model/umls_transe_new/model.pkl", "wb") as f:
    pickle.dump(model, f)


In [None]:
len(model.R_emb), len(model.E_emb)

In [None]:
with open("/data/pj20/umls_kge/pretrained_model/umls_transe_new/E_emb.pkl", "wb") as f:
    pickle.dump(model.E_emb, f)

with open("/data/pj20/umls_kge/pretrained_model/umls_transe_new/R_emb.pkl", "wb") as f:
    pickle.dump(model.R_emb, f)

### Load Pre-trained Entity Embedding and Relation Embedding

In [1]:
import pickle

with open("/data/pj20/umls_kge/pretrained_model/umls_transe_new/E_emb.pkl", "rb") as f:
    E_emb = pickle.load(f)

with open("/data/pj20/umls_kge/pretrained_model/umls_transe_new/R_emb.pkl", "rb") as f:
    R_emb = pickle.load(f)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import json

with open("/data/pj20/umls_kge/pretrained_model/umls_transe_new/id2entity.json", "r") as f:
    id2entity = json.load(f)

with open("/data/pj20/umls_kge/pretrained_model/umls_transe_new/id2relation.json", "r") as f:
    id2relation = json.load(f)

entity2id = {v: k for k, v in id2entity.items()}
relation2id = {v: k for k, v in id2relation.items()}

## ATC

In [21]:
# load the mapping from ATC to UMLS
atc_umls = pd.read_csv("../resource/ATC_to_UMLS.csv", header=None)
atc_umls_cuis = atc_umls[1].tolist()[1:]


In [22]:
# Check if there are any ATC codes that are not in the UMLS

cnt = 0
not_covered_cui_atc = []
for cui in tqdm(atc_umls_cuis):
    if cui not in entity2id.keys():
        not_covered_cui_atc.append(cui)
        cnt+=1

cnt, not_covered_cui_atc

100%|██████████| 6564/6564 [00:00<00:00, 680863.87it/s]


(0, [])

In [52]:
# get the embeddings

atc_to_umls = {}
for i in tqdm(range(len(atc_umls))):
    if atc_umls[1][i] != "UMLS":
        atc_to_umls[atc_umls[0][i]] = atc_umls[1][i]

atc_id2emb = {}
for atc_id in tqdm(atc_to_umls.keys()):
    atc_id2emb[atc_id] = E_emb[int(entity2id[atc_to_umls[atc_id]])].detach().numpy().tolist()

with open(f"../resource/embeddings/KG/drugs/atc.json", "w") as f:
    json.dump(atc_id2emb, f, indent=6)


100%|██████████| 6565/6565 [00:00<00:00, 56700.56it/s]
100%|██████████| 6564/6564 [00:00<00:00, 28065.41it/s]


## ICD-9-CM

In [27]:
# load the mapping from ICD-9-CM to UMLS
icd9cm_umls = pd.read_csv("../resource/ICD9CM_to_UMLS.csv", header=None)
icd9cm_umls_cuis = icd9cm_umls[1].tolist()[1:]

In [29]:
# Check if there are any ICD-9-CM codes that are not in the UMLS

cnt = 0
not_covered_cui_icd9cm = []
for cui in tqdm(icd9cm_umls_cuis):
    if cui not in entity2id.keys():
        not_covered_cui_icd9cm.append(cui)
        cnt+=1

cnt, not_covered_cui_icd9cm

100%|██████████| 22406/22406 [00:00<00:00, 1022428.91it/s]


(0, [])

In [None]:
atc_to_umls = {}
for i in tqdm(range(len(atc_umls))):
    if atc_umls[1][i] != "UMLS":
        atc_to_umls[atc_umls[0][i]] = atc_umls[1][i]

atc_id2emb = {}
for atc_id in tqdm(atc_to_umls.keys()):
    atc_id2emb[atc_id] = E_emb[int(entity2id[atc_to_umls[atc_id]])].detach().numpy().tolist()

with open(f"../resource/embeddings/KG/drugs/atc.json", "w") as f:
    json.dump(atc_id2emb, f, indent=6)

In [66]:
# get the embeddings

icd9cm_to_umls = {}
for i in tqdm(range(len(icd9cm_umls))):
    if icd9cm_umls[1][i] != "UMLS":
        icd9cm_to_umls[icd9cm_umls[0][i]] = icd9cm_umls[1][i]

icd9cm_id2emb = {}
for icd9cm_id in tqdm(icd9cm_to_umls.keys()):
    key = ICD9CM.standardize(icd9cm_id).replace('.', '')
    icd9cm_id2emb[key] = E_emb[int(entity2id[icd9cm_to_umls[icd9cm_id]])].detach().numpy().tolist()

with open(f"../resource/embeddings/KG/conditions/icd9cm.json", "w") as f:
    json.dump(icd9cm_id2emb, f, indent=6)

100%|██████████| 22407/22407 [00:00<00:00, 79102.78it/s]
100%|██████████| 22406/22406 [00:00<00:00, 48074.67it/s]


## ICD-9-PROC

In [90]:
# load the mapping from ICD-9-proc to UMLS
icd9proc_umls = pd.read_csv("../resource/ICD9CM_to_UMLS.csv", header=None)
icd9proc_umls_cuis = icd9proc_umls[1].tolist()[1:]

In [91]:
# Check if there are any ICD-9-proc codes that are not in the UMLS

cnt = 0
not_covered_cui_icd9proc = []
for cui in tqdm(icd9proc_umls_cuis):
    if cui not in entity2id.keys():
        not_covered_cui_icd9proc.append(cui)
        cnt+=1

cnt, not_covered_cui_icd9proc

100%|██████████| 22406/22406 [00:00<00:00, 1076218.77it/s]


(0, [])

In [92]:
# get the embeddings

icd9proc_to_umls = {}
for i in tqdm(range(len(icd9proc_umls))):
    if icd9proc_umls[1][i] != "UMLS":
        icd9proc_to_umls[icd9proc_umls[0][i]] = icd9proc_umls[1][i]

icd9proc_id2emb = {}
for icd9proc_id in tqdm(icd9proc_to_umls.keys()):
    key = ICD9PROC.standardize(icd9proc_id).replace('.', '')
    icd9proc_id2emb[key] = E_emb[int(entity2id[icd9proc_to_umls[icd9proc_id]])].detach().numpy().tolist()

with open(f"../resource/embeddings/KG/procedures/icd9proc.json", "w") as f:
    json.dump(icd9proc_id2emb, f, indent=6)

100%|██████████| 22407/22407 [00:00<00:00, 104368.29it/s]
100%|██████████| 22406/22406 [00:00<00:00, 54029.07it/s]


## CCSCM

In [70]:
icd9cm_to_ccscm = {}

with open("../resource/ICD9CM_to_CCSCM.csv", "r") as f:
    reader = csv.reader(f)
    for row in reader:
        if row[1] != 'CCSCM':
            icd9cm_to_ccscm[row[0]] = row[1]

ccscm_to_icd9cm = defaultdict(list)
for k, v in icd9cm_to_ccscm.items():
    ccscm_to_icd9cm[v].append(k)

ccscm_icd9cm = {}
for k, v in ccscm_to_icd9cm.items():
    ccscm_icd9cm[k] = v[0]

In [73]:
# get the embeddings
ccscm_id2emb = {}
for ccscm_id in tqdm(ccscm_icd9cm.keys()):
    try:
        ccscm_id2emb[ccscm_id] = E_emb[int(entity2id[icd9cm_to_umls[ccscm_icd9cm[ccscm_id]]])].detach().numpy().tolist()
    except:
        ccscm_id2emb[ccscm_id] = E_emb[int(entity2id[icd9cm_to_umls[ccscm_icd9cm[ccscm_id].replace('.00', '')]])].detach().numpy().tolist()

with open(f"../resource/embeddings/KG/conditions/ccscm.json", "w") as f:
    json.dump(ccscm_id2emb, f, indent=6)

100%|██████████| 283/283 [00:00<00:00, 32360.63it/s]


## CCSPROC

In [74]:
icd9proc_to_ccsproc = {}

with open("../resource/ICD9PROC_to_CCSPROC.csv", "r") as f:
    reader = csv.reader(f)
    for row in reader:
        if row[1] != 'CCSPROC':
            icd9proc_to_ccsproc[row[0]] = row[1]

ccsproc_to_icd9proc = defaultdict(list)
for k, v in icd9proc_to_ccsproc.items():
    ccsproc_to_icd9proc[v].append(k)

ccsproc_icd9proc = {}
for k, v in ccsproc_to_icd9proc.items():
    ccsproc_icd9proc[k] = v[0]

In [82]:
# get the embeddings
ccsproc_id2emb = {}
for ccsproc_id in tqdm(ccsproc_icd9proc.keys()):
    try:
        ccsproc_id2emb[ccsproc_id] = E_emb[int(entity2id[icd9proc_to_umls[ccsproc_icd9proc[ccsproc_id]]])].detach().numpy().tolist()
    except:
        try:
            icd9procid = ccsproc_icd9proc[ccsproc_id]
            if icd9procid[0] == '0':
                icd9procid = icd9procid[1:]
            if icd9procid[-1] == '0':
                icd9procid = icd9procid[:-1]
            if icd9procid[-1] == '0':
                icd9procid = icd9procid[:-2]
            if icd9procid[-1] == '.':
                icd9procid = icd9procid[:-1]

            ccsproc_id2emb[ccsproc_id] = E_emb[int(entity2id[icd9proc_to_umls[icd9procid]])].detach().numpy().tolist()
            
        except:
            icd9procid = icd9procid[:-1]
            ccsproc_id2emb[ccsproc_id] = E_emb[int(entity2id[icd9proc_to_umls[icd9procid]])].detach().numpy().tolist()

with open(f"../resource/embeddings/KG/procedures/ccsproc.json", "w") as f:
    json.dump(ccsproc_id2emb, f, indent=6)

100%|██████████| 231/231 [00:00<00:00, 49372.41it/s]


## Special Tokens

In [57]:
special_tokens = {}
tokens = ['<pad>', '<unk>']

for token in tokens:
    special_tokens[token] = np.random.randn(512).tolist()

with open(f"../resource/embeddings/KG/special_tokens/special_tokens.json", "w") as f:
    json.dump(special_tokens, f, indent=6)