In [51]:
import pickle
from pyhealth.datasets import SampleDataset
from pyhealth.datasets import split_by_patient
from torch_geometric.utils import to_networkx, from_networkx

with open('../../../data/pj20/exp_data/icd9cm_icd9proc/drugrec_dataset_umls_th015.pkl', 'rb') as f:
    sample_dataset = pickle.load(f)

with open('../../../data/pj20/exp_data/icd9cm_icd9proc/graph_umls_th015_drugrec.pkl', 'rb') as f:
    G = pickle.load(f)

G_tg = from_networkx(G) 

# filt_dataset = []

# for patient in sample_dataset:
#     if len(patient['node_set']) != 0:
#         filt_dataset.append(patient)
# filt_dataset = SampleDataset(samples=filt_dataset)

train_dataset, val_dataset, test_dataset = split_by_patient(sample_dataset, [0.8, 0.1, 0.1], seed=528)

In [52]:
c_v, p_v, d_v = [], [], []

for patient in sample_dataset:
    c_v.append(len(patient['conditions']))
    p_v.append(len(patient['procedures']))
print(max(c_v), max(p_v))
max_visits = max(c_v)

29 29


In [53]:
import torch
from torch_geometric.loader import DataListLoader, DataLoader

def get_subgraph(G, dataset, idx):
    patient = dataset[idx]
    while len(patient['node_set']) == 0:
        idx -= 1
        patient = dataset[idx]
    # L = G.edge_subgraph(torch.tensor([*patient['node_set']]))
    P = G.subgraph(torch.tensor([*patient['node_set']]))
    P.label = patient['drugs_ind']
    P.visits_cond = patient['visit_node_set_condition']
    P.visits_proc = patient['visit_node_set_procedure']
    
    return P

class Dataset(torch.utils.data.Dataset):
    def __init__(self, G, dataset):
        self.G = G
        self.dataset=dataset
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, idx):
        return get_subgraph(G=self.G, dataset=self.dataset, idx=idx)


In [54]:
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import GATConv, GINConv, HGTConv
from torch_geometric.data import DataLoader, Data
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn import DataParallel
from torch_geometric.loader import DataListLoader

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
        self.conv2 = GATConv(hidden_channels*heads, hidden_channels, heads=heads)
        self.conv3 = GATConv(hidden_channels*heads, hidden_channels, heads=1)

        self.fc = Linear(hidden_channels, out_channels)
        
    def forward(self, x, edge_index, batch):
        x = F.elu(self.conv1(x, edge_index))
        # print(x.shape)
        x = F.elu(self.conv2(x, edge_index))
        # print(x.shape)
        x = F.elu(self.conv3(x, edge_index))
        # print(x.shape)
        x = global_mean_pool(x, batch)
        # print(x.shape)
        x = F.dropout(x, p=0.5, training=self.training)
        # print(x.shape)
        logits = self.fc(x)
        # print(logits.shape)
        return logits


class GIN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GIN, self).__init__()
        self.conv1 = GINConv(Linear(in_channels, hidden_channels))
        self.conv2 = GINConv(Linear(hidden_channels, hidden_channels))
        self.conv3 = GINConv(Linear(hidden_channels, hidden_channels))

        self.fc = Linear(hidden_channels, out_channels)
        
    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=0.5, training=self.training)
        logits = self.fc(x)
        return logits


class GINX(torch.nn.Module):
    def __init__(self, num_nodes, embedding_dim, hidden_channels, out_channels, word_emb=None):
        super(GINX, self).__init__()
        
        if word_emb == None:
            self.embedding = torch.nn.Embedding(num_nodes, embedding_dim)
            self.conv1 = GINConv(Linear(embedding_dim, hidden_channels))
        else:
            self.embedding = torch.nn.Embedding.from_pretrained(word_emb, freeze=False)
            self.conv1 = GINConv(Linear(word_emb.shape[1], hidden_channels))

        self.conv2 = GINConv(Linear(hidden_channels, hidden_channels))
        self.conv3 = GINConv(Linear(hidden_channels, hidden_channels))
        self.fc = Linear(hidden_channels, out_channels)
        
    def forward(self, node_ids, edge_index, batch):
        x = self.embedding(node_ids)
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=0.5, training=self.training)
        logits = self.fc(x)
        return logits

In [55]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINEConv
from pyhealth.models import RETAINLayer

class GraphCare(nn.Module):
    def __init__(self, num_nodes, feature_keys, embedding_dim, hidden_dim, out_channels, dropout=0.5, max_visits=None, word_emb=None, use_attn=True):
        super(GraphCare, self).__init__()
        self.max_visits = max_visits
        self.max_nodes = len(word_emb)
        self.embedding_dim = embedding_dim
        self.use_attn = use_attn
        self.alpha = nn.Parameter(torch.tensor(0.5))

        if word_emb == None:
            self.embedding = torch.nn.Embedding(num_nodes, embedding_dim)
        else:
            self.embedding = torch.nn.Embedding.from_pretrained(word_emb, freeze=True)
        
        self.retain = nn.ModuleDict()
        for feature_key in feature_keys:
            self.retain[feature_key] = RETAINLayer(feature_size=self.max_nodes, dropout=dropout)
        
        self.conv1 = GINEConv(nn.Linear(embedding_dim, hidden_dim), edge_dim=1)
        self.conv2 = GINEConv(nn.Linear(hidden_dim, hidden_dim), edge_dim=1)
        self.conv3 = GINEConv(nn.Linear(hidden_dim, hidden_dim), edge_dim=1)



        self.fc = nn.Linear(hidden_dim, out_channels)


    def forward(self, node_ids, edge_index, batch, visits_cond, visits_proc):
        x = self.embedding(node_ids)

        if self.use_attn == True:

            cond_attn = self.retain['cond'](visits_cond)
            proc_attn = self.retain['proc'](visits_proc)
            cross_attn = self.retain['cross'](visits_cond + visits_proc)

            attn = cond_attn.add_(proc_attn).add_(cross_attn)    # (batch_size, max_nodes)

            # Create a batch index tensor to map the batch index to the corresponding attention weight
            batch_index = torch.arange(attn.size(0), device=node_ids.device).repeat_interleave(torch.bincount(batch))   
            # print("batch index shape: ", batch_index.shape)
            # print("edge index shape: ", edge_index.shape)
            # Fill the attn_weights matrix with the correct weights using batch_index and node_ids
            attn_weights = attn[batch_index, node_ids]
            # Multiply the embeddings with the corresponding attention weights
            # x = x * attn_weights
            # x = (1 - self.alpha) * x + self.alpha * x
            row, col = edge_index
            # Define a small constant value epsilon
            epsilon = 1e-6

            # Calculate the geometric mean with added epsilon
            # Normalize the attn_weights and replace NaNs with 0s
            attn_weights = attn_weights / torch.max(attn_weights)
            attn_weights = torch.where(torch.isnan(attn_weights), torch.zeros_like(attn_weights), attn_weights)

            # Calculate the geometric mean with added epsilon
            edge_attr = ((attn_weights[row] + epsilon) + (attn_weights[col] + epsilon)).unsqueeze(-1)

            
            # print("row shape: ", row.shape) 
            # print("col shape: ", col.shape)
            # print("attn shape: ", attn.shape)
            # print("attn_weights shape: ", attn_weights.shape)
            # print("edge_attr shape: ", edge_attr.shape)
            # print("x shape: ", x.shape)


        # Apply the first GIN layer
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        # Apply the second GIN layer
        x = F.relu(self.conv2(x, edge_index, edge_attr))
        x = F.relu(self.conv3(x, edge_index, edge_attr))

        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=0.5, training=self.training)

        logits = self.fc(x)
        return logits



In [56]:
from tqdm import tqdm
from pyhealth.metrics import multilabel_metrics_fn
from sklearn.metrics import average_precision_score, roc_auc_score, f1_score, jaccard_score

def train(model, device, train_loader, optimizer, model_):
    model.train()
    training_loss = 0
    tot_loss = 0
    pbar = tqdm(train_loader)
    for data in pbar:
        pbar.set_description(f'loss: {training_loss}')
        data = data.to(device)
        optimizer.zero_grad()

        if model_ == "GIN":
            out = model(data.x, data.edge_index, data.batch)
        elif model_ == "GINX":
            out = model(data.y, data.edge_index, data.batch)
        else:
            out = model(
                    data.y, 
                    data.edge_index, 
                    data.batch, 
                    data.visits_cond.reshape(int(train_loader.batch_size), int(len(data.visits_cond)/train_loader.batch_size), data.visits_cond.shape[1]).double(), 
                    data.visits_proc.reshape(int(train_loader.batch_size), int(len(data.visits_proc)/train_loader.batch_size), data.visits_proc.shape[1]).double(), 
                )

        
        label = data.label.reshape(int(train_loader.batch_size), int(len(data.label)/train_loader.batch_size))

        loss = F.binary_cross_entropy_with_logits(out, label)
        loss.backward()
        training_loss = loss
        tot_loss += loss
        optimizer.step()
    
    return tot_loss

def evaluate(model, device, loader, model_):
    model.eval()
    y_prob_all = []
    y_true_all = []

    for data in tqdm(loader):
        data = data.to(device)
        with torch.no_grad():
            
            if model_ == "GIN":
                logits = model(data.x, data.edge_index, data.batch)
            elif model_ == "GINX":
                logits = model(data.y, data.edge_index, data.batch)
            else:
                logits = model(
                    data.y, 
                    data.edge_index, 
                    data.batch, 
                    data.visits_cond.reshape(int(loader.batch_size), int(len(data.visits_cond)/loader.batch_size), data.visits_cond.shape[1]).double(), 
                    data.visits_proc.reshape(int(loader.batch_size), int(len(data.visits_proc)/loader.batch_size), data.visits_proc.shape[1]).double(), 
                )

            y_prob = torch.sigmoid(logits)
            try:
                y_true = data.label.reshape(int(loader.batch_size), int(len(data.label)/loader.batch_size))
            except:
                continue
            y_prob_all.append(y_prob.cpu())
            y_true_all.append(y_true.cpu())
            
    y_true_all = np.concatenate(y_true_all, axis=0)
    y_prob_all = np.concatenate(y_prob_all, axis=0)
    # pr_auc = multilabel_metrics_fn(y_true=y_true_all, y_prob=y_true_all, metrics="pr_auc_macro")

    return y_true_all, y_prob_all

def train_loop(train_loader, val_loader, model, optimizer, device, epochs, model_):
    best_pr_auc = 0
    best_roc_auc = 0
    for epoch in range(1, epochs+1):
        loss = train(model, device, train_loader, optimizer, model_)
        y_true_all, y_prob_all = evaluate(model, device, val_loader, model_)

        y_pred_all = y_prob_all.copy()
        y_pred_all[y_pred_all >= 0.5] = 1
        y_pred_all[y_pred_all < 0.5] = 0

        val_pr_auc = average_precision_score(y_true_all, y_prob_all, average="samples")
        val_roc_auc = roc_auc_score(y_true_all, y_prob_all, average="samples")
        val_f1 = f1_score(y_true_all, y_pred_all, average='samples')
        val_jaccard = jaccard_score(y_true_all, y_pred_all, average='samples')

        if val_pr_auc >= best_pr_auc and val_roc_auc >= best_roc_auc:
            torch.save(model.state_dict(), '../../../data/pj20/exp_data/saved_weights_graph_mimic3_drugrec_th02.pkl')
            print("best model saved")
            best_pr_auc = val_pr_auc
            best_roc_auc = val_roc_auc

        print(f'Epoch: {epoch}, Training loss: {loss}, Val PRAUC: {val_pr_auc:.4f}, Val ROC_AUC: {val_roc_auc:.4f}, Val F1-score: {val_f1:.4f}, Val Jaccard: {val_jaccard:.4f}')

In [57]:
# G_tg.x = torch.randn(G_tg.num_nodes, 256)

In [58]:
train_set = Dataset(G=G_tg, dataset=train_dataset)
val_set = Dataset(G=G_tg, dataset=val_dataset)
test_set = Dataset(G=G_tg, dataset=test_dataset)

train_loader = DataLoader(train_set, batch_size=16, shuffle=True, drop_last=True)
val_loader = DataLoader(val_set, batch_size=16, shuffle=False, drop_last=True)
test_loader = DataLoader(test_set, batch_size=16, shuffle=False, drop_last=True)



In [59]:
model_ = "GIN"

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
out_channels = len(train_set[0].label)

if model_ == "GIN":
    in_channels = train_set[0].x.shape[1]
    model = GIN(in_channels=in_channels, out_channels=out_channels, hidden_channels=512).to(device)
    # model = GAT(in_channels=in_channels, out_channels=1, hidden_channels=256, heads=3).to(device)
    # model = HGT(in_channels=in_channels, out_channels=out_channels, hidden_channels=512, heads=2).to(device)
elif model_ == "GINX":
    model = GINX(num_nodes=G_tg.num_nodes, embedding_dim=512, hidden_channels=512, out_channels=out_channels, word_emb=G_tg.x).to(device)

elif model_ == "GraphCare":
    # model = GINX(num_nodes=G_tg.num_nodes, embedding_dim=512, hidden_channels=512, out_channels=out_channels, word_emb=G_tg.x).to(device)
    model = GraphCare(
        num_nodes=G_tg.num_nodes,
        feature_keys=['cond', 'proc', 'cross'], 
        embedding_dim=len(G_tg.x[0]), 
        hidden_dim=512, 
        out_channels=out_channels, 
        dropout=0.5, 
        max_visits=max_visits,
        word_emb=G_tg.x,
        use_attn=True
    ).to(device)

model.double()

GIN(
  (conv1): GINConv(nn=Linear(in_features=1536, out_features=512, bias=True))
  (conv2): GINConv(nn=Linear(in_features=512, out_features=512, bias=True))
  (conv3): GINConv(nn=Linear(in_features=512, out_features=512, bias=True))
  (fc): Linear(in_features=512, out_features=197, bias=True)
)

In [60]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train_loop(train_loader=train_loader, val_loader=val_loader, model=model, optimizer=optimizer, device=device, epochs=100, model_=model_)

loss: 0.264652420131181: 100%|██████████| 2223/2223 [02:28<00:00, 14.97it/s]  
100%|██████████| 271/271 [00:16<00:00, 16.90it/s]


best model saved
Epoch: 1, Training loss: 620.9488562615705, Val PRAUC: 0.6562, Val ROC_AUC: 0.9117, Val F1-score: 0.4713, Val Jaccard: 0.3203


loss: 0.23851399449628402: 100%|██████████| 2223/2223 [03:03<00:00, 12.11it/s]
100%|██████████| 271/271 [00:17<00:00, 15.76it/s]


best model saved
Epoch: 2, Training loss: 565.3708916109855, Val PRAUC: 0.6958, Val ROC_AUC: 0.9200, Val F1-score: 0.5363, Val Jaccard: 0.3806


loss: 0.25836394406931: 100%|██████████| 2223/2223 [02:44<00:00, 13.49it/s]   
100%|██████████| 271/271 [00:19<00:00, 13.87it/s]


best model saved
Epoch: 3, Training loss: 546.5450063532412, Val PRAUC: 0.7062, Val ROC_AUC: 0.9245, Val F1-score: 0.5276, Val Jaccard: 0.3739


loss: 0.2656918345643904: 100%|██████████| 2223/2223 [02:48<00:00, 13.19it/s] 
100%|██████████| 271/271 [00:16<00:00, 16.71it/s]


best model saved
Epoch: 4, Training loss: 539.1200801766764, Val PRAUC: 0.7128, Val ROC_AUC: 0.9264, Val F1-score: 0.5365, Val Jaccard: 0.3839


loss: 0.23456541218674248: 100%|██████████| 2223/2223 [02:42<00:00, 13.66it/s]
100%|██████████| 271/271 [00:18<00:00, 14.59it/s]


Epoch: 5, Training loss: 530.5595686406699, Val PRAUC: 0.7159, Val ROC_AUC: 0.9262, Val F1-score: 0.5734, Val Jaccard: 0.4198


loss: 0.3102274889538489: 100%|██████████| 2223/2223 [02:37<00:00, 14.08it/s] 
100%|██████████| 271/271 [00:16<00:00, 16.43it/s]


best model saved
Epoch: 6, Training loss: 524.6495366238371, Val PRAUC: 0.7156, Val ROC_AUC: 0.9268, Val F1-score: 0.5671, Val Jaccard: 0.4144


loss: 0.21745822299046955: 100%|██████████| 2223/2223 [02:38<00:00, 14.07it/s]
100%|██████████| 271/271 [00:18<00:00, 14.56it/s]


best model saved
Epoch: 7, Training loss: 521.8044102076269, Val PRAUC: 0.7231, Val ROC_AUC: 0.9294, Val F1-score: 0.5700, Val Jaccard: 0.4173


loss: 0.23728790588903828: 100%|██████████| 2223/2223 [02:40<00:00, 13.84it/s]
100%|██████████| 271/271 [00:17<00:00, 15.78it/s]


best model saved
Epoch: 8, Training loss: 518.4774973883899, Val PRAUC: 0.7242, Val ROC_AUC: 0.9298, Val F1-score: 0.5570, Val Jaccard: 0.4059


loss: 0.19735614937025947: 100%|██████████| 2223/2223 [02:41<00:00, 13.80it/s]
100%|██████████| 271/271 [00:15<00:00, 17.16it/s]


Epoch: 9, Training loss: 517.4744460167844, Val PRAUC: 0.7184, Val ROC_AUC: 0.9279, Val F1-score: 0.5788, Val Jaccard: 0.4253


loss: 0.23712275730851895: 100%|██████████| 2223/2223 [02:45<00:00, 13.43it/s]
100%|██████████| 271/271 [00:16<00:00, 16.09it/s]


best model saved
Epoch: 10, Training loss: 520.2664088063736, Val PRAUC: 0.7265, Val ROC_AUC: 0.9299, Val F1-score: 0.5671, Val Jaccard: 0.4151


loss: 0.28236734866247293: 100%|██████████| 2223/2223 [02:44<00:00, 13.48it/s]
100%|██████████| 271/271 [00:14<00:00, 18.61it/s]


best model saved
Epoch: 11, Training loss: 514.0924009080613, Val PRAUC: 0.7298, Val ROC_AUC: 0.9313, Val F1-score: 0.5653, Val Jaccard: 0.4135


loss: 0.2023531802907622: 100%|██████████| 2223/2223 [02:26<00:00, 15.17it/s] 
100%|██████████| 271/271 [00:16<00:00, 16.09it/s]


Epoch: 12, Training loss: 513.4551000749922, Val PRAUC: 0.7278, Val ROC_AUC: 0.9300, Val F1-score: 0.5693, Val Jaccard: 0.4167


loss: 0.2252312063652598: 100%|██████████| 2223/2223 [02:36<00:00, 14.25it/s] 
100%|██████████| 271/271 [00:15<00:00, 17.98it/s]


Epoch: 13, Training loss: 512.6274367714983, Val PRAUC: 0.7294, Val ROC_AUC: 0.9309, Val F1-score: 0.5790, Val Jaccard: 0.4261


loss: 0.2771169464038104: 100%|██████████| 2223/2223 [02:27<00:00, 15.11it/s] 
100%|██████████| 271/271 [00:16<00:00, 16.88it/s]


Epoch: 14, Training loss: 511.374240786065, Val PRAUC: 0.7289, Val ROC_AUC: 0.9304, Val F1-score: 0.5663, Val Jaccard: 0.4141


loss: 0.303339931228752: 100%|██████████| 2223/2223 [02:26<00:00, 15.19it/s]  
100%|██████████| 271/271 [00:14<00:00, 18.51it/s]


best model saved
Epoch: 15, Training loss: 511.4807423247251, Val PRAUC: 0.7318, Val ROC_AUC: 0.9319, Val F1-score: 0.5625, Val Jaccard: 0.4108


loss: 0.2384440962679373: 100%|██████████| 2223/2223 [02:18<00:00, 16.02it/s] 
100%|██████████| 271/271 [00:13<00:00, 19.68it/s]


best model saved
Epoch: 16, Training loss: 510.5601419647611, Val PRAUC: 0.7331, Val ROC_AUC: 0.9323, Val F1-score: 0.5775, Val Jaccard: 0.4254


loss: 0.2207560984998378: 100%|██████████| 2223/2223 [02:15<00:00, 16.36it/s] 
100%|██████████| 271/271 [00:13<00:00, 19.62it/s]


Epoch: 17, Training loss: 509.86237250197934, Val PRAUC: 0.7305, Val ROC_AUC: 0.9316, Val F1-score: 0.5632, Val Jaccard: 0.4111


loss: 0.2333292302457498: 100%|██████████| 2223/2223 [02:13<00:00, 16.70it/s] 
100%|██████████| 271/271 [00:13<00:00, 19.68it/s]


Epoch: 18, Training loss: 509.634421370784, Val PRAUC: 0.7294, Val ROC_AUC: 0.9310, Val F1-score: 0.5620, Val Jaccard: 0.4102


loss: 0.2699796150366273: 100%|██████████| 2223/2223 [02:17<00:00, 16.11it/s] 
100%|██████████| 271/271 [00:13<00:00, 20.42it/s]


Epoch: 19, Training loss: 516.1450444874873, Val PRAUC: 0.7324, Val ROC_AUC: 0.9318, Val F1-score: 0.5783, Val Jaccard: 0.4261


loss: 0.28399227903704094: 100%|██████████| 2223/2223 [02:36<00:00, 14.20it/s]
100%|██████████| 271/271 [00:18<00:00, 14.39it/s]


Epoch: 20, Training loss: 509.32720693649316, Val PRAUC: 0.7313, Val ROC_AUC: 0.9323, Val F1-score: 0.5559, Val Jaccard: 0.4047


loss: 0.2266822409795965: 100%|██████████| 2223/2223 [02:40<00:00, 13.89it/s] 
100%|██████████| 271/271 [00:14<00:00, 19.05it/s]


Epoch: 21, Training loss: 510.92679103652785, Val PRAUC: 0.7287, Val ROC_AUC: 0.9312, Val F1-score: 0.5695, Val Jaccard: 0.4174


loss: 0.2084033884479648: 100%|██████████| 2223/2223 [02:52<00:00, 12.87it/s] 
100%|██████████| 271/271 [00:18<00:00, 14.65it/s]


Epoch: 22, Training loss: 513.1270967918309, Val PRAUC: 0.7287, Val ROC_AUC: 0.9309, Val F1-score: 0.5742, Val Jaccard: 0.4219


loss: 0.233681675821645: 100%|██████████| 2223/2223 [02:54<00:00, 12.75it/s]  
100%|██████████| 271/271 [00:16<00:00, 16.68it/s]


Epoch: 23, Training loss: 511.117312153225, Val PRAUC: 0.7312, Val ROC_AUC: 0.9317, Val F1-score: 0.5578, Val Jaccard: 0.4066


loss: 0.20919307109037283: 100%|██████████| 2223/2223 [03:01<00:00, 12.22it/s]
100%|██████████| 271/271 [00:14<00:00, 18.81it/s]


Epoch: 24, Training loss: 509.8599535965487, Val PRAUC: 0.7324, Val ROC_AUC: 0.9317, Val F1-score: 0.5766, Val Jaccard: 0.4247


loss: 0.24528638202076386: 100%|██████████| 2223/2223 [02:42<00:00, 13.66it/s]
100%|██████████| 271/271 [00:15<00:00, 16.98it/s]


best model saved
Epoch: 25, Training loss: 508.64697666377094, Val PRAUC: 0.7333, Val ROC_AUC: 0.9330, Val F1-score: 0.5813, Val Jaccard: 0.4296


loss: 0.21262393636107757: 100%|██████████| 2223/2223 [02:31<00:00, 14.71it/s]
100%|██████████| 271/271 [00:14<00:00, 18.36it/s]


Epoch: 26, Training loss: 507.14072430252537, Val PRAUC: 0.7335, Val ROC_AUC: 0.9328, Val F1-score: 0.5770, Val Jaccard: 0.4251


loss: 0.20974041947672353: 100%|██████████| 2223/2223 [02:20<00:00, 15.82it/s]
100%|██████████| 271/271 [00:13<00:00, 20.63it/s]


best model saved
Epoch: 27, Training loss: 506.3439952924706, Val PRAUC: 0.7347, Val ROC_AUC: 0.9331, Val F1-score: 0.5788, Val Jaccard: 0.4269


loss: 0.22488158366622177: 100%|██████████| 2223/2223 [02:17<00:00, 16.19it/s]
100%|██████████| 271/271 [00:16<00:00, 16.67it/s]


best model saved
Epoch: 28, Training loss: 506.0166451120067, Val PRAUC: 0.7353, Val ROC_AUC: 0.9333, Val F1-score: 0.5716, Val Jaccard: 0.4202


loss: 0.24217492389990528: 100%|██████████| 2223/2223 [02:25<00:00, 15.26it/s]
100%|██████████| 271/271 [00:14<00:00, 19.26it/s]


Epoch: 29, Training loss: 506.12700061344873, Val PRAUC: 0.7343, Val ROC_AUC: 0.9324, Val F1-score: 0.5851, Val Jaccard: 0.4329


loss: 0.2608394724586076: 100%|██████████| 2223/2223 [02:29<00:00, 14.84it/s] 
100%|██████████| 271/271 [00:17<00:00, 15.38it/s]


Epoch: 30, Training loss: 504.98880040206615, Val PRAUC: 0.7333, Val ROC_AUC: 0.9327, Val F1-score: 0.5751, Val Jaccard: 0.4230


loss: 0.21214668744232168: 100%|██████████| 2223/2223 [02:33<00:00, 14.51it/s]
100%|██████████| 271/271 [00:14<00:00, 19.15it/s]


Epoch: 31, Training loss: 504.78390528750305, Val PRAUC: 0.7351, Val ROC_AUC: 0.9336, Val F1-score: 0.5864, Val Jaccard: 0.4344


loss: 0.16835746878889973: 100%|██████████| 2223/2223 [02:44<00:00, 13.55it/s]
100%|██████████| 271/271 [00:19<00:00, 14.01it/s]


best model saved
Epoch: 32, Training loss: 504.27173944700746, Val PRAUC: 0.7358, Val ROC_AUC: 0.9335, Val F1-score: 0.5773, Val Jaccard: 0.4258


loss: 0.24326478438445132: 100%|██████████| 2223/2223 [02:41<00:00, 13.75it/s]
100%|██████████| 271/271 [00:18<00:00, 14.97it/s]


Epoch: 33, Training loss: 503.2805603499776, Val PRAUC: 0.7310, Val ROC_AUC: 0.9320, Val F1-score: 0.5755, Val Jaccard: 0.4233


loss: 0.20465569731650074: 100%|██████████| 2223/2223 [02:43<00:00, 13.56it/s]
100%|██████████| 271/271 [00:19<00:00, 13.72it/s]


Epoch: 34, Training loss: 503.292123100584, Val PRAUC: 0.7337, Val ROC_AUC: 0.9320, Val F1-score: 0.5784, Val Jaccard: 0.4266


loss: 0.23988472870351688: 100%|██████████| 2223/2223 [02:51<00:00, 12.94it/s]
100%|██████████| 271/271 [00:16<00:00, 16.09it/s]


Epoch: 35, Training loss: 503.2673965112291, Val PRAUC: 0.7354, Val ROC_AUC: 0.9335, Val F1-score: 0.5763, Val Jaccard: 0.4245


loss: 0.2062125938388466: 100%|██████████| 2223/2223 [03:01<00:00, 12.26it/s] 
100%|██████████| 271/271 [00:16<00:00, 16.24it/s]


Epoch: 36, Training loss: 502.6411985118604, Val PRAUC: 0.7356, Val ROC_AUC: 0.9339, Val F1-score: 0.5750, Val Jaccard: 0.4233


loss: 0.25239967689868986: 100%|██████████| 2223/2223 [02:40<00:00, 13.82it/s]
100%|██████████| 271/271 [00:15<00:00, 17.92it/s]


best model saved
Epoch: 37, Training loss: 502.62556518607164, Val PRAUC: 0.7373, Val ROC_AUC: 0.9343, Val F1-score: 0.5742, Val Jaccard: 0.4229


loss: 0.19268260674661697: 100%|██████████| 2223/2223 [02:50<00:00, 13.01it/s]
100%|██████████| 271/271 [00:18<00:00, 14.91it/s]


Epoch: 38, Training loss: 501.9360889644533, Val PRAUC: 0.7373, Val ROC_AUC: 0.9343, Val F1-score: 0.5828, Val Jaccard: 0.4310


loss: 0.19220278833534504: 100%|██████████| 2223/2223 [02:59<00:00, 12.41it/s]
100%|██████████| 271/271 [00:20<00:00, 13.35it/s]


Epoch: 39, Training loss: 501.81665132676704, Val PRAUC: 0.7364, Val ROC_AUC: 0.9336, Val F1-score: 0.5721, Val Jaccard: 0.4207


loss: 0.2388874401674598: 100%|██████████| 2223/2223 [03:10<00:00, 11.70it/s] 
100%|██████████| 271/271 [00:19<00:00, 13.60it/s]


Epoch: 40, Training loss: 501.4019642126533, Val PRAUC: 0.7357, Val ROC_AUC: 0.9338, Val F1-score: 0.5749, Val Jaccard: 0.4231


loss: 0.1950800610722115: 100%|██████████| 2223/2223 [03:15<00:00, 11.37it/s] 
100%|██████████| 271/271 [00:19<00:00, 13.65it/s]


Epoch: 41, Training loss: 500.89516088711434, Val PRAUC: 0.7369, Val ROC_AUC: 0.9337, Val F1-score: 0.5933, Val Jaccard: 0.4408


loss: 0.2600185637682354: 100%|██████████| 2223/2223 [03:13<00:00, 11.47it/s] 
100%|██████████| 271/271 [00:20<00:00, 13.18it/s]


Epoch: 42, Training loss: 500.7281724472741, Val PRAUC: 0.7368, Val ROC_AUC: 0.9342, Val F1-score: 0.5667, Val Jaccard: 0.4155


loss: 0.24994443603869715: 100%|██████████| 2223/2223 [03:08<00:00, 11.77it/s]
100%|██████████| 271/271 [00:17<00:00, 15.25it/s]


best model saved
Epoch: 43, Training loss: 500.9458651389644, Val PRAUC: 0.7386, Val ROC_AUC: 0.9346, Val F1-score: 0.5782, Val Jaccard: 0.4268


loss: 0.2692422061525881: 100%|██████████| 2223/2223 [03:02<00:00, 12.15it/s] 
100%|██████████| 271/271 [00:22<00:00, 12.27it/s]


Epoch: 44, Training loss: 500.23906527011775, Val PRAUC: 0.7339, Val ROC_AUC: 0.9334, Val F1-score: 0.5757, Val Jaccard: 0.4241


loss: 0.1966615091898438: 100%|██████████| 2223/2223 [02:51<00:00, 12.93it/s] 
100%|██████████| 271/271 [00:15<00:00, 17.41it/s]


Epoch: 45, Training loss: 499.8853326913479, Val PRAUC: 0.7343, Val ROC_AUC: 0.9324, Val F1-score: 0.5938, Val Jaccard: 0.4409


loss: 0.25591749245114886: 100%|██████████| 2223/2223 [02:47<00:00, 13.27it/s]
100%|██████████| 271/271 [00:22<00:00, 12.11it/s]


Epoch: 46, Training loss: 499.5433346207292, Val PRAUC: 0.7366, Val ROC_AUC: 0.9342, Val F1-score: 0.5864, Val Jaccard: 0.4340


loss: 0.2588962777731507: 100%|██████████| 2223/2223 [02:46<00:00, 13.37it/s] 
100%|██████████| 271/271 [00:18<00:00, 14.54it/s]


Epoch: 47, Training loss: 499.202745295591, Val PRAUC: 0.7360, Val ROC_AUC: 0.9337, Val F1-score: 0.5921, Val Jaccard: 0.4395


loss: 0.19312815908735756: 100%|██████████| 2223/2223 [02:46<00:00, 13.39it/s]
100%|██████████| 271/271 [00:19<00:00, 13.78it/s]


best model saved
Epoch: 48, Training loss: 499.68388534362253, Val PRAUC: 0.7394, Val ROC_AUC: 0.9347, Val F1-score: 0.5886, Val Jaccard: 0.4370


loss: 0.25292904059296445: 100%|██████████| 2223/2223 [02:56<00:00, 12.60it/s]
100%|██████████| 271/271 [00:22<00:00, 11.97it/s]


Epoch: 49, Training loss: 499.45574927639666, Val PRAUC: 0.7384, Val ROC_AUC: 0.9345, Val F1-score: 0.5927, Val Jaccard: 0.4405


loss: 0.23888105063542406: 100%|██████████| 2223/2223 [03:14<00:00, 11.43it/s]
100%|██████████| 271/271 [00:19<00:00, 13.94it/s]


Epoch: 50, Training loss: 499.08935754453796, Val PRAUC: 0.7386, Val ROC_AUC: 0.9347, Val F1-score: 0.5869, Val Jaccard: 0.4348


loss: 0.2214860826783488: 100%|██████████| 2223/2223 [03:14<00:00, 11.45it/s] 
100%|██████████| 271/271 [00:22<00:00, 11.84it/s]


Epoch: 51, Training loss: 498.93953396546306, Val PRAUC: 0.7378, Val ROC_AUC: 0.9345, Val F1-score: 0.5970, Val Jaccard: 0.4448


loss: 0.24503920857640643: 100%|██████████| 2223/2223 [03:01<00:00, 12.25it/s]
100%|██████████| 271/271 [00:20<00:00, 13.39it/s]


Epoch: 52, Training loss: 498.99925550237987, Val PRAUC: 0.7384, Val ROC_AUC: 0.9342, Val F1-score: 0.5881, Val Jaccard: 0.4363


loss: 0.20921497779866546: 100%|██████████| 2223/2223 [03:00<00:00, 12.29it/s]
100%|██████████| 271/271 [00:20<00:00, 13.21it/s]


Epoch: 53, Training loss: 498.48129641795003, Val PRAUC: 0.7378, Val ROC_AUC: 0.9343, Val F1-score: 0.5842, Val Jaccard: 0.4322


loss: 0.24957955463703496: 100%|██████████| 2223/2223 [02:56<00:00, 12.63it/s]
100%|██████████| 271/271 [00:15<00:00, 17.02it/s]


Epoch: 54, Training loss: 498.82705942120526, Val PRAUC: 0.7370, Val ROC_AUC: 0.9335, Val F1-score: 0.5897, Val Jaccard: 0.4372


loss: 0.1951572555367995: 100%|██████████| 2223/2223 [03:00<00:00, 12.32it/s] 
100%|██████████| 271/271 [00:14<00:00, 18.78it/s]


best model saved
Epoch: 55, Training loss: 498.32914848464105, Val PRAUC: 0.7397, Val ROC_AUC: 0.9350, Val F1-score: 0.5807, Val Jaccard: 0.4295


loss: 0.253510790274362: 100%|██████████| 2223/2223 [02:47<00:00, 13.26it/s]  
100%|██████████| 271/271 [00:19<00:00, 14.25it/s]


Epoch: 56, Training loss: 498.3820612854395, Val PRAUC: 0.7379, Val ROC_AUC: 0.9342, Val F1-score: 0.5752, Val Jaccard: 0.4231


loss: 0.21018016546946724: 100%|██████████| 2223/2223 [02:54<00:00, 12.76it/s]
100%|██████████| 271/271 [00:16<00:00, 16.77it/s]


Epoch: 57, Training loss: 497.9989936056451, Val PRAUC: 0.7399, Val ROC_AUC: 0.9350, Val F1-score: 0.5961, Val Jaccard: 0.4437


loss: 0.22829240612955068: 100%|██████████| 2223/2223 [02:52<00:00, 12.88it/s]
100%|██████████| 271/271 [00:19<00:00, 13.62it/s]


Epoch: 58, Training loss: 498.02106887904614, Val PRAUC: 0.7374, Val ROC_AUC: 0.9340, Val F1-score: 0.5911, Val Jaccard: 0.4390


loss: 0.2566155296914928:  67%|██████▋   | 1495/2223 [01:50<00:53, 13.51it/s] 


KeyboardInterrupt: 

: 

In [None]:
for data in test_loader:
    print(data.y)

In [None]:
# torch.save(model.state_dict(), './exp_data/saved_weights_gat_mimic3_drugrec.pkl')
# torch.save(model.state_dict(), './exp_data/saved_weights_gin_mimic3_drugrec_random.pkl')
# torch.save(model.state_dict(), './exp_data/saved_weights_hgt_mimic3_drugrec.pkl')

In [None]:
# model.load_state_dict(torch.load('./exp_data/saved_weights_gat_mimic3_drugrec.pkl'))
model.load_state_dict(torch.load('../../../data/pj20/exp_data/saved_weights_graph_mimic3_drugrec_th02.pkl'))
model.double()

y_true_all, y_prob_all = evaluate(model, device, val_loader, static)

y_pred_all = y_prob_all.copy()
y_pred_all[y_pred_all >= 0.5] = 1
y_pred_all[y_pred_all < 0.5] = 0

test_pr_auc = average_precision_score(y_true_all, y_prob_all, average="samples")
test_roc_auc = roc_auc_score(y_true_all, y_prob_all, average="samples")
test_f1 = f1_score(y_true_all, y_pred_all, average='samples')
test_jaccard = jaccard_score(y_true_all, y_pred_all, average='samples')

print(f'test PRAUC: {test_pr_auc:.4f}, test ROC_AUC: {test_roc_auc:.4f}, test F1-score: {test_f1:.4f}, test Jaccard: {test_jaccard:.4f}')