![alt text](04_respiratory_metrics.png)

In [1]:
import torch
import torch.utils.data
from torch.utils.data import TensorDataset

import argparse
import time
import gc
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

from tqdm.notebook import tqdm

import wandb

from sklearn.metrics import roc_auc_score, average_precision_score

In [2]:
import h5py
import numpy as np
from loader import ICUVariableLengthLoaderTables, ICUVariableLengthDataset
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

In [3]:
torch.manual_seed(0)
torch.cuda.is_available(), torch.cuda.get_device_name(0)

(True, 'NVIDIA GeForce RTX 4060 Ti')

In [4]:
import gc
gc.collect()

11

In [5]:
from net import GNNStack
# from utils_todynet_binary import AverageMeter, accuracy, log_msg, get_default_train_val_test_loader

In [6]:
import warnings

# Suppress all warnings
warnings.filterwarnings("ignore")

In [7]:
# class FocalLoss(nn.Module):
#     def __init__(self, alpha=None, gamma=2):
#         super(FocalLoss, self).__init__()
#         self.alpha = alpha
#         self.gamma = gamma

#     def forward(self, inputs, targets):
#         ce_loss = F.cross_entropy(inputs, targets, reduction='none')
#         pt = torch.exp(-ce_loss)
#         loss = (self.alpha[targets] * (1 - pt) ** self.gamma * ce_loss).mean()
#         return loss

# class FocalLoss(nn.Module):
#     def __init__(self, alpha=None, gamma=2, num_classes=None):
#         super(FocalLoss, self).__init__()
#         self.gamma = gamma
#         if alpha is None:
#             self.alpha = None
#         else:
#             # If alpha is a single float number, expand it to a tensor
#             if isinstance(alpha, float):
#                 assert num_classes is not None, "num_classes must be specified when alpha is a single float"
#                 self.alpha = torch.full((num_classes,), alpha)
#             else:
#                 self.alpha = alpha

#     def forward(self, inputs, targets):
#         ce_loss = F.cross_entropy(inputs, targets, reduction='none')
#         pt = torch.exp(-ce_loss)
#         if self.alpha is not None:
#             if self.alpha.type() != inputs.data.type():
#                 self.alpha = self.alpha.type_as(inputs.data)
#             at = self.alpha[targets]
#         else:
#             at = 1.0
#         loss = (at * (1 - pt) ** self.gamma * ce_loss).mean()
#         return loss

## aruguments

In [10]:
args = {
    'arch': 'dyGIN2d', #what other models I can put here?? dyGCN2d, dyGIN2d
    'dataset': 'Resp_failure', # "AtrialFibrillation" # 'Mortality', # 'MIMIC3' Resp_failure
    'num_layers': 2,  # the number of GNN layers  3
    'groups': 32,  # the number of time series groups (num_graphs)
    'pool_ratio': 0.1,  # the ratio of pooling for nodes
    'kern_size': [3,3],  # list of time conv kernel size for each layer [9,5,3]
    'in_dim': 64,  # input dimensions of GNN stacks
    'hidden_dim': 64,  # hidden dimensions of GNN stacks
    'out_dim': 64,  # output dimensions of GNN stacks
    'workers': 0,  # number of data loading workers
    'epochs': 50,  # number of total epochs to run
    'batch_size': 8,  # mini-batch size, this is the total batch size of all GPUs
    'val_batch_size': 8,  # validation batch size
    'lr': 0.0002,  # initial learning rate
    'weight_decay': 1e-4,  # weight decay
    'evaluate': False,  # evaluate model on validation set
    'seed': 2,  # seed for initializing training
    'gpu': 0,  # GPU id to use
    'use_benchmark': True,  # use benchmark
    'tag': 'date',  # the tag for identifying the log and model files
    'loss':'bce'
}

In [11]:
# # start a new wandb run to track this script
# wandb.init(
#     # set the wandb project where this run will be logged
#     project="mortality",
    
#     # track hyperparameters and run metadata
#     config=args
# )

In [12]:
# train_dataset = TensorDataset(data_train, label_train)
# val_dataset   = TensorDataset(data_val, label_val)

In [13]:
# train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=args['batch_size'],shuffle=True, num_workers=args['workers'], pin_memory=True)
# val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args['val_batch_size'], shuffle=False,num_workers=args['workers'],pin_memory=True)

In [14]:
# def main():
#     # args = parser.parse_args()
    
#     # args.kern_size = [ int(l) for l in args.kern_size.split(",") ]

#     # if args.seed is not None:
#     random.seed(args['seed'])
#     torch.manual_seed(args['seed'])

#     main_work(args)

In [15]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    
    
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
              
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        
        # print(output, target)

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
           
            res.append(correct_k.mul_(100.0 / batch_size))

        return res

def log_msg(message, log_file):
    with open(log_file, 'a') as f:
        print(message, file=f)


def get_default_train_val_test_loader(args):

    # get dataset-id
    dsid = args['dataset']
    h5_path= 'h5_folder/ml_stage_12h.h5'
    # get dataset from .pt
    # data_train  = torch.load(f'data/UCR/{dsid}/X_train.pt')
    # data_val    = torch.load(f'data/UCR/{dsid}/X_valid.pt')
    # data_val    = torch.load(f'data/UCR/{dsid}/X_test.pt')
    
    # label_train = torch.load(f'data/UCR/{dsid}/y_train.pt')
    # label_val   = torch.load(f'data/UCR/{dsid}/y_valid.pt')
    # label_val   = torch.load(f'data/UCR/{dsid}/y_test.pt')

    # label_train = label_train.flatten().to(dtype=torch.int64)
    # label_val   = label_val.flatten().to(dtype=torch.int64)
    # init [num_variables, seq_length, num_classes]
    
    num_nodes = 231

    # seq_length = 288
    seq_length = 2016
    
    num_classes = 2


    # convert data & labels to TensorDataset
    # train_dataset = TensorDataset(data_train, label_train)
    # val_dataset = TensorDataset(data_val, label_val)
    
#     task              = 'Mortality_At24Hours'
#     maxlen            = 288                                                                             #patients| pos  | neg |
#     data_loader_train = ICUVariableLengthDataset(source_path=h5_path, maxlen=maxlen, task=task, split='train') #10524   | 9613 | 911 |
#     data_loader_val   = ICUVariableLengthDataset(source_path=h5_path, maxlen=maxlen, task=task, split='val')   #2205    | 2048 | 157 |
#     data_loader_test  = ICUVariableLengthDataset(source_path=h5_path, maxlen=maxlen, task=task, split='test')  #2231    | 2045 | 186 |
    
    
    
    task   = 'Dynamic_RespFailure_12Hours' ##27249
    maxlen            = 2016                                                                                   #|patients|    #-1 =8065                    #2016                      #difference
    data_loader_train = ICUVariableLengthDataset(source_path=h5_path, maxlen=maxlen, task=task, split='train') #| 19092  |    5646423(3467116+2179307)     4776429(2901908+1874521)   869994(565208+304786)
    data_loader_val   = ICUVariableLengthDataset(source_path=h5_path, maxlen=maxlen, task=task, split='val')   #| 4081   |    1220285(758822+461463)       1006329(614326+392003)     213956(144496+69460)
    data_loader_test  = ICUVariableLengthDataset(source_path=h5_path, maxlen=maxlen, task=task, split='test')  #| 4076   |    1209724(753935+455789)       1015139(620532+394607)     194585(133403+61182)
    
    train_loader = DataLoader(data_loader_train, batch_size=4, shuffle=True, num_workers=1,pin_memory=True, prefetch_factor=2)
    # val_loader   = DataLoader(data_loader_val,   batch_size=256, shuffle=False, num_workers=1,pin_memory=True, prefetch_factor=2)
    val_loader  = DataLoader(data_loader_test,  batch_size=4, shuffle=False, num_workers=1,pin_memory=True, prefetch_factor=2)
    # test_loader  = DataLoader(data_loader_test,  batch_size=256, shuffle=False, num_workers=1,pin_memory=True, prefetch_factor=2)


    # data_loader
    # train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=args['batch_size'],shuffle=True,num_workers=args['workers'], pin_memory=True)
    # val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=args['val_batch_size'],shuffle=False,num_workers=args['workers'],pin_memory=True)


    return train_loader, val_loader, num_nodes, seq_length, num_classes

In [16]:
def main_work(args):
    
    random.seed(args['seed'])
    torch.manual_seed(args['seed'])
    
    
    # init acc
    best_acc1 = 0
    best_roc  = 0
    best_pr   = 0
    
    if args['tag'] == 'date':
        local_date = time.strftime('%m.%d %H:%M', time.localtime(time.time()))
        args['tag'] = local_date

    log_file = 'log/{}_gpu{}_{}_{}_exp.txt'.format(args['tag'], args['gpu'], args['arch'], args['dataset'])
    
    
    if args['gpu'] is not None:
        print("Use GPU: {} for training".format(args['gpu']))


    # dataset
    train_loader, val_loader, num_nodes, seq_length, num_classes = get_default_train_val_test_loader(args)
    
    print('features / nodes', num_nodes,'total time graphs',seq_length,'classes', num_classes)
    
    # training model from net.py
    model = GNNStack(gnn_model_type=args['arch'], num_layers=args['num_layers'], 
                     groups=args['groups'], pool_ratio=args['pool_ratio'], kern_size=args['kern_size'], 
                     in_dim=args['in_dim'], hidden_dim=args['hidden_dim'], out_dim=args['out_dim'], 
                     seq_len=seq_length, num_nodes=num_nodes, num_classes=num_classes)

    # print & log
    log_msg('epochs {}, lr {}, weight_decay {}'.format(args['epochs'], args['lr'], args['weight_decay']), log_file)
    
    log_msg(str(args), log_file)


    # determine whether GPU or not
    if not torch.cuda.is_available():
        print("Warning! Using CPU!!!")
    elif args['gpu'] is not None:
        torch.cuda.set_device(args['gpu'])

        # collect cache
        gc.collect()
        torch.cuda.empty_cache()

        model = model.cuda(args['gpu'])
        if args['use_benchmark']:
            cudnn.benchmark = True
        print('Using cudnn.benchmark.')
    else:
        print("Error! We only have one gpu!!!")


    # define loss function(criterion) and optimizer
    # class_weights = torch.tensor([0.087, 0.913], dtype=torch.float).cuda(args['gpu'])
    # class_weights = torch.tensor([0.913, 0.087], dtype=torch.float).cuda(args['gpu'])
    # class_weights = torch.tensor([1.0, 22.47], dtype=torch.float).cuda(args['gpu'])
    
    
    # criterion = nn.CrossEntropyLoss(weight=class_weights).cuda(args['gpu'])
    
    criterion = nn.CrossEntropyLoss().cuda(args['gpu'])
    
    # criterion = FocalLoss(gamma=0.1)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
    
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)

    # validation
    if args['evaluate']:
        validate(val_loader, model, criterion, args)
        return

    # train & valid
    print('****************************************************')
    print('Dataset: ', args['dataset'])

    dataset_time = AverageMeter('Time', ':6.3f')

    loss_train = []
    acc_train = []
    loss_val = []
    acc_val = []
    epoches = []
    
    ###### 4 more lists to have values
    roc_train = []
    pr_train  = []
    
    roc_val   = []
    pr_val    = []

    end = time.time()
    for epoch in tqdm(range(args['epochs'])):
        epoches += [epoch]

        # train for one epoch
        acc_train_per, loss_train_per, output_train_per, target_train_per = train(train_loader, model, criterion, optimizer, lr_scheduler, args)
        
        acc_train += [acc_train_per]
        loss_train += [loss_train_per]
        # calculate metric
        # print(len(target_train_per),len(output_train_per))
        auc_roc_value_train = roc_auc_score(target_train_per, output_train_per)
        auc_pr_value_train = average_precision_score(target_train_per, output_train_per)
        #new code
        roc_train += [auc_roc_value_train]
        pr_train  += [auc_pr_value_train]

        msg = f'TRAIN, epoch {epoch}, train_loss {loss_train_per}, train_acc {acc_train_per}, train_roc {auc_roc_value_train:.5f}, train_pr {auc_pr_value_train:.5f}'

        print(f'TRAIN, epoch {epoch}, train_loss {loss_train_per:.5f}, train_roc {auc_roc_value_train:.5f}, train_pr {auc_pr_value_train:.5f}')
        log_msg(msg, log_file)

        
        # evaluate on validation set
        acc_val_per, loss_val_per, output_val_per, target_val_per = validate(val_loader, model, criterion, args)

        acc_val  += [acc_val_per]
        loss_val += [loss_val_per]
        #calculate metric
        # calculate metric
        # print(len(target_val_per),len(output_val_per))
        auc_roc_value_val = roc_auc_score(target_val_per, output_val_per)
        auc_pr_value_val = average_precision_score(target_val_per, output_val_per)
        #new code

        msg = f'VAL, epoch {epoch}, val_loss {loss_val_per}, val_acc {acc_val_per}, val_roc {auc_roc_value_val:.5f}, val_pr {auc_pr_value_val:.5f}'
        
        print(f'VAL, epoch {epoch}, val_loss {loss_val_per:.5f}, val_roc {auc_roc_value_val:.5f}, val_pr {auc_pr_value_val:.5f}')
        log_msg(msg, log_file)

        
        
        # remember best acc
        best_acc1 = max(acc_val_per, best_acc1)
        
        best_roc = max(auc_roc_value_val, best_roc)
        
        best_pr  = max(auc_pr_value_val, best_pr)
        
    #     wandb.log({"train_loss": loss_train_per, "train_roc": auc_roc_value_train, "train_pr": auc_pr_value_train, \
    #                "val_loss": loss_val_per, "val_roc": auc_roc_value_val, "val_pr": auc_pr_value_val, "best_val_roc": best_roc, "best_val_pr": best_pr})
    # wandb.finish()
    # measure elapsed time
    dataset_time.update(time.time() - end)

    # log & print the best_acc
    msg = f'\n\n * BEST_ACC: {best_acc1}\n * TIME: {dataset_time}\n'
    log_msg(msg, log_file)

    print(f' * best_acc1: {best_acc1}, best_roc: {best_roc}, best_pr: {best_pr}')
    print(f' * time: {dataset_time}')
    print('****************************************************')


    # collect cache
    gc.collect()
    torch.cuda.empty_cache()


def train(train_loader, model, criterion, optimizer, lr_scheduler, args):
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc', ':6.2f')
    # met_roc = AverageMeter('ROC', ':6.2f')
    # met_pr = AverageMeter('PR', ':6.2f')
    
    output_list = []
    target_list = [] 

    # switch to train mode
    model.train()

    for count, (data, label, mask) in enumerate(train_loader):

        # data in cuda
        data = data.cuda(args['gpu']).type(torch.float)
        data = data.view(data.size(0), 1, data.size(2), data.size(1))
        mask = mask.cuda(args['gpu']).type(torch.bool)
        label = label.cuda(args['gpu']).type(torch.long)

        # compute output
        output = model(data)
        # print(len(output))
        # print('output', output.shape, 'mask', mask.shape)
        out_flat = torch.masked_select(output, mask.unsqueeze(-1)).reshape(-1, output.shape[-1])
        # print(output)
        # print(output.shape, mask.shape, out_flat.shape)
        # break
        # out_flat = torch.masked_select(output[:,:,1], mask)
        

        label_flat = torch.masked_select(label, mask)
        # print('output',output.shape, 'mask', mask.shape,'out_flat', out_flat.shape, 'label', label.shape, 'label_flat', label_flat.shape)
        loss = criterion(out_flat, label_flat)

        # measure accuracy and record loss
        acc1 = accuracy(out_flat, label_flat, topk=(1, 1))
        
        output_np = torch.softmax(out_flat, dim=1).detach().cpu().numpy()[:,1].tolist()
        
        target_np = label_flat.detach().cpu().numpy().tolist()
        
        # print(output_np, target_np)
        
        losses.update(loss.item(), data.size(0))
        top1.update(acc1[0], data.size(0))
        
        # met_roc.update(roc, data.size(0))
        # met_pr.update(pr, data.size(0))
        output_list += output_np
        target_list += target_np

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    lr_scheduler.step(top1.avg)

    return top1.avg, losses.avg, output_list, target_list


def validate(val_loader, model, criterion, args):
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    # met_roc = AverageMeter('ROC', ':6.2f')
    # met_pr = AverageMeter('PR', ':6.2f')
    output_list = []
    target_list = [] 
    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        for count, (data, label, mask) in enumerate(val_loader):
            if args['gpu'] is not None:
                data = data.cuda(args['gpu'], non_blocking=True).type(torch.float)
                data = data.view(data.size(0), 1, data.size(2), data.size(1))
            if torch.cuda.is_available():
                label = label.cuda(args['gpu'], non_blocking=True).type(torch.long)
                mask  = mask.cuda(args['gpu'], non_blocking=True).type(torch.bool)

            # compute output
            output = model(data)
            
            out_flat = torch.masked_select(output, mask.unsqueeze(-1)).reshape(-1, output.shape[-1])
            label_flat = torch.masked_select(label, mask)

            loss = criterion(out_flat, label_flat)
            
            output_np = torch.softmax(out_flat, dim=1).detach().cpu().numpy()[:,1].tolist()
            target_np = label_flat.detach().cpu().numpy().tolist()

            # measure accuracy and record loss
            acc1 = accuracy(out_flat, label_flat, topk=(1, 1))
            losses.update(loss.item(), data.size(0))
            top1.update(acc1[0], data.size(0))
            
            output_list += output_np
            target_list += target_np
            
            # met_roc.update(roc, data.size(0))
            # met_pr.update(pr, data.size(0))

    return top1.avg, losses.avg, output_list, target_list

In [17]:
main_work(args)

Use GPU: 0 for training
features / nodes 231 total time graphs 2016 classes 2
Using cudnn.benchmark.
****************************************************
Dataset:  Resp_failure


  0%|                                                                                                                                                                                | 0/50 [00:00<?, ?it/s]

TRAIN, epoch 0, train_loss 0.68590, train_roc 0.53856, train_pr 0.42121


  2%|███▎                                                                                                                                                               | 1/50 [48:48<39:51:22, 2928.21s/it]

VAL, epoch 0, val_loss 0.68376, val_roc 0.54191, val_pr 0.42369
TRAIN, epoch 1, train_loss 0.68368, train_roc 0.55073, train_pr 0.43405


  4%|██████▍                                                                                                                                                          | 2/50 [1:37:08<38:49:38, 2912.05s/it]

VAL, epoch 1, val_loss 0.68321, val_roc 0.54366, val_pr 0.42611
TRAIN, epoch 2, train_loss 0.68375, train_roc 0.55363, train_pr 0.43653


  6%|█████████▋                                                                                                                                                       | 3/50 [2:25:29<37:56:52, 2906.66s/it]

VAL, epoch 2, val_loss 0.68320, val_roc 0.54546, val_pr 0.42585
TRAIN, epoch 3, train_loss 0.67882, train_roc 0.57603, train_pr 0.46107


  8%|████████████▉                                                                                                                                                    | 4/50 [3:13:49<37:06:39, 2904.33s/it]

VAL, epoch 3, val_loss 0.66857, val_roc 0.60870, val_pr 0.48228
TRAIN, epoch 4, train_loss 0.65411, train_roc 0.64110, train_pr 0.53147


 10%|████████████████                                                                                                                                                 | 5/50 [4:02:10<36:17:17, 2903.05s/it]

VAL, epoch 4, val_loss 0.64881, val_roc 0.65292, val_pr 0.53381
TRAIN, epoch 5, train_loss 0.63617, train_roc 0.67005, train_pr 0.57000


 12%|███████████████████▎                                                                                                                                             | 6/50 [4:50:31<35:28:16, 2902.20s/it]

VAL, epoch 5, val_loss 0.63498, val_roc 0.67034, val_pr 0.56181
TRAIN, epoch 6, train_loss 0.62396, train_roc 0.68534, train_pr 0.59048


 14%|██████████████████████▌                                                                                                                                          | 7/50 [5:38:54<34:40:05, 2902.46s/it]

VAL, epoch 6, val_loss 0.62686, val_roc 0.68203, val_pr 0.57642
TRAIN, epoch 7, train_loss 0.61577, train_roc 0.69744, train_pr 0.60345


 16%|█████████████████████████▊                                                                                                                                       | 8/50 [6:27:30<33:54:43, 2906.74s/it]

VAL, epoch 7, val_loss 0.62275, val_roc 0.68845, val_pr 0.58721
TRAIN, epoch 8, train_loss 0.61036, train_roc 0.70397, train_pr 0.61134


 18%|████████████████████████████▉                                                                                                                                    | 9/50 [7:16:14<33:09:57, 2912.14s/it]

VAL, epoch 8, val_loss 0.61484, val_roc 0.69989, val_pr 0.60130
TRAIN, epoch 9, train_loss 0.60732, train_roc 0.70817, train_pr 0.61816


 20%|████████████████████████████████                                                                                                                                | 10/50 [8:04:35<32:19:12, 2908.82s/it]

VAL, epoch 9, val_loss 0.61354, val_roc 0.69826, val_pr 0.59805
TRAIN, epoch 10, train_loss 0.60280, train_roc 0.71337, train_pr 0.62415


 22%|███████████████████████████████████▏                                                                                                                            | 11/50 [8:52:56<31:29:13, 2906.50s/it]

VAL, epoch 10, val_loss 0.60997, val_roc 0.70432, val_pr 0.59807
TRAIN, epoch 11, train_loss 0.60125, train_roc 0.71530, train_pr 0.62810


 24%|██████████████████████████████████████▍                                                                                                                         | 12/50 [9:41:18<30:39:46, 2904.90s/it]

VAL, epoch 11, val_loss 0.60526, val_roc 0.70807, val_pr 0.61268
TRAIN, epoch 12, train_loss 0.59893, train_roc 0.71888, train_pr 0.63232


 26%|█████████████████████████████████████████▎                                                                                                                     | 13/50 [10:29:39<29:50:39, 2903.76s/it]

VAL, epoch 12, val_loss 0.60482, val_roc 0.71177, val_pr 0.61574
TRAIN, epoch 13, train_loss 0.59801, train_roc 0.71975, train_pr 0.63419


 28%|████████████████████████████████████████████▌                                                                                                                  | 14/50 [11:18:00<29:01:48, 2903.03s/it]

VAL, epoch 13, val_loss 0.60616, val_roc 0.71121, val_pr 0.61521
TRAIN, epoch 14, train_loss 0.59409, train_roc 0.72193, train_pr 0.63645


 30%|███████████████████████████████████████████████▋                                                                                                               | 15/50 [12:06:21<28:13:08, 2902.53s/it]

VAL, epoch 14, val_loss 0.60364, val_roc 0.71319, val_pr 0.61949
TRAIN, epoch 15, train_loss 0.59457, train_roc 0.72442, train_pr 0.63949


 32%|██████████████████████████████████████████████████▉                                                                                                            | 16/50 [12:54:43<27:24:34, 2902.20s/it]

VAL, epoch 15, val_loss 0.60410, val_roc 0.71384, val_pr 0.61526
TRAIN, epoch 16, train_loss 0.59055, train_roc 0.72683, train_pr 0.64197


 34%|██████████████████████████████████████████████████████                                                                                                         | 17/50 [13:43:04<26:36:03, 2901.91s/it]

VAL, epoch 16, val_loss 0.60257, val_roc 0.71595, val_pr 0.62376
TRAIN, epoch 17, train_loss 0.58877, train_roc 0.72854, train_pr 0.64479


 36%|█████████████████████████████████████████████████████████▏                                                                                                     | 18/50 [14:31:25<25:47:34, 2901.69s/it]

VAL, epoch 17, val_loss 0.60046, val_roc 0.71710, val_pr 0.62589
TRAIN, epoch 18, train_loss 0.58888, train_roc 0.72873, train_pr 0.64564


 38%|████████████████████████████████████████████████████████████▍                                                                                                  | 19/50 [15:19:46<24:59:07, 2901.54s/it]

VAL, epoch 18, val_loss 0.60628, val_roc 0.71027, val_pr 0.62013
TRAIN, epoch 19, train_loss 0.58707, train_roc 0.73066, train_pr 0.64760


 40%|███████████████████████████████████████████████████████████████▌                                                                                               | 20/50 [16:08:57<24:18:05, 2916.18s/it]

VAL, epoch 19, val_loss 0.59651, val_roc 0.72278, val_pr 0.62810
TRAIN, epoch 20, train_loss 0.58651, train_roc 0.73195, train_pr 0.65037


 42%|██████████████████████████████████████████████████████████████████▊                                                                                            | 21/50 [16:58:01<23:33:35, 2924.67s/it]

VAL, epoch 20, val_loss 0.59280, val_roc 0.72303, val_pr 0.62828
TRAIN, epoch 21, train_loss 0.58547, train_roc 0.73304, train_pr 0.65075


 44%|█████████████████████████████████████████████████████████████████████▉                                                                                         | 22/50 [17:49:16<23:05:50, 2969.65s/it]

VAL, epoch 21, val_loss 0.59391, val_roc 0.72463, val_pr 0.63028
TRAIN, epoch 22, train_loss 0.58276, train_roc 0.73455, train_pr 0.65288


 46%|█████████████████████████████████████████████████████████████████████████▏                                                                                     | 23/50 [18:38:50<22:16:56, 2971.00s/it]

VAL, epoch 22, val_loss 0.59343, val_roc 0.72452, val_pr 0.63340
TRAIN, epoch 23, train_loss 0.58176, train_roc 0.73593, train_pr 0.65476


 48%|████████████████████████████████████████████████████████████████████████████▎                                                                                  | 24/50 [19:27:11<21:18:18, 2949.95s/it]

VAL, epoch 23, val_loss 0.58896, val_roc 0.72795, val_pr 0.63913
TRAIN, epoch 24, train_loss 0.58140, train_roc 0.73749, train_pr 0.65679


 50%|███████████████████████████████████████████████████████████████████████████████▌                                                                               | 25/50 [20:15:33<20:23:13, 2935.76s/it]

VAL, epoch 24, val_loss 0.59132, val_roc 0.72487, val_pr 0.63606
TRAIN, epoch 25, train_loss 0.57987, train_roc 0.73836, train_pr 0.65779


 52%|██████████████████████████████████████████████████████████████████████████████████▋                                                                            | 26/50 [21:03:55<19:30:12, 2925.52s/it]

VAL, epoch 25, val_loss 0.58808, val_roc 0.72733, val_pr 0.63736
TRAIN, epoch 26, train_loss 0.57897, train_roc 0.73941, train_pr 0.65917


 54%|█████████████████████████████████████████████████████████████████████████████████████▊                                                                         | 27/50 [21:52:17<18:38:43, 2918.43s/it]

VAL, epoch 26, val_loss 0.58858, val_roc 0.72739, val_pr 0.63982
TRAIN, epoch 27, train_loss 0.58027, train_roc 0.74037, train_pr 0.65991


 56%|█████████████████████████████████████████████████████████████████████████████████████████                                                                      | 28/50 [22:40:38<17:48:13, 2913.33s/it]

VAL, epoch 27, val_loss 0.58523, val_roc 0.73095, val_pr 0.63782
TRAIN, epoch 28, train_loss 0.57815, train_roc 0.74111, train_pr 0.66156


 58%|████████████████████████████████████████████████████████████████████████████████████████████▏                                                                  | 29/50 [23:29:00<16:58:26, 2909.85s/it]

VAL, epoch 28, val_loss 0.58422, val_roc 0.73265, val_pr 0.64548
TRAIN, epoch 29, train_loss 0.57660, train_roc 0.74277, train_pr 0.66270


 60%|███████████████████████████████████████████████████████████████████████████████████████████████▍                                                               | 30/50 [24:17:22<16:09:07, 2907.38s/it]

VAL, epoch 29, val_loss 0.58723, val_roc 0.72742, val_pr 0.63097
TRAIN, epoch 30, train_loss 0.57628, train_roc 0.74289, train_pr 0.66372


 62%|██████████████████████████████████████████████████████████████████████████████████████████████████▌                                                            | 31/50 [25:05:43<15:20:07, 2905.64s/it]

VAL, epoch 30, val_loss 0.58737, val_roc 0.73048, val_pr 0.63945
TRAIN, epoch 31, train_loss 0.57434, train_roc 0.74447, train_pr 0.66512


 64%|█████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                         | 32/50 [25:54:52<14:35:32, 2918.45s/it]

VAL, epoch 31, val_loss 0.58165, val_roc 0.73465, val_pr 0.64539
TRAIN, epoch 32, train_loss 0.57303, train_roc 0.74496, train_pr 0.66657


 66%|████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                      | 33/50 [26:44:49<13:53:37, 2942.23s/it]

VAL, epoch 32, val_loss 0.58400, val_roc 0.73289, val_pr 0.64438
TRAIN, epoch 33, train_loss 0.57390, train_roc 0.74631, train_pr 0.66772


 68%|████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                   | 34/50 [27:34:37<13:08:15, 2955.96s/it]

VAL, epoch 33, val_loss 0.58053, val_roc 0.73544, val_pr 0.64414
TRAIN, epoch 34, train_loss 0.57253, train_roc 0.74661, train_pr 0.66831


 70%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                               | 35/50 [28:24:37<12:22:14, 2968.96s/it]

VAL, epoch 34, val_loss 0.58263, val_roc 0.73293, val_pr 0.64857
TRAIN, epoch 35, train_loss 0.57046, train_roc 0.74820, train_pr 0.67012


 72%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                            | 36/50 [29:14:03<11:32:36, 2968.31s/it]

VAL, epoch 35, val_loss 0.58064, val_roc 0.73594, val_pr 0.64952
TRAIN, epoch 36, train_loss 0.57063, train_roc 0.74878, train_pr 0.66987


 74%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                         | 37/50 [30:04:38<10:47:27, 2988.28s/it]

VAL, epoch 36, val_loss 0.58398, val_roc 0.73178, val_pr 0.64664
TRAIN, epoch 37, train_loss 0.57023, train_roc 0.74955, train_pr 0.67151


 76%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                      | 38/50 [30:53:16<9:53:26, 2967.19s/it]

VAL, epoch 37, val_loss 0.57722, val_roc 0.73720, val_pr 0.64705
TRAIN, epoch 38, train_loss 0.56914, train_roc 0.75079, train_pr 0.67330


 78%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                   | 39/50 [31:41:37<9:00:19, 2947.27s/it]

VAL, epoch 38, val_loss 0.57805, val_roc 0.73864, val_pr 0.65385
TRAIN, epoch 39, train_loss 0.56791, train_roc 0.75126, train_pr 0.67435


 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                | 40/50 [32:29:58<8:08:53, 2933.33s/it]

VAL, epoch 39, val_loss 0.57933, val_roc 0.73685, val_pr 0.65065
TRAIN, epoch 40, train_loss 0.56580, train_roc 0.75240, train_pr 0.67592


 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                            | 41/50 [33:18:19<7:18:33, 2923.72s/it]

VAL, epoch 40, val_loss 0.57363, val_roc 0.74298, val_pr 0.65451
TRAIN, epoch 41, train_loss 0.56616, train_roc 0.75249, train_pr 0.67561


 84%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                         | 42/50 [34:06:40<6:28:55, 2916.96s/it]

VAL, epoch 41, val_loss 0.57235, val_roc 0.74216, val_pr 0.65796
TRAIN, epoch 42, train_loss 0.56534, train_roc 0.75298, train_pr 0.67644


 86%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                      | 43/50 [34:55:02<5:39:45, 2912.26s/it]

VAL, epoch 42, val_loss 0.57607, val_roc 0.73901, val_pr 0.65342
Epoch 00044: reducing learning rate of group 0 to 1.0000e-04.
TRAIN, epoch 43, train_loss 0.56594, train_roc 0.75485, train_pr 0.67872


 88%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                   | 44/50 [35:43:23<4:50:53, 2908.93s/it]

VAL, epoch 43, val_loss 0.57294, val_roc 0.74307, val_pr 0.65301
TRAIN, epoch 44, train_loss 0.55591, train_roc 0.76314, train_pr 0.68838


 90%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                | 45/50 [36:31:44<4:02:12, 2906.51s/it]

VAL, epoch 44, val_loss 0.56798, val_roc 0.74829, val_pr 0.66124
TRAIN, epoch 45, train_loss 0.55463, train_roc 0.76505, train_pr 0.69148


 92%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏            | 46/50 [37:20:05<3:13:39, 2904.96s/it]

VAL, epoch 45, val_loss 0.56855, val_roc 0.74689, val_pr 0.66196
TRAIN, epoch 46, train_loss 0.55323, train_roc 0.76565, train_pr 0.69218


 94%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍         | 47/50 [38:08:26<2:25:11, 2903.81s/it]

VAL, epoch 46, val_loss 0.56759, val_roc 0.74904, val_pr 0.66698
TRAIN, epoch 47, train_loss 0.55202, train_roc 0.76638, train_pr 0.69351


 96%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌      | 48/50 [38:56:47<1:36:45, 2902.94s/it]

VAL, epoch 47, val_loss 0.56529, val_roc 0.75104, val_pr 0.66517
TRAIN, epoch 48, train_loss 0.55083, train_roc 0.76689, train_pr 0.69347


 98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊   | 49/50 [39:46:21<48:44, 2924.11s/it]

VAL, epoch 48, val_loss 0.56909, val_roc 0.74779, val_pr 0.65865
TRAIN, epoch 49, train_loss 0.55104, train_roc 0.76754, train_pr 0.69483


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [40:36:16<00:00, 2923.54s/it]

VAL, epoch 49, val_loss 0.56393, val_roc 0.75209, val_pr 0.66701
 * best_acc1: tensor([70.4086], device='cuda:0'), best_roc: 0.7520907529774032, best_pr: 0.667008274085278
 * time: Time 146176.872 (146176.872)
****************************************************





In [14]:
main_work(args)

Use GPU: 0 for training
features / nodes 231 total time graphs 2016 classes 2
Using cudnn.benchmark.
****************************************************
Dataset:  Resp_failure


  0%|                                                                                                                                                                                | 0/50 [00:00<?, ?it/s]

TRAIN, epoch 0, train_loss 0.68733, train_roc 0.53742, train_pr 0.41795


  2%|███▎                                                                                                                                                               | 1/50 [24:46<20:14:07, 1486.69s/it]

VAL, epoch 0, val_loss 0.68152, val_roc 0.55437, val_pr 0.43292
TRAIN, epoch 1, train_loss 0.67359, train_roc 0.58742, train_pr 0.47100


  4%|██████▌                                                                                                                                                            | 2/50 [49:28<19:46:55, 1483.65s/it]

VAL, epoch 1, val_loss 0.67000, val_roc 0.60732, val_pr 0.47800
TRAIN, epoch 2, train_loss 0.65891, train_roc 0.62756, train_pr 0.51396


  6%|█████████▋                                                                                                                                                       | 3/50 [1:14:09<19:21:28, 1482.74s/it]

VAL, epoch 2, val_loss 0.66207, val_roc 0.62836, val_pr 0.49836
TRAIN, epoch 3, train_loss 0.65106, train_roc 0.64399, train_pr 0.53296


  8%|████████████▉                                                                                                                                                    | 4/50 [1:38:51<18:56:25, 1482.29s/it]

VAL, epoch 3, val_loss 0.65693, val_roc 0.63896, val_pr 0.50981
TRAIN, epoch 4, train_loss 0.64573, train_roc 0.65359, train_pr 0.54469


 10%|████████████████                                                                                                                                                 | 5/50 [2:03:33<18:31:34, 1482.10s/it]

VAL, epoch 4, val_loss 0.65280, val_roc 0.64655, val_pr 0.51885
TRAIN, epoch 5, train_loss 0.64129, train_roc 0.66099, train_pr 0.55406


 12%|███████████████████▎                                                                                                                                             | 6/50 [2:28:14<18:06:45, 1481.94s/it]

VAL, epoch 5, val_loss 0.64925, val_roc 0.65278, val_pr 0.52649
TRAIN, epoch 6, train_loss 0.63725, train_roc 0.66727, train_pr 0.56213


 14%|██████████████████████▌                                                                                                                                          | 7/50 [2:52:57<17:42:15, 1482.23s/it]

VAL, epoch 6, val_loss 0.64532, val_roc 0.65877, val_pr 0.53457
TRAIN, epoch 7, train_loss 0.63328, train_roc 0.67328, train_pr 0.57037


 16%|█████████████████████████▊                                                                                                                                       | 8/50 [3:19:16<17:39:04, 1512.98s/it]

VAL, epoch 7, val_loss 0.64142, val_roc 0.66470, val_pr 0.54270
TRAIN, epoch 8, train_loss 0.62929, train_roc 0.67898, train_pr 0.57783


 18%|████████████████████████████▉                                                                                                                                    | 9/50 [3:45:08<17:22:16, 1525.29s/it]

VAL, epoch 8, val_loss 0.63758, val_roc 0.67025, val_pr 0.55051
TRAIN, epoch 9, train_loss 0.62533, train_roc 0.68452, train_pr 0.58527


 20%|████████████████████████████████                                                                                                                                | 10/50 [4:11:13<17:04:56, 1537.41s/it]

VAL, epoch 9, val_loss 0.63357, val_roc 0.67564, val_pr 0.55765
TRAIN, epoch 10, train_loss 0.62125, train_roc 0.69003, train_pr 0.59224


 22%|███████████████████████████████████▏                                                                                                                            | 11/50 [4:38:31<16:59:14, 1568.08s/it]

VAL, epoch 10, val_loss 0.62976, val_roc 0.68106, val_pr 0.56479
TRAIN, epoch 11, train_loss 0.61751, train_roc 0.69476, train_pr 0.59845


 24%|██████████████████████████████████████▍                                                                                                                         | 12/50 [5:04:03<16:26:12, 1557.18s/it]

VAL, epoch 11, val_loss 0.62600, val_roc 0.68581, val_pr 0.57153
TRAIN, epoch 12, train_loss 0.61361, train_roc 0.69968, train_pr 0.60497


 26%|█████████████████████████████████████████▌                                                                                                                      | 13/50 [5:29:11<15:51:07, 1542.35s/it]

VAL, epoch 12, val_loss 0.62261, val_roc 0.69006, val_pr 0.57710
TRAIN, epoch 13, train_loss 0.61030, train_roc 0.70356, train_pr 0.61005


 28%|████████████████████████████████████████████▊                                                                                                                   | 14/50 [5:53:53<15:14:32, 1524.25s/it]

VAL, epoch 13, val_loss 0.61948, val_roc 0.69385, val_pr 0.58270
TRAIN, epoch 14, train_loss 0.60685, train_roc 0.70765, train_pr 0.61558


 30%|████████████████████████████████████████████████                                                                                                                | 15/50 [6:18:36<14:41:50, 1511.72s/it]

VAL, epoch 14, val_loss 0.61650, val_roc 0.69731, val_pr 0.58703
TRAIN, epoch 15, train_loss 0.60350, train_roc 0.71142, train_pr 0.62066


 32%|███████████████████████████████████████████████████▏                                                                                                            | 16/50 [6:43:19<14:11:45, 1503.11s/it]

VAL, epoch 15, val_loss 0.61400, val_roc 0.70040, val_pr 0.59064
TRAIN, epoch 16, train_loss 0.60027, train_roc 0.71508, train_pr 0.62550


 34%|██████████████████████████████████████████████████████▍                                                                                                         | 17/50 [7:09:00<13:52:54, 1514.39s/it]

VAL, epoch 16, val_loss 0.61185, val_roc 0.70293, val_pr 0.59444
TRAIN, epoch 17, train_loss 0.59740, train_roc 0.71841, train_pr 0.62987


 36%|█████████████████████████████████████████████████████████▌                                                                                                      | 18/50 [7:35:02<13:35:16, 1528.63s/it]

VAL, epoch 17, val_loss 0.60996, val_roc 0.70493, val_pr 0.59741
TRAIN, epoch 18, train_loss 0.59462, train_roc 0.72157, train_pr 0.63413


 38%|████████████████████████████████████████████████████████████▊                                                                                                   | 19/50 [8:00:12<13:06:55, 1523.08s/it]

VAL, epoch 18, val_loss 0.60805, val_roc 0.70716, val_pr 0.59947
TRAIN, epoch 19, train_loss 0.59181, train_roc 0.72473, train_pr 0.63831


 40%|████████████████████████████████████████████████████████████████                                                                                                | 20/50 [8:25:09<12:37:43, 1515.45s/it]

VAL, epoch 19, val_loss 0.60620, val_roc 0.70925, val_pr 0.60256
TRAIN, epoch 20, train_loss 0.58936, train_roc 0.72726, train_pr 0.64180


 42%|███████████████████████████████████████████████████████████████████▏                                                                                            | 21/50 [8:49:52<12:07:39, 1505.52s/it]

VAL, epoch 20, val_loss 0.60563, val_roc 0.71014, val_pr 0.60308
TRAIN, epoch 21, train_loss 0.58686, train_roc 0.73023, train_pr 0.64528


 44%|██████████████████████████████████████████████████████████████████████▍                                                                                         | 22/50 [9:15:33<11:47:35, 1516.28s/it]

VAL, epoch 21, val_loss 0.60455, val_roc 0.71163, val_pr 0.60498
TRAIN, epoch 22, train_loss 0.58458, train_roc 0.73262, train_pr 0.64875


 46%|█████████████████████████████████████████████████████████████████████████▌                                                                                      | 23/50 [9:41:44<11:29:41, 1532.65s/it]

VAL, epoch 22, val_loss 0.60360, val_roc 0.71242, val_pr 0.60565
TRAIN, epoch 23, train_loss 0.58247, train_roc 0.73499, train_pr 0.65182


 48%|████████████████████████████████████████████████████████████████████████████▎                                                                                  | 24/50 [10:07:40<11:07:11, 1539.68s/it]

VAL, epoch 23, val_loss 0.60313, val_roc 0.71309, val_pr 0.60632
TRAIN, epoch 24, train_loss 0.57991, train_roc 0.73801, train_pr 0.65556


 50%|███████████████████████████████████████████████████████████████████████████████▌                                                                               | 25/50 [10:32:58<10:38:50, 1533.22s/it]

VAL, epoch 24, val_loss 0.60228, val_roc 0.71415, val_pr 0.60749
TRAIN, epoch 25, train_loss 0.57746, train_roc 0.74057, train_pr 0.65877


 52%|██████████████████████████████████████████████████████████████████████████████████▋                                                                            | 26/50 [10:58:03<10:09:53, 1524.74s/it]

VAL, epoch 25, val_loss 0.60201, val_roc 0.71475, val_pr 0.60773
TRAIN, epoch 26, train_loss 0.57550, train_roc 0.74291, train_pr 0.66159


 54%|██████████████████████████████████████████████████████████████████████████████████████▍                                                                         | 27/50 [11:23:27<9:44:22, 1524.47s/it]

VAL, epoch 26, val_loss 0.60169, val_roc 0.71516, val_pr 0.60853
TRAIN, epoch 27, train_loss 0.57304, train_roc 0.74552, train_pr 0.66486


 56%|█████████████████████████████████████████████████████████████████████████████████████████▌                                                                      | 28/50 [11:49:23<9:22:26, 1533.92s/it]

VAL, epoch 27, val_loss 0.60205, val_roc 0.71510, val_pr 0.60727
TRAIN, epoch 28, train_loss 0.57134, train_roc 0.74749, train_pr 0.66727


 58%|████████████████████████████████████████████████████████████████████████████████████████████▊                                                                   | 29/50 [12:15:45<9:01:54, 1548.32s/it]

VAL, epoch 28, val_loss 0.60210, val_roc 0.71504, val_pr 0.60754
TRAIN, epoch 29, train_loss 0.56888, train_roc 0.75023, train_pr 0.67077


 60%|████████████████████████████████████████████████████████████████████████████████████████████████                                                                | 30/50 [12:41:53<8:38:04, 1554.24s/it]

VAL, epoch 29, val_loss 0.60208, val_roc 0.71534, val_pr 0.60803
TRAIN, epoch 30, train_loss 0.56715, train_roc 0.75217, train_pr 0.67316


 62%|███████████████████████████████████████████████████████████████████████████████████████████████████▏                                                            | 31/50 [13:07:39<8:11:23, 1551.77s/it]

VAL, epoch 30, val_loss 0.60214, val_roc 0.71536, val_pr 0.60810


 62%|███████████████████████████████████████████████████████████████████████████████████████████████████▏                                                            | 31/50 [13:30:51<8:16:58, 1569.41s/it]


KeyboardInterrupt: 

In [None]:
#TODO ignite metrics