Skip to content

[ICLR'24] Enhancing Healthcare Predictions with Personalized Knowledge Graphs

License

Notifications You must be signed in to change notification settings

pat-jj/GraphCare

Repository files navigation

GraphCare

Code for the paper: GraphCare: Enhancing Healthcare Predictions with Personalized Knowledge Graphs in ICLR'24.

Requirements:

pip install torch==1.12.0
pip install torch-geometric==2.3.0
pip install pyhealth==1.1.2
pip install scikit-learn==1.2.1
pip install openai==0.27.4

We follow the flow of methodology section (Section 3) to explain our implementation.

1. Concept-specific Knowledge Graph (KG) Generation

1.1 LLM-based KG extraction via prompting

The jupyter notebook to prompt KG for EHR medical code:

/graphcare_/graph_generation/graph_gen.ipynb

We place sample KGs generated by GPT-4 as

/graphs/{condition/CCSCM,procedure/CCSPROC,drug/ATC3}/{code_id}.txt

1.2 Subgraph sampling from existing KGs

The script for subgraph sampling from UMLS:

/KG_mapping/umls_sampling.py

We place 2-hop sample KGs randomly subsampled from UMLS as

/graphs/umls_2hop.csv

1.3 Word Embedding Retrieval for Nodes & Edges

The jupyter notebooks for word embedding retrieval:

/graphcare_/graph_generation/{cond,proc,drug}_emb_ret.ipynb

Due to the large size of word embedding, we do not include them in the repo. You can use our script to retrieve it and store it in either

/graphs/cond_proc/{entity_embedding.pkl, relation_embedding.pkl}
or
/graphs/cond_proc_drug/{entity_embedding.pkl, relation_embedding.pkl}

depending on the features used for the prediction tasks.

1.4 Node & Edge Clustering

The function for node & edge clustering:

clustering() in data_prepare.py

We place some clustering results (only "_inv" as cluster embedding has large size) in

/clustering/

2. Personalized Knowledge Graph Composition

process_sample_dataset() and process_graph() in data_prepare.py
&
get_subgraph() in graphcare.py

3. Bi-attention Augmented (BAT) Graph Neural Network

The implementation of our proposed BAT model is in

/graphcare_/model.py

4. Training and Prediction

The creation of task-specific datasets (using PyHealth) is in

data_prepare.py

The training and prediction details are in

graphcare.py

Run Baseline Models

The scripts running baseline models are placed in

ehr_models.py

Cite GraphCare

@inproceedings{jiang2023graphcare,
  title={GraphCare: Enhancing Healthcare Predictions with Personalized Knowledge Graphs},
  author={Jiang, Pengcheng and Xiao, Cao and Cross, Adam Richard and Sun, Jimeng},
  booktitle={The Twelfth International Conference on Learning Representations},
  year={2023}
}

Thanks for your interest in our work! 😊