In [2]:
import pickle
from pyhealth.datasets import SampleDataset
from pyhealth.datasets import split_by_patient
from torch_geometric.utils import to_networkx, from_networkx

# with open('../../../data/pj20/exp_data/ccscm_ccsproc_atc3/sample_dataset_multiclass_lenofstay_th015.pkl', 'rb') as f:
#     sample_dataset = pickle.load(f)

# with open('../../../data/pj20/exp_data/ccscm_ccsproc_atc3/graph_multiclass_lenofstay_th015.pkl', 'rb') as f:
#     G = pickle.load(f)

with open('../../../data/pj20/exp_data/ccscm_ccsproc_atc3/sample_dataset_multiclass_lenofstay_mimic4_th015.pkl', 'rb') as f:
    sample_dataset = pickle.load(f)

with open('../../../data/pj20/exp_data/ccscm_ccsproc_atc3/graph_multiclass_lenofstay_mimic4_th015.pkl', 'rb') as f:
    G = pickle.load(f)

G_tg = from_networkx(G) 

# filt_dataset = []

# for patient in sample_dataset:
#     if len(patient['node_set']) != 0:
#         filt_dataset.append(patient)
# filt_dataset = SampleDataset(samples=filt_dataset)

train_dataset, val_dataset, test_dataset = split_by_patient(sample_dataset, [0.8, 0.1, 0.1], seed=528)

In [3]:
c_v, p_v, d_v = [], [], []

for patient in sample_dataset:
    c_v.append(len(patient['conditions']))
    p_v.append(len(patient['procedures']))
print(max(c_v), max(p_v))
max_visits = max(c_v)

1 1


In [4]:
import torch
from torch_geometric.loader import DataListLoader, DataLoader

def get_subgraph(G, dataset, idx):
    patient = dataset[idx]
    while len(patient['node_set']) == 0:
        idx -= 1
        patient = dataset[idx]
    # L = G.edge_subgraph(torch.tensor([*patient['node_set']]))
    P = G.subgraph(torch.tensor([*patient['node_set']]))
    P.label = patient['drugs_ind']
    P.visits_cond = patient['visit_node_set_condition']
    P.visits_proc = patient['visit_node_set_procedure']
    
    return P

class Dataset(torch.utils.data.Dataset):
    def __init__(self, G, dataset):
        self.G = G
        self.dataset=dataset
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, idx):
        return get_subgraph(G=self.G, dataset=self.dataset, idx=idx)


In [7]:
[*train_dataset[1]['visit_node_set_condition']]

[tensor([0., 0., 0.,  ..., 0., 0., 1.])]

In [5]:
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import GATConv, GINConv, HGTConv
from torch_geometric.data import DataLoader, Data
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn import DataParallel
from torch_geometric.loader import DataListLoader

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
        self.conv2 = GATConv(hidden_channels*heads, hidden_channels, heads=heads)
        self.conv3 = GATConv(hidden_channels*heads, hidden_channels, heads=1)

        self.fc = Linear(hidden_channels, out_channels)
        
    def forward(self, x, edge_index, batch):
        x = F.elu(self.conv1(x, edge_index))
        # print(x.shape)
        x = F.elu(self.conv2(x, edge_index))
        # print(x.shape)
        x = F.elu(self.conv3(x, edge_index))
        # print(x.shape)
        x = global_mean_pool(x, batch)
        # print(x.shape)
        x = F.dropout(x, p=0.5, training=self.training)
        # print(x.shape)
        logits = self.fc(x)
        # print(logits.shape)
        return logits


class GIN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GIN, self).__init__()
        self.conv1 = GINConv(Linear(in_channels, hidden_channels))
        self.conv2 = GINConv(Linear(hidden_channels, hidden_channels))
        self.conv3 = GINConv(Linear(hidden_channels, hidden_channels))

        self.fc = Linear(hidden_channels, out_channels)
        
    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=0.5, training=self.training)
        logits = self.fc(x)
        return logits


class GINX(torch.nn.Module):
    def __init__(self, num_nodes, embedding_dim, hidden_channels, out_channels, word_emb=None):
        super(GINX, self).__init__()
        
        if word_emb == None:
            self.embedding = torch.nn.Embedding(num_nodes, embedding_dim)
            self.conv1 = GINConv(Linear(embedding_dim, hidden_channels))
        else:
            self.embedding = torch.nn.Embedding.from_pretrained(word_emb, freeze=False)
            self.conv1 = GINConv(Linear(word_emb.shape[1], hidden_channels))

        self.conv2 = GINConv(Linear(hidden_channels, hidden_channels))
        self.conv3 = GINConv(Linear(hidden_channels, hidden_channels))
        self.fc = Linear(hidden_channels, out_channels)
        
    def forward(self, node_ids, edge_index, batch):
        x = self.embedding(node_ids)
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=0.5, training=self.training)
        logits = self.fc(x)
        return logits

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINEConv
from pyhealth.models import RETAINLayer

class GraphCare(nn.Module):
    def __init__(self, num_nodes, feature_keys, embedding_dim, hidden_dim, out_channels, dropout=0.5, max_visits=None, word_emb=None, use_attn=True):
        super(GraphCare, self).__init__()
        self.max_visits = max_visits
        self.max_nodes = len(word_emb)
        self.embedding_dim = embedding_dim
        self.use_attn = use_attn
        self.alpha = nn.Parameter(torch.tensor(0.5))

        if word_emb == None:
            self.embedding = torch.nn.Embedding(num_nodes, embedding_dim)
        else:
            self.embedding = torch.nn.Embedding.from_pretrained(word_emb, freeze=True)
        
        self.retain = nn.ModuleDict()
        for feature_key in feature_keys:
            self.retain[feature_key] = RETAINLayer(feature_size=self.max_nodes, dropout=dropout)
        
        self.conv1 = GINEConv(nn.Linear(embedding_dim, hidden_dim), edge_dim=1)
        self.conv2 = GINEConv(nn.Linear(hidden_dim, hidden_dim), edge_dim=1)
        self.conv3 = GINEConv(nn.Linear(hidden_dim, hidden_dim), edge_dim=1)



        self.fc = nn.Linear(hidden_dim, out_channels)


    def forward(self, node_ids, edge_index, batch, visits_cond, visits_proc):
        x = self.embedding(node_ids)

        if self.use_attn == True:

            cond_attn = self.retain['cond'](visits_cond)
            proc_attn = self.retain['proc'](visits_proc)
            cross_attn = self.retain['cross'](visits_cond + visits_proc)

            attn = cond_attn.add_(proc_attn).add_(cross_attn)    # (batch_size, max_nodes)

            # Create a batch index tensor to map the batch index to the corresponding attention weight
            batch_index = torch.arange(attn.size(0), device=node_ids.device).repeat_interleave(torch.bincount(batch))   
            # print("batch index shape: ", batch_index.shape)
            # print("edge index shape: ", edge_index.shape)
            # Fill the attn_weights matrix with the correct weights using batch_index and node_ids
            attn_weights = attn[batch_index, node_ids]
            # Multiply the embeddings with the corresponding attention weights
            # x = x * attn_weights
            # x = (1 - self.alpha) * x + self.alpha * x
            row, col = edge_index
            # Define a small constant value epsilon
            epsilon = 1e-6

            # Calculate the geometric mean with added epsilon
            # Normalize the attn_weights and replace NaNs with 0s
            attn_weights = attn_weights / torch.max(attn_weights)
            attn_weights = torch.where(torch.isnan(attn_weights), torch.zeros_like(attn_weights), attn_weights)

            # Calculate the geometric mean with added epsilon
            edge_attr = ((attn_weights[row] + epsilon) + (attn_weights[col] + epsilon)).unsqueeze(-1)

            
            # print("row shape: ", row.shape) 
            # print("col shape: ", col.shape)
            # print("attn shape: ", attn.shape)
            # print("attn_weights shape: ", attn_weights.shape)
            # print("edge_attr shape: ", edge_attr.shape)
            # print("x shape: ", x.shape)


        # Apply the first GIN layer
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        # Apply the second GIN layer
        x = F.relu(self.conv2(x, edge_index, edge_attr))
        x = F.relu(self.conv3(x, edge_index, edge_attr))

        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=0.5, training=self.training)

        logits = self.fc(x)
        return logits



In [88]:
from tqdm import tqdm
from pyhealth.metrics import multilabel_metrics_fn
from sklearn.metrics import average_precision_score, roc_auc_score, f1_score, jaccard_score

def train(model, device, train_loader, optimizer, model_):
    model.train()
    training_loss = 0
    tot_loss = 0
    pbar = tqdm(train_loader)
    for data in pbar:
        pbar.set_description(f'loss: {training_loss}')
        data = data.to(device)
        optimizer.zero_grad()

        if model_ == "GIN":
            out = model(data.x, data.edge_index, data.batch)
        elif model_ == "GINX":
            out = model(data.y, data.edge_index, data.batch)
        else:
            out = model(
                    data.y, 
                    data.edge_index, 
                    data.batch, 
                    data.visits_cond.reshape(int(train_loader.batch_size), int(len(data.visits_cond)/train_loader.batch_size), data.visits_cond.shape[1]).double(), 
                    data.visits_proc.reshape(int(train_loader.batch_size), int(len(data.visits_proc)/train_loader.batch_size), data.visits_proc.shape[1]).double(), 
                )
        try:
            label = data.label.reshape(int(train_loader.batch_size), int(len(data.label)/train_loader.batch_size))
        except:
            continue

        loss = F.cross_entropy(out, label)
        loss.backward()
        training_loss = loss
        tot_loss += loss
        optimizer.step()
    
    return tot_loss

def evaluate(model, device, loader, model_):
    model.eval()
    y_prob_all = []
    y_true_all = []

    for data in tqdm(loader):
        data = data.to(device)
        with torch.no_grad():
            
            if model_ == "GIN":
                logits = model(data.x, data.edge_index, data.batch)
            elif model_ == "GINX":
                logits = model(data.y, data.edge_index, data.batch)
            else:
                logits = model(
                    data.y, 
                    data.edge_index, 
                    data.batch, 
                    data.visits_cond.reshape(int(loader.batch_size), int(len(data.visits_cond)/loader.batch_size), data.visits_cond.shape[1]).double(), 
                    data.visits_proc.reshape(int(loader.batch_size), int(len(data.visits_proc)/loader.batch_size), data.visits_proc.shape[1]).double(), 
                )

            y_prob = torch.sigmoid(logits)
            try:
                y_true = data.label.reshape(int(loader.batch_size), int(len(data.label)/loader.batch_size))
            except:
                continue
            y_prob_all.append(y_prob.cpu())
            y_true_all.append(y_true.cpu())
            
    y_true_all = np.concatenate(y_true_all, axis=0)
    y_prob_all = np.concatenate(y_prob_all, axis=0)
    # pr_auc = multilabel_metrics_fn(y_true=y_true_all, y_prob=y_true_all, metrics="pr_auc_macro")

    return y_true_all, y_prob_all

def train_loop(train_loader, val_loader, model, optimizer, device, epochs, model_):
    best_acc = 0
    best_f1 = 0
    for epoch in range(1, epochs+1):
        loss = train(model, device, train_loader, optimizer, model_)
        y_true_all, y_prob_all = evaluate(model, device, val_loader, model_)

        y_pred_all = (y_prob_all >= 0.5).astype(int)
        
        roc_auc_weighted_ovr = sklearn_metrics.roc_auc_score(
            y_true, y_prob, average="weighted", multi_class="ovr"
        )

        if val_acc >= best_acc and val_f1 >= best_f1:
            torch.save(model.state_dict(), '../../../data/pj20/exp_data/saved_weights_gin_mimic3_readmission_dynamic.pkl')
            print("best model saved")
            best_acc = val_acc
            best_f1 = val_f1

        print(f'Epoch: {epoch}, Training loss: {loss}, Val PRAUC: {val_pr_auc:.4f}, Val ROC_AUC: {val_roc_auc:.4f}, Val acc: {val_acc:.4f}, Val F1: {val_f1:.4f}, Val precision: {val_precision:.4f}, Val recall: {val_recall:.4f}, Val jaccard: {val_jaccard:.4f}')


In [89]:
# G_tg.x = torch.randn(G_tg.num_nodes, 256)

In [90]:
train_set = Dataset(G=G_tg, dataset=train_dataset)
val_set = Dataset(G=G_tg, dataset=val_dataset)
test_set = Dataset(G=G_tg, dataset=test_dataset)

train_loader = DataLoader(train_set, batch_size=16, shuffle=True, drop_last=True)
val_loader = DataLoader(val_set, batch_size=16, shuffle=False, drop_last=True)
test_loader = DataLoader(test_set, batch_size=16, shuffle=False, drop_last=True)



In [94]:
model_ = "GINX"

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

if model_ == "GIN":
    in_channels = train_set[0].x.shape[1]
    model = GIN(in_channels=in_channels, out_channels=1, hidden_channels=512).to(device)
    # model = GAT(in_channels=in_channels, out_channels=1, hidden_channels=256, heads=3).to(device)
    # model = HGT(in_channels=in_channels, out_channels=out_channels, hidden_channels=512, heads=2).to(device)
elif model_ == "GINX":
    model = GINX(num_nodes=G_tg.num_nodes, embedding_dim=512, hidden_channels=512, out_channels=1, word_emb=G_tg.x).to(device)

elif model_ == "GraphCare":
    # model = GINX(num_nodes=G_tg.num_nodes, embedding_dim=512, hidden_channels=512, out_channels=out_channels, word_emb=G_tg.x).to(device)
    model = GraphCare(
        num_nodes=G_tg.num_nodes,
        feature_keys=['cond', 'proc', 'cross'], 
        embedding_dim=len(G_tg.x[0]), 
        hidden_dim=512, 
        out_channels=1, 
        dropout=0.5, 
        max_visits=max_visits,
        word_emb=G_tg.x,
        use_attn=True
    ).to(device)

model.double()

GraphCare(
  (embedding): Embedding(1000, 1024)
  (retain): ModuleDict(
    (cond): RETAINLayer(
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (alpha_gru): GRU(1000, 1000, batch_first=True)
      (beta_gru): GRU(1000, 1000, batch_first=True)
      (alpha_li): Linear(in_features=1000, out_features=1, bias=True)
      (beta_li): Linear(in_features=1000, out_features=1000, bias=True)
    )
    (proc): RETAINLayer(
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (alpha_gru): GRU(1000, 1000, batch_first=True)
      (beta_gru): GRU(1000, 1000, batch_first=True)
      (alpha_li): Linear(in_features=1000, out_features=1, bias=True)
      (beta_li): Linear(in_features=1000, out_features=1000, bias=True)
    )
    (cross): RETAINLayer(
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (alpha_gru): GRU(1000, 1000, batch_first=True)
      (beta_gru): GRU(1000, 1000, batch_first=True)
      (alpha_li): Linear(in_features=1000, out_features=1, bias=True)
      (beta_

In [95]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train_loop(train_loader=train_loader, val_loader=val_loader, model=model, optimizer=optimizer, device=device, epochs=100, static=static)

loss: 0.3190099504377803: 100%|██████████| 2223/2223 [17:14<00:00,  2.15it/s] 
100%|██████████| 271/271 [00:54<00:00,  5.02it/s]


best model saved
Epoch: 1, Training loss: 31558.401586260756, Val PRAUC: 0.6461, Val ROC_AUC: 0.9004, Val F1-score: 0.4687, Val Jaccard: 0.3179


loss: 0.23548025798127212: 100%|██████████| 2223/2223 [17:08<00:00,  2.16it/s]
100%|██████████| 271/271 [01:10<00:00,  3.87it/s]


best model saved
Epoch: 2, Training loss: 668.3661525146066, Val PRAUC: 0.6466, Val ROC_AUC: 0.9029, Val F1-score: 0.4687, Val Jaccard: 0.3179


loss: 0.28290939595346204: 100%|██████████| 2223/2223 [18:00<00:00,  2.06it/s]
100%|██████████| 271/271 [00:56<00:00,  4.79it/s]


best model saved
Epoch: 3, Training loss: 607.9846786465552, Val PRAUC: 0.6466, Val ROC_AUC: 0.9031, Val F1-score: 0.4687, Val Jaccard: 0.3179


loss: 0.2730543356437892: 100%|██████████| 2223/2223 [18:13<00:00,  2.03it/s] 
100%|██████████| 271/271 [01:43<00:00,  2.62it/s]


Epoch: 4, Training loss: 592.9134928369855, Val PRAUC: 0.6467, Val ROC_AUC: 0.9030, Val F1-score: 0.4687, Val Jaccard: 0.3179


loss: 0.26231648551205394: 100%|██████████| 2223/2223 [17:35<00:00,  2.11it/s]
100%|██████████| 271/271 [00:57<00:00,  4.69it/s]


Epoch: 5, Training loss: 588.685593257485, Val PRAUC: 0.6466, Val ROC_AUC: 0.9031, Val F1-score: 0.4687, Val Jaccard: 0.3179


loss: 0.2539298038136868: 100%|██████████| 2223/2223 [17:45<00:00,  2.09it/s] 
100%|██████████| 271/271 [00:57<00:00,  4.73it/s]


best model saved
Epoch: 6, Training loss: 587.4888132497164, Val PRAUC: 0.6467, Val ROC_AUC: 0.9031, Val F1-score: 0.4590, Val Jaccard: 0.3094


loss: 0.2947981430612893: 100%|██████████| 2223/2223 [17:55<00:00,  2.07it/s] 
100%|██████████| 271/271 [00:57<00:00,  4.71it/s]


Epoch: 7, Training loss: 587.1197404865286, Val PRAUC: 0.6467, Val ROC_AUC: 0.9031, Val F1-score: 0.4687, Val Jaccard: 0.3179


loss: 0.24527524525398567: 100%|██████████| 2223/2223 [17:42<00:00,  2.09it/s]
100%|██████████| 271/271 [00:56<00:00,  4.79it/s]


Epoch: 8, Training loss: 586.9825388997497, Val PRAUC: 0.6466, Val ROC_AUC: 0.9031, Val F1-score: 0.4687, Val Jaccard: 0.3179


loss: 0.2842615063126055:  77%|███████▋  | 1709/2223 [13:48<04:09,  2.06it/s] 


KeyboardInterrupt: 

In [None]:
for data in test_loader:
    print(data.y)

In [None]:
# torch.save(model.state_dict(), './exp_data/saved_weights_gat_mimic3_drugrec.pkl')
# torch.save(model.state_dict(), './exp_data/saved_weights_gin_mimic3_drugrec_random.pkl')
# torch.save(model.state_dict(), './exp_data/saved_weights_hgt_mimic3_drugrec.pkl')

In [None]:
# model.load_state_dict(torch.load('./exp_data/saved_weights_gat_mimic3_drugrec.pkl'))
model.load_state_dict(torch.load('../../../data/pj20/exp_data/saved_weights_graph_mimic3_drugrec_th02.pkl'))
model.double()

y_true_all, y_prob_all = evaluate(model, device, val_loader, static)

y_pred_all = y_prob_all.copy()
y_pred_all[y_pred_all >= 0.5] = 1
y_pred_all[y_pred_all < 0.5] = 0

test_pr_auc = average_precision_score(y_true_all, y_prob_all, average="samples")
test_roc_auc = roc_auc_score(y_true_all, y_prob_all, average="samples")
test_f1 = f1_score(y_true_all, y_pred_all, average='samples')
test_jaccard = jaccard_score(y_true_all, y_pred_all, average='samples')

print(f'test PRAUC: {test_pr_auc:.4f}, test ROC_AUC: {test_roc_auc:.4f}, test F1-score: {test_f1:.4f}, test Jaccard: {test_jaccard:.4f}')