# GraphCare Tutorial

This notebook demonstrates how to use the **GraphCare** model for healthcare predictions with personalized knowledge graphs in PyHealth.

**Contributors:** Josh Steier

**Paper:** Pengcheng Jiang et al. *GraphCare: Enhancing Healthcare Predictions with Personalized Knowledge Graphs.* ICLR 2024.

## Overview

1. Understand the GraphCare architecture
2. Build synthetic patient knowledge graphs
3. Instantiate the model with different GNN backbones
4. Train and evaluate on binary classification
5. Inspect attention weights for interpretability

## 1. Environment Setup

Import required libraries and set seeds for reproducibility.

In [1]:
import os
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

Running on device: cpu


In [2]:
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader as PyGDataLoader

from pyhealth.models.graphcare import GraphCare, BiAttentionGNNConv

print("All imports successful!")

All imports successful!


## 2. Understanding GraphCare Architecture

GraphCare operates on **personalized patient knowledge graphs** derived from EHR codes.

### Architecture Overview

```
Patient EHR Record
  ├── Conditions  ──┐
  ├── Procedures  ──┤── Concept-specific KGs ── Patient KG
  └── Medications ──┘                              │
                                                    ▼
                                           ┌───────────────┐
                                           │  GNN Encoder   │
                                           │  (BAT/GAT/GIN) │
                                           │  + Bi-Attention │
                                           └───────┬───────┘
                                                   │
                                    ┌──────────────┼──────────────┐
                                    ▼              ▼              ▼
                              graph pool     node embed      joint (concat)
                                    │              │              │
                                    └──────────────┴──────────────┘
                                                   │
                                                   ▼
                                              MLP Head
                                                   │
                                                   ▼
                                              Prediction
```

### Key Components

- **Bi-Attention (BAT):** Visit-level (α) and node-level (β) attention with temporal decay
- **Temporal Decay:** λ_j = exp(γ(V−j)) — more recent visits get higher weight
- **Patient Modes:** `graph` (global pool), `node` (EHR-node avg), `joint` (both concatenated)

## 3. Create Synthetic Patient Graphs

In a real setting, patient KGs are constructed from EHR codes using LLM-prompted subgraphs and medical KGs (see the [GraphCare repo](https://github.com/pat-jj/GraphCare) for the generation pipeline).

Here we create synthetic data to demonstrate the model API.

In [3]:
# --- Global KG parameters ---
NUM_NODES = 500       # Total cluster nodes in the KG
NUM_RELS = 50         # Total relation types
MAX_VISIT = 5         # Max visits per patient
EMBEDDING_DIM = 64    # Pre-trained embedding dim
HIDDEN_DIM = 64       # Model hidden dim
NUM_PATIENTS = 200    # Synthetic patients

# --- Fake pre-trained embeddings ---
# In practice these come from word2vec / TransE on the KG
node_emb = torch.randn(NUM_NODES, EMBEDDING_DIM)
rel_emb = torch.randn(NUM_RELS, EMBEDDING_DIM)

print(f"Node embeddings: {node_emb.shape}")
print(f"Relation embeddings: {rel_emb.shape}")

Node embeddings: torch.Size([500, 64])
Relation embeddings: torch.Size([50, 64])


In [4]:
def create_synthetic_patient_graph(patient_idx, num_nodes, num_rels, max_visit):
    """Create a single synthetic patient KG as a PyG Data object.
    
    Each patient graph is a subgraph of the global KG, with:
    - Node IDs (y): indices into the global KG node embedding table
    - Relation IDs: indices into the global relation embedding table
    - visit_padded_node: binary (max_visit, num_nodes) indicating which
      KG nodes appear in each visit
    - ehr_nodes: binary (num_nodes,) indicating direct EHR nodes
    - label: binary mortality label
    """
    # Random subgraph size
    n = random.randint(15, 40)   # nodes in this patient's subgraph
    e = random.randint(n, n * 3)  # edges
    
    # Local edge indices (within the subgraph)
    src = torch.randint(0, n, (e,))
    dst = torch.randint(0, n, (e,))
    
    # Node IDs: which global KG nodes are in this subgraph
    y = torch.randint(0, num_nodes, (n,))
    
    # Relation IDs per edge
    relation = torch.randint(0, num_rels, (e,))
    
    # Visit-padded node indicators
    vpn = torch.zeros(max_visit, num_nodes)
    for v in range(max_visit):
        # Each visit activates some random KG nodes
        active = y[torch.randint(0, n, (random.randint(2, 8),))]
        vpn[v, active] = 1.0
    
    # Direct EHR nodes (subset of the patient's KG nodes)
    ehr = torch.zeros(num_nodes)
    ehr_active = y[torch.randint(0, n, (random.randint(3, 10),))]
    ehr[ehr_active] = 1.0
    
    # Binary label (mortality)
    label = torch.tensor([float(patient_idx % 2)])
    
    data = Data(
        edge_index=torch.stack([src, dst]),
        y=y,
        relation=relation,
        visit_padded_node=vpn,
        ehr_nodes=ehr,
        label=label,
    )
    data.num_nodes = n
    return data


# Create dataset
all_graphs = [
    create_synthetic_patient_graph(i, NUM_NODES, NUM_RELS, MAX_VISIT)
    for i in range(NUM_PATIENTS)
]

print(f"Created {len(all_graphs)} patient graphs")
print(f"Example graph: {all_graphs[0]}")

Created 200 patient graphs
Example graph: Data(edge_index=[2, 49], y=[35], relation=[49], visit_padded_node=[5, 500], ehr_nodes=[500], label=[1], num_nodes=35)


In [5]:
# --- Train / Val / Test split ---
n_train = int(0.8 * NUM_PATIENTS)
n_val = int(0.1 * NUM_PATIENTS)

train_graphs = all_graphs[:n_train]
val_graphs = all_graphs[n_train:n_train + n_val]
test_graphs = all_graphs[n_train + n_val:]

BATCH_SIZE = 16

train_loader = PyGDataLoader(train_graphs, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_loader = PyGDataLoader(val_graphs, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
test_loader = PyGDataLoader(test_graphs, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

print(f"Train/Val/Test sizes: {len(train_graphs)}, {len(val_graphs)}, {len(test_graphs)}")

Train/Val/Test sizes: 160, 20, 20


## 4. Inspect Batch Structure

Examine what a PyG batch looks like after collation.

In [6]:
batch = next(iter(train_loader))

print("Batch attributes:")
print(f"  edge_index: {batch.edge_index.shape}")
print(f"  y (node_ids): {batch.y.shape}")
print(f"  relation: {batch.relation.shape}")
print(f"  batch vector: {batch.batch.shape} (max={batch.batch.max().item()})")
print(f"  visit_padded_node: {batch.visit_padded_node.shape}")
print(f"  ehr_nodes: {batch.ehr_nodes.shape}")
print(f"  label: {batch.label.shape}")
print()
print("Note: visit_padded_node and ehr_nodes need reshaping before forward().")
print(f"  visit_node reshaped: ({BATCH_SIZE}, {MAX_VISIT}, {NUM_NODES})")
print(f"  ehr_nodes reshaped:  ({BATCH_SIZE}, {NUM_NODES})")

Batch attributes:
  edge_index: torch.Size([2, 783])
  y (node_ids): torch.Size([436])
  relation: torch.Size([783])
  batch vector: torch.Size([436]) (max=15)
  visit_padded_node: torch.Size([80, 500])
  ehr_nodes: torch.Size([8000])
  label: torch.Size([16])

Note: visit_padded_node and ehr_nodes need reshaping before forward().
  visit_node reshaped: (16, 5, 500)
  ehr_nodes reshaped:  (16, 500)


## 5. Instantiate GraphCare

Compare the three GNN backbones (BAT, GAT, GIN) and three patient modes.

In [7]:
print("GraphCare Configuration Comparison")
print("=" * 60)

configs = [
    ("BAT", "joint"),
    ("BAT", "graph"),
    ("BAT", "node"),
    ("GAT", "joint"),
    ("GIN", "joint"),
]

for gnn, mode in configs:
    model = GraphCare(
        num_nodes=NUM_NODES,
        num_rels=NUM_RELS,
        max_visit=MAX_VISIT,
        embedding_dim=EMBEDDING_DIM,
        hidden_dim=HIDDEN_DIM,
        out_channels=1,
        layers=2,
        node_emb=node_emb,
        rel_emb=rel_emb,
        gnn=gnn,
        patient_mode=mode,
    )
    n_params = sum(p.numel() for p in model.parameters())
    print(f"  {gnn:>3s}/{mode:<6s} — {n_params:>10,} params")

GraphCare Configuration Comparison
  BAT/joint  —    549,941 params
  BAT/graph  —    549,877 params
  BAT/node   —    549,877 params
  GAT/joint  —    550,067 params
  GIN/joint  —    549,811 params


### 5.1 Forward Pass Verification

In [8]:
model = GraphCare(
    num_nodes=NUM_NODES,
    num_rels=NUM_RELS,
    max_visit=MAX_VISIT,
    embedding_dim=EMBEDDING_DIM,
    hidden_dim=HIDDEN_DIM,
    out_channels=1,
    layers=2,
    dropout=0.5,
    decay_rate=0.01,
    node_emb=node_emb,
    rel_emb=rel_emb,
    gnn="BAT",
    patient_mode="joint",
    use_alpha=True,
    use_beta=True,
    use_edge_attn=True,
).to(device)

# Reshape batch tensors
batch = next(iter(train_loader))
batch = batch.to(device)

node_ids = batch.y
rel_ids = batch.relation
edge_index = batch.edge_index
batch_vec = batch.batch
visit_node = batch.visit_padded_node.reshape(BATCH_SIZE, MAX_VISIT, NUM_NODES).float()
ehr_nodes = batch.ehr_nodes.reshape(BATCH_SIZE, NUM_NODES).float()

# Forward pass
model.eval()
with torch.no_grad():
    logits = model(node_ids, rel_ids, edge_index, batch_vec, visit_node, ehr_nodes)

print(f"Input: {batch.y.shape[0]} total nodes across {BATCH_SIZE} patient graphs")
print(f"Output logits: {logits.shape}")
print(f"Predictions: {torch.sigmoid(logits).squeeze().cpu().numpy().round(3)}")

Input: 488 total nodes across 16 patient graphs
Output logits: torch.Size([16, 1])
Predictions: [0.508 0.484 0.51  0.493 0.505 0.522 0.501 0.493 0.528 0.482 0.509 0.447
 0.505 0.492 0.509 0.502]


## 6. Training Loop

Train GraphCare with a standard PyTorch loop. Note that GraphCare uses `torch_geometric.loader.DataLoader` rather than PyHealth's `Trainer`, since the data pipeline requires PyG batching.

In [9]:
from sklearn.metrics import roc_auc_score, average_precision_score


def train_one_epoch(model, loader, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0.0
    
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        
        bs = BATCH_SIZE
        vn = data.visit_padded_node.reshape(bs, MAX_VISIT, NUM_NODES).float()
        en = data.ehr_nodes.reshape(bs, NUM_NODES).float()
        
        logits = model(
            node_ids=data.y,
            rel_ids=data.relation,
            edge_index=data.edge_index,
            batch=data.batch,
            visit_node=vn,
            ehr_nodes=en,
            in_drop=True,
        )
        
        labels = data.label.reshape(bs, -1).float()
        loss = F.binary_cross_entropy_with_logits(logits, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    return total_loss / len(loader)


@torch.no_grad()
def evaluate(model, loader, device):
    """Evaluate and return metrics."""
    model.eval()
    y_true_all, y_prob_all = [], []
    
    for data in loader:
        data = data.to(device)
        bs = BATCH_SIZE
        vn = data.visit_padded_node.reshape(bs, MAX_VISIT, NUM_NODES).float()
        en = data.ehr_nodes.reshape(bs, NUM_NODES).float()
        
        logits = model(
            node_ids=data.y,
            rel_ids=data.relation,
            edge_index=data.edge_index,
            batch=data.batch,
            visit_node=vn,
            ehr_nodes=en,
        )
        
        labels = data.label.reshape(bs, -1)
        y_prob_all.append(torch.sigmoid(logits).cpu())
        y_true_all.append(labels.cpu())
    
    y_true = torch.cat(y_true_all).numpy()
    y_prob = torch.cat(y_prob_all).numpy()
    
    return {
        "roc_auc": roc_auc_score(y_true, y_prob),
        "pr_auc": average_precision_score(y_true, y_prob),
    }

In [10]:
# --- Training ---
EPOCHS = 10
LR = 1e-3

model = GraphCare(
    num_nodes=NUM_NODES,
    num_rels=NUM_RELS,
    max_visit=MAX_VISIT,
    embedding_dim=EMBEDDING_DIM,
    hidden_dim=HIDDEN_DIM,
    out_channels=1,
    layers=2,
    dropout=0.5,
    decay_rate=0.01,
    node_emb=node_emb,
    rel_emb=rel_emb,
    gnn="BAT",
    patient_mode="joint",
    drop_rate=0.1,
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model: BAT/joint, {total_params:,} parameters")
print(f"Training for {EPOCHS} epochs...\n")

for epoch in range(1, EPOCHS + 1):
    train_loss = train_one_epoch(model, train_loader, optimizer, device)
    val_metrics = evaluate(model, val_loader, device)
    
    print(
        f"Epoch {epoch:2d} | loss={train_loss:.4f} | "
        f"val_roc_auc={val_metrics['roc_auc']:.4f} | "
        f"val_pr_auc={val_metrics['pr_auc']:.4f}"
    )

Model: BAT/joint, 549,941 parameters
Training for 10 epochs...

Epoch  1 | loss=0.7127 | val_roc_auc=0.5938 | val_pr_auc=0.6341
Epoch  2 | loss=0.6975 | val_roc_auc=0.5625 | val_pr_auc=0.6046
Epoch  3 | loss=0.6992 | val_roc_auc=0.5625 | val_pr_auc=0.5265
Epoch  4 | loss=0.6872 | val_roc_auc=0.5625 | val_pr_auc=0.5225
Epoch  5 | loss=0.6639 | val_roc_auc=0.5312 | val_pr_auc=0.5072
Epoch  6 | loss=0.6854 | val_roc_auc=0.4844 | val_pr_auc=0.4868
Epoch  7 | loss=0.6784 | val_roc_auc=0.5000 | val_pr_auc=0.4957
Epoch  8 | loss=0.6564 | val_roc_auc=0.4219 | val_pr_auc=0.4709
Epoch  9 | loss=0.6403 | val_roc_auc=0.3438 | val_pr_auc=0.4432
Epoch 10 | loss=0.6539 | val_roc_auc=0.3438 | val_pr_auc=0.4446


## 7. Evaluate on Test Set

In [11]:
test_metrics = evaluate(model, test_loader, device)

print("Test Results")
print("=" * 40)
for k, v in test_metrics.items():
    print(f"  {k}: {v:.4f}")

Test Results
  roc_auc: 0.7969
  pr_auc: 0.8318


## 8. Attention Interpretability

GraphCare's BAT layers produce interpretable attention weights:
- **Alpha (α):** Visit-level attention — which visits matter most
- **Beta (β):** Node-level attention with temporal decay — which nodes matter, weighted by recency

In [12]:
# Get attention weights
batch = next(iter(test_loader))
batch = batch.to(device)

vn = batch.visit_padded_node.reshape(BATCH_SIZE, MAX_VISIT, NUM_NODES).float()
en = batch.ehr_nodes.reshape(BATCH_SIZE, NUM_NODES).float()

model.eval()
with torch.no_grad():
    logits, alphas, betas, attns, edge_ws = model(
        node_ids=batch.y,
        rel_ids=batch.relation,
        edge_index=batch.edge_index,
        batch=batch.batch,
        visit_node=vn,
        ehr_nodes=en,
        store_attn=True,
    )

print("Attention weight shapes (per layer):")
for i, (a, b, att, ew) in enumerate(zip(alphas, betas, attns, edge_ws)):
    print(f"  Layer {i+1}:")
    print(f"    alpha:     {a.shape}  (batch, max_visit, num_nodes)")
    print(f"    beta:      {b.shape}  (batch, max_visit, 1)")
    print(f"    attn:      {att.shape}  (num_edges_in_batch, 1)")
    print(f"    edge_w:    {ew.shape}  (num_edges_in_batch, 1)")

Attention weight shapes (per layer):
  Layer 1:
    alpha:     torch.Size([16, 5, 500])  (batch, max_visit, num_nodes)
    beta:      torch.Size([16, 5, 1])  (batch, max_visit, 1)
    attn:      torch.Size([914, 1])  (num_edges_in_batch, 1)
    edge_w:    torch.Size([914, 1])  (num_edges_in_batch, 1)
  Layer 2:
    alpha:     torch.Size([16, 5, 500])  (batch, max_visit, num_nodes)
    beta:      torch.Size([16, 5, 1])  (batch, max_visit, 1)
    attn:      torch.Size([914, 1])  (num_edges_in_batch, 1)
    edge_w:    torch.Size([914, 1])  (num_edges_in_batch, 1)


In [13]:
# --- Visualise visit-level attention for first patient ---
import matplotlib
matplotlib.use("Agg")  # Non-interactive backend
import matplotlib.pyplot as plt

patient_idx = 0

# Beta weights for patient 0, layer 1: (max_visit, 1)
beta_patient = betas[0][patient_idx].squeeze().cpu().numpy()

fig, ax = plt.subplots(figsize=(8, 3))
visits = [f"Visit {j+1}" for j in range(MAX_VISIT)]
bars = ax.bar(visits, beta_patient, color="steelblue")
ax.set_ylabel("β attention × temporal decay")
ax.set_title("Visit-Level Attention Weights (Patient 0, Layer 1)")
ax.axhline(y=0, color="gray", linestyle="--", linewidth=0.5)
plt.tight_layout()
plt.show()
print("Note: Recent visits (higher index) get boosted by temporal decay λ_j = exp(γ(V−j))")

Note: Recent visits (higher index) get boosted by temporal decay λ_j = exp(γ(V−j))


  plt.show()


In [14]:
# --- Top attended nodes ---
# Alpha for patient 0, layer 1: (max_visit, num_nodes)
alpha_patient = alphas[0][patient_idx].cpu()  # (max_visit, num_nodes)

# Sum across visits to find globally important nodes
node_importance = alpha_patient.sum(dim=0).numpy()  # (num_nodes,)

top_k = 10
top_indices = np.argsort(node_importance)[-top_k:][::-1]

print(f"Top {top_k} most attended KG nodes for Patient 0:")
print(f"{'Node ID':>8}  {'Importance':>12}")
print("-" * 24)
for idx in top_indices:
    print(f"{idx:>8d}  {node_importance[idx]:>12.4f}")
print()
print("In a real setting, these node IDs map to medical concepts")
print("(conditions, procedures, drugs) via the cluster mapping.")

Top 10 most attended KG nodes for Patient 0:
 Node ID    Importance
------------------------
     492        1.0000
     103        1.0000
     450        1.0000
      92        1.0000
     153        1.0000
     167        1.0000
     205        1.0000
     204        1.0000
     200        1.0000
     216        1.0000

In a real setting, these node IDs map to medical concepts
(conditions, procedures, drugs) via the cluster mapping.


## 9. Using Pre-Built KG Artifacts (Real Data)

For real EHR data with pre-built KGs, use the `graphcare_utils` module:

```python
from pyhealth.models.graphcare_utils import (
    load_kg_artifacts,
    prepare_graphcare_data,
    build_graphcare_dataloaders,
    reshape_batch_tensors,
)

# 1. Load pre-built KG artifacts
artifacts = load_kg_artifacts(
    sample_dataset_path="sample_dataset_mimic3_mortality_th015.pkl",
    graph_path="graph_mimic3_mortality_th015.pkl",
    ent_emb_path="entity_embedding.pkl",
    rel_emb_path="relation_embedding.pkl",
    cluster_path="clusters_th015.json",
    cluster_rel_path="clusters_rel_th015.json",
    ccscm_id2clus_path="ccscm_id2clus.json",
    ccsproc_id2clus_path="ccsproc_id2clus.json",
)

# 2. Prepare data (labels, embeddings, splits)
prepared = prepare_graphcare_data(artifacts, task="mortality")

# 3. Build PyG DataLoaders
train_loader, val_loader, test_loader = build_graphcare_dataloaders(
    prepared, batch_size=64,
)

# 4. Build model
model = GraphCare(
    num_nodes=prepared["num_nodes"],
    num_rels=prepared["num_rels"],
    max_visit=prepared["max_visit"],
    embedding_dim=prepared["node_emb"].shape[1],
    hidden_dim=128,
    out_channels=prepared["task_config"]["out_channels"],
    node_emb=prepared["node_emb"],
    rel_emb=prepared["rel_emb"],
    gnn="BAT",
    patient_mode="joint",
)

# 5. Use reshape_batch_tensors in the training loop
for data in train_loader:
    batch_tensors = reshape_batch_tensors(
        data, batch_size=64,
        max_visit=prepared["max_visit"],
        num_nodes=prepared["num_nodes"],
    )
    logits = model(**batch_tensors)  # minus 'label'
```

## Summary

### GraphCare Model

Predicts healthcare outcomes using personalized patient knowledge graphs with bi-attention augmented GNNs.

| Component | Options | Description |
|-----------|---------|-------------|
| **GNN Backbone** | `BAT` (default), `GAT`, `GIN` | Message-passing layer; BAT adds bi-attention + edge attention |
| **Patient Mode** | `joint` (default), `graph`, `node` | How to produce patient-level representation |
| **Attention** | α (visit), β (node) | Visit-level and node-level attention with temporal decay |
| **Temporal Decay** | λ_j = exp(γ(V−j)) | Exponential weighting favouring recent visits |

### Key Features

- Pre-trained node/relation embeddings (from TransE or word2vec on KG)
- Optional edge dropout for regularisation
- Attention weights for clinical interpretability
- Supports mortality, readmission, drug recommendation, length-of-stay tasks

### Files

| File | Purpose |
|------|------|
| `pyhealth/models/graphcare.py` | Model implementation (GraphCare + BiAttentionGNNConv) |
| `pyhealth/models/graphcare_utils.py` | Data pipeline utilities (KG loading, subgraph extraction, dataloaders) |
| `examples/train_graphcare.py` | End-to-end training script with CLI |
| `tests/test_graphcare.py` | Unit tests |