In [1]:
dataset="mimic3"
task="mortality"
kg ="GPT-KG"

In [2]:
from graphcare import *

sample_dataset, G, ent2id, rel2id, ent_emb, rel_emb, \
            map_cluster, map_cluster_inv, map_cluster_rel, map_cluster_rel_inv, \
                ccscm_id2clus, ccsproc_id2clus, atc3_id2clus = load_everything(dataset, task, kg)
mode, out_channels, loss_function = get_mode_and_out_channels_and_loss_func(task=task, sample_dataset=sample_dataset)

# label direct ehr node
print("Labeling direct ehr nodes...")
sample_dataset = label_ehr_nodes(task, sample_dataset, len(map_cluster), ccscm_id2clus, ccsproc_id2clus, atc3_id2clus)
print("Splitting dataset...")
train_dataset, val_dataset, test_dataset = split_by_patient(sample_dataset, [0.8, 0.1, 0.1], seed=528)
G_tg = from_networkx(G)

# get embedding
print("Getting embedding...")
rel_emb = get_rel_emb(map_cluster_rel)
node_emb = G_tg.x 

# get dataloader
print("Getting dataloader...")
train_loader, val_loader, test_loader = get_dataloader(G_tg, train_dataset, val_dataset, test_dataset, task, 4)

  from .autonotebook import tqdm as notebook_tqdm


Labeling direct ehr nodes...


100%|██████████| 9717/9717 [00:00<00:00, 18743.72it/s]


Splitting dataset...
Getting embedding...
Getting dataloader...


In [5]:
device = torch.device("cuda:2" if torch.cuda.is_available() else 'cpu')

In [6]:
model = GraphCare(
    num_nodes=node_emb.shape[0],
    num_rels=rel_emb.shape[0],
    max_visit=sample_dataset[0]['visit_padded_node'].shape[0],
    embedding_dim=node_emb.shape[1],
    hidden_dim=512,
    out_channels=out_channels,
    layers=1,
    dropout=0.5,
    decay_rate=0.01,
    node_emb=node_emb,
    rel_emb=rel_emb,
    patient_mode="joint",
    use_alpha=True,
    use_beta=True,
    use_edge_attn=True,
    gnn="BAT",
    freeze=False
)
model.to(device)

GraphCare(
  (node_emb): Embedding(4599, 1536)
  (rel_emb): Embedding(1077, 1536)
  (lin): Linear(in_features=1536, out_features=512, bias=True)
  (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (alpha_attn): ModuleDict(
    (1): Linear(in_features=4599, out_features=4599, bias=True)
  )
  (beta_attn): ModuleDict(
    (1): Linear(in_features=4599, out_features=1, bias=True)
  )
  (conv): ModuleDict(
    (1): BiAttentionGNNConv(nn=Linear(in_features=512, out_features=512, bias=True))
  )
  (bn_gnn): ModuleDict()
  (leakyrelu): LeakyReLU(negative_slope=0.1)
  (relu): ReLU()
  (tahh): Tanh()
  (MLP): Linear(in_features=1024, out_features=1, bias=True)
)

In [7]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

In [None]:
train_loop(
    dataset=dataset,
    task=task,
    mode=mode,
    patient_mode="joint",
    gnn=model.gnn, 
    train_loader=train_loader, 
    val_loader=val_loader, 
    model=model, 
    optimizer=optimizer, 
    loss_func=loss_function, 
    device=device, 
    epochs=20, 
    )

In [9]:
model_ = GraphCare(
    num_nodes=node_emb.shape[0],
    num_rels=rel_emb.shape[0],
    max_visit=sample_dataset[0]['visit_padded_node'].shape[0],
    embedding_dim=node_emb.shape[1],
    hidden_dim=512,
    out_channels=out_channels,
    layers=1,
    dropout=0.5,
    decay_rate=0.01,
    node_emb=node_emb,
    rel_emb=rel_emb,
    patient_mode="joint",
    use_alpha=True,
    use_beta=True,
    use_edge_attn=True,
    gnn="BAT",
)
model_.load_state_dict(torch.load(f'../../../data/pj20/exp_data/saved_weights_{dataset}_{task}_BAT.pkl'))
model_.eval()

GraphCare(
  (node_emb): Embedding(4599, 1536)
  (rel_emb): Embedding(1077, 1536)
  (lin): Linear(in_features=1536, out_features=512, bias=True)
  (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (alpha_attn): ModuleDict(
    (1): Linear(in_features=4599, out_features=4599, bias=True)
  )
  (beta_attn): ModuleDict(
    (1): Linear(in_features=4599, out_features=1, bias=True)
  )
  (conv): ModuleDict(
    (1): BiAttentionGNNConv(nn=Linear(in_features=512, out_features=512, bias=True))
  )
  (bn_gnn): ModuleDict()
  (leakyrelu): LeakyReLU(negative_slope=0.1)
  (relu): ReLU()
  (tahh): Tanh()
  (MLP): Linear(in_features=1024, out_features=1, bias=True)
)

In [10]:
cnt = 0
d_patient_ids = []

for patient in val_dataset:
    if patient['label'] == 1:
        cnt += 1
        d_patient_ids.append(patient['patient_id'])

len(val_dataset), cnt

(991, 77)

In [11]:
from tqdm import tqdm

loader = val_loader
y_prob_all = []
y_true_all = []
patient_ids = []
attentions = []
edge_weights = []
rel_ids_ = []

data_ = None
for data in tqdm(loader):
    data_ = data
    with torch.no_grad():
        node_ids = data.y
        rel_ids = data.relation
        ehr_nodes = data.ehr_nodes.reshape(int(loader.batch_size), int(len(data.ehr_nodes)/loader.batch_size)).float()
        visit_node = data.visit_padded_node.reshape(int(loader.batch_size), int(len(data.visit_padded_node)/loader.batch_size), data.visit_padded_node.shape[1]).float()

        logits, alpha, beta, attn_w, rel_w = model_(
                node_ids = node_ids, 
                rel_ids = rel_ids,
                edge_index = data.edge_index,
                batch = data.batch,
                visit_node = visit_node,
                ehr_nodes = ehr_nodes,
                store_attn = True,

            )
        
        y_prob = torch.sigmoid(logits)
        y_true = data.label.reshape(int(loader.batch_size), int(len(data.label)/loader.batch_size))
        y_prob_all.append(y_prob.cpu())
        y_true_all.append(y_true.cpu())
        for i in range(4):
            if y_prob[i] >= 0.2 and y_true[i] == 1:
                if data.patient_id[i] in d_patient_ids:
                    attentions.append(alpha[0][i].cpu())
                    edge_weights.append(rel_w[0].cpu())
                    rel_ids_.append(rel_ids.cpu())
                    patient_ids.append(data.patient_id[i])
                    
y_true_all = np.concatenate(y_true_all, axis=0)
y_prob_all = np.concatenate(y_prob_all, axis=0)
# np.where(y_true_all == 1)[0], y_prob_all[np.where(y_true_all == 1)[0]]

100%|██████████| 247/247 [00:26<00:00,  9.48it/s]


In [12]:
patient_ids

[]

In [9]:
target_patients = []

for patient in val_dataset:
    if patient['patient_id'] in patient_ids:
        target_patients.append(patient)


In [10]:
target_patient = target_patients[1]
edge_weights = edge_weights[1]
rel_ids = rel_ids_[1]
attentions = attentions[1]


In [11]:
torch.min(edge_weights)

tensor(-0.0023)

In [12]:
import torch

min_val = torch.min(edge_weights)
max_val = torch.max(edge_weights)

normed_edge_weights = (edge_weights - min_val) / (max_val - min_val)

In [13]:
normed_edge_weights

tensor([[0.1510],
        [0.2756],
        [0.2756],
        ...,
        [0.1532],
        [0.1008],
        [1.0000]])

In [14]:
rel_weights = {}

for i in range(len(rel_ids)):
    rel_weights[int(rel_ids[i])] =  float(normed_edge_weights[i])

In [113]:
# map_cluster, map_cluster_inv, map_cluster_rel, map_cluster_rel_inv

In [96]:
attentions = attentions.cpu().numpy()

In [97]:
attentions = attentions

In [15]:
with open("/data/pj20/exp_data/ccscm_ccsproc_atc3/clusters_th015.json", 'r') as f:
    map_cluster = json.load(f)

In [16]:
import networkx as nx
from pyvis.network import Network
import matplotlib.pyplot as plt
from collections import defaultdict

net = Network()
Graph = nx.Graph()
edge_labels = {}

Graph.add_node("PATIENT")

conditions = target_patient['conditions'][0]
procedures = target_patient['procedures'][0]
drugs = target_patient['drugs'][0]

node_set_all = set()
node_set_list = []
cluster_included_entities = defaultdict(list)
cluster_included_relations = defaultdict(list)

for condition in tqdm(conditions):

    # add direct EHR node and edge
    Graph.add_node(ccscm_id2clus[condition])
    Graph.add_edge("PATIENT", ccscm_id2clus[condition], label="condition")


    cond_file = f'./graphs/condition/CCSCM/{condition}.txt'
    with open(cond_file, 'r') as f:
        lines = f.readlines()
    for line in lines:
        items = line.split('\t')
        try:
            if len(items) == 3:
                h, r, t = items
                t = t[:-1]
                h_id = ent2id[h]
                t_id = ent2id[t]
                r_id = rel2id[r]

                h_node = int(map_cluster_inv[h_id])
                t_node = int(map_cluster_inv[t_id])
                r_edge = int(map_cluster_rel_inv[r_id])

                cluster_included_entities[h_node].append(h)
                cluster_included_entities[t_node].append(t)
                cluster_included_relations[r_edge].append(r)

                Graph.add_node(h_node, label=cluster_included_entities[h_node][0], weight=map_cluster[str(h_node)]['attention_mortality'])
                Graph.add_node(t_node, label=cluster_included_entities[t_node][0], weight=map_cluster[str(t_node)]['attention_mortality'])
                Graph.add_edge(h_node, t_node, label=cluster_included_relations[r_edge][0], weight=rel_weights[r_edge])
        except:
            continue


for procedure in tqdm(procedures):
    # add direct EHR node and edge
    Graph.add_node(ccsproc_id2clus[procedure])
    Graph.add_edge("PATIENT", ccsproc_id2clus[procedure], label="procedure")

    proc_file = f'./graphs/procedure/CCSPROC/{procedure}.txt'
    with open(proc_file, 'r') as f:
        lines = f.readlines()
    for line in lines:
        items = line.split('\t')
        try:
            if len(items) == 3:
                h, r, t = items
                t = t[:-1]
                h_id = ent2id[h]
                t_id = ent2id[t]
                r_id = rel2id[r]

                h_node = int(map_cluster_inv[h_id])
                t_node = int(map_cluster_inv[t_id])
                r_edge = int(map_cluster_rel_inv[r_id])

                cluster_included_entities[h_node].append(h)
                cluster_included_entities[t_node].append(t)
                cluster_included_relations[r_edge].append(r)

                Graph.add_node(h_node, label=cluster_included_entities[h_node][0], weight=map_cluster[str(h_node)]['attention_mortality'])
                Graph.add_node(t_node, label=cluster_included_entities[t_node][0], weight=map_cluster[str(t_node)]['attention_mortality'])
                Graph.add_edge(h_node, t_node, label=cluster_included_relations[r_edge][0], weight=rel_weights[r_edge])
        except:
            continue

# for drug in tqdm(drugs):    
#     Graph.add_node(atc3_id2clus[drug])
#     Graph.add_edge("PATIENT", atc3_id2clus[drug], label="drug")
#     drug_file = f'./graphs/drug/ATC3/{drug}.txt'

#     with open(drug_file, 'r') as f:
#         lines = f.readlines()

#     for line in lines:
#         items = line.split('\t')
#         try:
#             if len(items) == 3:
#                 h, r, t = items
#                 t = t[:-1]
#                 h_id = ent2id[h]
#                 t_id = ent2id[t]
#                 r_id = rel2id[r]

#                 h_node = map_cluster_inv[h_id]
#                 t_node = map_cluster_inv[t_id]
#                 r_edge = map_cluster_rel_inv[r_id]

#                 cluster_included_entities[h_node].append(h)
#                 cluster_included_entities[t_node].append(t)
#                 cluster_included_relations[r_edge].append(r)

#                 Graph.add_node(h_node)
#                 Graph.add_node(t_node)
#                 Graph.add_edge(h_node, t_node, label=r_edge)

#         except:
#             continue



100%|██████████| 8/8 [00:00<00:00, 1337.42it/s]
100%|██████████| 3/3 [00:00<00:00, 2062.43it/s]


In [99]:
attentions.shape

(28, 4599)

In [17]:
nx.write_gexf(Graph, 'graph.gexf')