## Medical Concept Word Embedding Retrieval (HuggingFace)

In [1]:
import csv
from pyhealth.medcode.pretrained_embeddings.lm_emb.huggingface_retriever import embedding_retrieve as embedding_retriever
import numpy as np
from tqdm import tqdm
import pickle
import json
import retrying
from transformers import AutoTokenizer, AutoModel, BioGptTokenizer, BioGptForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def embedding_retrieve(model, tokenizer, phrase):
    # Encode the sentence
    inputs = tokenizer(phrase, return_tensors='pt')

    # Get the model's output 
    outputs = model(**inputs)

    # Extract the embeddings
    embedding = outputs.last_hidden_state.mean(dim=1)

    # Now, `embedding` is a tensor that contains the embedding for your sentence.
    # You can convert it to a numpy array if needed:
    embedding = embedding.detach().numpy().tolist()[0]

    return embedding


In [15]:
# MODEL_NAME = "bio_clinicalbert"
# TOKENIZER = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
# MODEL = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

# MODEL_NAME = "sapbert"
# TOKENIZER = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")
# MODEL = AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")

MODEL_NAME = "biogpt"
TOKENIZER = AutoTokenizer.from_pretrained("microsoft/biogpt")
MODEL = AutoModel.from_pretrained("microsoft/biogpt")

Some weights of the model checkpoint at microsoft/biogpt were not used when initializing BioGptModel: ['output_projection.weight']
- This IS expected if you are initializing BioGptModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BioGptModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
# @retrying.retry(stop_max_attempt_number=5000)
def retrieve_embedding(term):
    return embedding_retriever(MODEL, TOKENIZER, term)

### Special Tokens

In [7]:
st_id2emb = {}
special_tokens = ["<pad>", "<unk>"]

for token in tqdm(special_tokens):
    emb = retrieve_embedding(term=token)
    st_id2emb[token] = emb

with open(f"../resource/embeddings/LM/{MODEL_NAME}/special_tokens/special_tokens.json", "w") as f:
    json.dump(st_id2emb, f)

100%|██████████| 2/2 [00:00<00:00, 21.25it/s]


### CCSCM

In [8]:
ccscm_id2name = {}
with open('../resource/CCSCM.csv', 'r') as f:
    lines = f.readlines()
    for line in lines[1:]:
        line = line.strip().split(',')
        ccscm_id2name[line[0]] = line[1].lower()

ccscm_id2emb = {}
for key in tqdm(ccscm_id2name.keys()):
    emb = retrieve_embedding(term=ccscm_id2name[key])
    ccscm_id2emb[key] = emb

with open(f"../resource/embeddings/LM/{MODEL_NAME}/conditions/ccscm.json", "w") as f:
    json.dump(ccscm_id2emb, f)

100%|██████████| 285/285 [00:12<00:00, 22.10it/s]


### CCSPROC

In [9]:
ccsproc_id2name = {}
with open('../resource/CCSPROC.csv', 'r') as f:
    lines = f.readlines()
    for line in lines[1:]:
        line = line.strip().split(',')
        ccsproc_id2name[line[0]] = line[1].lower()

ccsproc_id2emb = {}
for key in tqdm(ccsproc_id2name.keys()):
    emb = retrieve_embedding(term=ccsproc_id2name[key])
    ccsproc_id2emb[key] = emb

with open(f"../resource/embeddings/LM/{MODEL_NAME}/procedures/ccsproc.json", "w") as f:
    json.dump(ccsproc_id2emb, f)

100%|██████████| 231/231 [00:10<00:00, 22.86it/s]


### ATC

In [10]:
atc_id2name = {}
with open("../resource/ATC.csv", newline='') as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
        # if row['level'] == '3.0':
        atc_id2name[row['code']] = row['name'].lower()

atc_id2emb = {}
for key in tqdm(atc_id2name.keys()):
    i = 0
    emb = retrieve_embedding(term=atc_id2name[key])
    atc_id2emb[key] = emb

with open(f"../resource/embeddings/LM/{MODEL_NAME}/drugs/atc.json", "w") as f:
    json.dump(atc_id2emb, f)

100%|██████████| 6440/6440 [04:01<00:00, 26.69it/s]


In [17]:
import pandas as pd
from tqdm import tqdm

# Load the TSV file into a DataFrame
df = pd.read_csv("/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/detailed/word_embedding/BioGPT/ATC_Metadata.tsv", sep='\t')

# Initialize an empty list to store the embeddings
emb_list = []

# Iterate through each row in the DataFrame
for index, row in tqdm(df.iterrows()):
    emb = retrieve_embedding(term=row['name'].lower())
    emb_list.append(emb)

# Specify the path to the output file
output_file_path = "/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/detailed/word_embedding/BioGPT/ATC_embedding.tsv"

# Open the file with write mode
with open(output_file_path, "w", newline='\n') as file:
    # Iterate through each embedding in emb_list
    for emb in emb_list:
        # Convert numerical values to string and join them with tab separator
        line = '\t'.join(map(str, emb))
        # Write the line to the file
        file.write(line + '\n')

6920it [04:56, 23.31it/s]


### ICD9CM

In [11]:
from pyhealth.medcode import ICD9CM

icd9cm_id2name = {}
with open('../resource/ICD9CM.csv', 'r') as f:
    lines = f.readlines()
    for line in lines[1:]:
        line = line.strip().split(',')
        icd9cm_id2name[line[0]] = line[2].lower()

icd9cm_id2emb = {}
for key in tqdm(icd9cm_id2name.keys()):
    emb = retrieve_embedding(term=icd9cm_id2name[key])
    icd9cm_id2emb[ICD9CM.standardize(key).replace('.', '')] = emb

with open(f"../resource/embeddings/LM/{MODEL_NAME}/conditions/icd9cm.json", "w") as f:
    json.dump(icd9cm_id2emb, f)

100%|██████████| 17736/17736 [11:34<00:00, 25.54it/s]


In [16]:
import pandas as pd
from tqdm import tqdm

# Load the TSV file into a DataFrame
df = pd.read_csv("/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/detailed/word_embedding/BioGPT/ICD9CM_Metadata.tsv", sep='\t')

# Initialize an empty list to store the embeddings
emb_list = []

# Iterate through each row in the DataFrame
for index, row in tqdm(df.iterrows()):
    emb = retrieve_embedding(term=row['name'].lower())
    emb_list.append(emb)

# Specify the path to the output file
output_file_path = "/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/detailed/word_embedding/BioGPT/ICD9CM_embedding.tsv"

# Open the file with write mode
with open(output_file_path, "w", newline='\n') as file:
    # Iterate through each embedding in emb_list
    for emb in emb_list:
        # Convert numerical values to string and join them with tab separator
        line = '\t'.join(map(str, emb))
        # Write the line to the file
        file.write(line + '\n')

17736it [13:30, 21.89it/s]


### ICD9PROC

In [12]:
from pyhealth.medcode import ICD9PROC

icd9proc_id2name = {}
with open('../resource/ICD9PROC.csv', 'r') as f:
    lines = f.readlines()
    for line in lines[1:]:
        line = line.strip().split(',')
        icd9proc_id2name[line[0]] = line[2].lower()

icd9proc_id2emb = {}
for key in tqdm(icd9proc_id2name.keys()):
    emb = retrieve_embedding(term=icd9proc_id2name[key])
    icd9proc_id2emb[ICD9PROC.standardize(key).replace('.', '')] = emb

with open(f"../resource/embeddings/LM/{MODEL_NAME}/procedures/icd9proc.json", "w") as f:
    json.dump(icd9proc_id2emb, f)

100%|██████████| 4671/4671 [03:12<00:00, 24.30it/s]


In [19]:
import pandas as pd
from tqdm import tqdm

# Load the TSV file into a DataFrame
df = pd.read_csv("/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/detailed/word_embedding/BioGPT/ICD9PROC_Metadata.tsv", sep='\t')

# Initialize an empty list to store the embeddings
emb_list = []

# Iterate through each row in the DataFrame
for index, row in tqdm(df.iterrows()):
    emb = retrieve_embedding(term=row['name'].lower())
    emb_list.append(emb)

# Specify the path to the output file
output_file_path = "/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/detailed/word_embedding/BioGPT/ICD9PROC_embedding.tsv"

# Open the file with write mode
with open(output_file_path, "w", newline='\n') as file:
    # Iterate through each embedding in emb_list
    for emb in emb_list:
        # Convert numerical values to string and join them with tab separator
        line = '\t'.join(map(str, emb))
        # Write the line to the file
        file.write(line + '\n')

4670it [03:40, 21.22it/s]


In [13]:
len(icd9proc_id2emb.keys())

4671

In [27]:
with open(f"../resource/embeddings/LM/{MODEL_NAME}/procedures/icd9proc.json", "r") as f:
    icd9proc_id2emb = json.load(f)

icd9proc_id2emb_new = {}

for key, value in icd9proc_id2emb.items():
    icd9proc_id2emb_new[key.replace('.', '')] = value
    icd9proc_id2emb_new['3605'] = icd9proc_id2emb['0066']
    icd9proc_id2emb_new['3602'] = icd9proc_id2emb['36']

with open(f"../resource/embeddings/LM/{MODEL_NAME}/procedures/icd9proc.json", "w") as f:
    json.dump(icd9proc_id2emb_new, f)

In [25]:
with open(f"../resource/embeddings/LM/{MODEL_NAME}/procedures/icd9proc.json", "r") as f:
    icd9proc_id2emb = json.load(f)

In [26]:
'3602' in icd9proc_id2emb.keys()

False