![alt text](04_respiratory_metrics.png)

In [1]:
import torch
import torch.utils.data
from torch.utils.data import TensorDataset

import argparse
import time
import gc
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

from tqdm.notebook import tqdm

import wandb
import os

from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, matthews_corrcoef

In [2]:
import h5py
import numpy as np
from loader import ICUVariableLengthLoaderTables, ICUVariableLengthDataset
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

In [3]:
torch.manual_seed(0)
torch.cuda.is_available(), torch.cuda.get_device_name(0)

(True, 'NVIDIA GeForce RTX 4060 Ti')

In [4]:
import gc
gc.collect()

11

In [5]:
from net_dev import GNNStack
# from utils_todynet_binary import AverageMeter, accuracy, log_msg, get_default_train_val_test_loader

In [6]:
import warnings

# Suppress all warnings
warnings.filterwarnings("ignore")

In [7]:
# class FocalLoss(nn.Module):
#     def __init__(self, alpha=None, gamma=2):
#         super(FocalLoss, self).__init__()
#         self.alpha = alpha
#         self.gamma = gamma

#     def forward(self, inputs, targets):
#         ce_loss = F.cross_entropy(inputs, targets, reduction='none')
#         pt = torch.exp(-ce_loss)
#         loss = (self.alpha[targets] * (1 - pt) ** self.gamma * ce_loss).mean()
#         return loss

# class FocalLoss(nn.Module):
#     def __init__(self, alpha=None, gamma=2, num_classes=None):
#         super(FocalLoss, self).__init__()
#         self.gamma = gamma
#         if alpha is None:
#             self.alpha = None
#         else:
#             # If alpha is a single float number, expand it to a tensor
#             if isinstance(alpha, float):
#                 assert num_classes is not None, "num_classes must be specified when alpha is a single float"
#                 self.alpha = torch.full((num_classes,), alpha)
#             else:
#                 self.alpha = alpha

#     def forward(self, inputs, targets):
#         ce_loss = F.cross_entropy(inputs, targets, reduction='none')
#         pt = torch.exp(-ce_loss)
#         if self.alpha is not None:
#             if self.alpha.type() != inputs.data.type():
#                 self.alpha = self.alpha.type_as(inputs.data)
#             at = self.alpha[targets]
#         else:
#             at = 1.0
#         loss = (at * (1 - pt) ** self.gamma * ce_loss).mean()
#         return loss

## aruguments

In [8]:
args = {
    'arch': 'dyGIN2d', #what other models I can put here?? dyGCN2d, dyGIN2d
    'dataset': 'Resp_failure_development_attentionMSE', # "AtrialFibrillation" # 'Mortality', # 'MIMIC3' Resp_failure
    'num_layers': 2,  # the number of GNN layers  3
    'groups': 32,  # the number of time series groups (num_graphs)
    'pool_ratio': 0.0,  # the ratio of pooling for nodes # initially 0.1 but changed to 0 because the node number was decreasing
    'kern_size': [3,3],  # list of time conv kernel size for each layer [9,5,3]
    'in_dim': 64,  # input dimensions of GNN stacks
    'hidden_dim': 64,  # hidden dimensions of GNN stacks
    'out_dim': 64,  # output dimensions of GNN stacks
    'workers': 0,  # number of data loading workers
    'epochs': 20,  # number of total epochs to run
    'batch_size': 8,  # mini-batch size, this is the total batch size of all GPUs
    'val_batch_size': 8,  # validation batch size
    'lr': 0.0002,  # initial learning rate
    'weight_decay': 1e-4,  # weight decay
    'evaluate': False,  # evaluate model on validation set
    'seed': 3,  # seed for initializing training
    'gpu': 0,  # GPU id to use
    'use_benchmark': True,  # use benchmark
    'tag': 'date',  # the tag for identifying the log and model files
    'loss':'bce',
    'resample':'5min',
    'series_length':864,
}

# resample works: 5min, 15 min, 35 min, 45min. analysis will have 5,15,45

### GPU 2 is used to training in sheet 2

### TODO
## Molly douglas has filtered 116 features from 231, ~50% reduction

In [9]:
# # start a new wandb run to track this script
# wandb.init(
#     # set the wandb project where this run will be logged
#     project="mortality",
    
#     # track hyperparameters and run metadata
#     config=args
# )

In [10]:
# train_dataset = TensorDataset(data_train, label_train)
# val_dataset   = TensorDataset(data_val, label_val)

In [11]:
# train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=args['batch_size'],shuffle=True, num_workers=args['workers'], pin_memory=True)
# val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args['val_batch_size'], shuffle=False,num_workers=args['workers'],pin_memory=True)

In [12]:
# def main():
#     # args = parser.parse_args()
    
#     # args.kern_size = [ int(l) for l in args.kern_size.split(",") ]

#     # if args.seed is not None:
#     random.seed(args['seed'])
#     torch.manual_seed(args['seed'])

#     main_work(args)

In [13]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    
    
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
              
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        
        # print(output, target)

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            try:
                res.append(correct_k.mul_(100.0 / batch_size))
            except:
                res.append(0)

        return res

def log_msg(message, log_file):
    with open(log_file, 'a') as f:
        print(message, file=f)


def get_default_train_val_test_loader(args):

    # get dataset-id
    dsid = args['dataset']
    h5_path= 'h5_folder/ml_stage_12h.h5'
    # get dataset from .pt
    # data_train  = torch.load(f'data/UCR/{dsid}/X_train.pt')
    # data_val    = torch.load(f'data/UCR/{dsid}/X_valid.pt')
    # data_val    = torch.load(f'data/UCR/{dsid}/X_test.pt')
    
    # label_train = torch.load(f'data/UCR/{dsid}/y_train.pt')
    # label_val   = torch.load(f'data/UCR/{dsid}/y_valid.pt')
    # label_val   = torch.load(f'data/UCR/{dsid}/y_test.pt')

    # label_train = label_train.flatten().to(dtype=torch.int64)
    # label_val   = label_val.flatten().to(dtype=torch.int64)
    # init [num_variables, seq_length, num_classes]
    
    num_nodes = 231
    seq_length = 864 #288 #2016 ### now we will reduce the sequence length to see the performance in short-range
    num_classes = 2


    # convert data & labels to TensorDataset
    # train_dataset = TensorDataset(data_train, label_train)
    # val_dataset = TensorDataset(data_val, label_val)
    
#     task              = 'Mortality_At24Hours'
#     maxlen            = 288                                                                             #patients| pos  | neg |
#     data_loader_train = ICUVariableLengthDataset(source_path=h5_path, maxlen=maxlen, task=task, split='train') #10524   | 9613 | 911 |
#     data_loader_val   = ICUVariableLengthDataset(source_path=h5_path, maxlen=maxlen, task=task, split='val')   #2205    | 2048 | 157 |
#     data_loader_test  = ICUVariableLengthDataset(source_path=h5_path, maxlen=maxlen, task=task, split='test')  #2231    | 2045 | 186 |
    
    
    
    task   = 'Dynamic_RespFailure_12Hours' ##27249
    maxlen            = 864 #288 #2016                                                                         #|patients|    #-1 =8065                    #2016                      #difference
    data_loader_train = ICUVariableLengthDataset(source_path=h5_path, maxlen=maxlen, task=task, split='train') #| 19092  |    5646423(3467116+2179307)     4776429(2901908+1874521)   869994(565208+304786)
    data_loader_val   = ICUVariableLengthDataset(source_path=h5_path, maxlen=maxlen, task=task, split='val')   #| 4081   |    1220285(758822+461463)       1006329(614326+392003)     213956(144496+69460)
    data_loader_test  = ICUVariableLengthDataset(source_path=h5_path, maxlen=maxlen, task=task, split='test')  #| 4076   |    1209724(753935+455789)       1015139(620532+394607)     194585(133403+61182)
    
    train_loader = DataLoader(data_loader_train, batch_size=4, shuffle=True, num_workers=1,pin_memory=True, prefetch_factor=2)
    val_loader   = DataLoader(data_loader_val,   batch_size=4, shuffle=False, num_workers=1,pin_memory=True, prefetch_factor=2)
    # val_loader  = DataLoader(data_loader_test,  batch_size=4, shuffle=False, num_workers=1,pin_memory=True, prefetch_factor=2)
    test_loader  = DataLoader(data_loader_test,  batch_size=4, shuffle=False, num_workers=1,pin_memory=True, prefetch_factor=2)


    # data_loader
    # train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=args['batch_size'],shuffle=True,num_workers=args['workers'], pin_memory=True)
    # val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=args['val_batch_size'],shuffle=False,num_workers=args['workers'],pin_memory=True)


    return train_loader, val_loader, test_loader, num_nodes, seq_length, num_classes

In [14]:
def main_work(args):
    
    random.seed(args['seed'])
    torch.manual_seed(args['seed'])
    
    
    # init metrics
    best_acc1 = 0
    best_roc  = 0
    best_pr   = 0
    best_f1   = 0
    best_mcc  = 0
    
    best_test_acc1 = 0
    best_test_roc  = 0
    best_test_pr   = 0
    best_test_f1   = 0
    best_test_mcc  = 0    
    
    if args['tag'] == 'date':
        local_date = time.strftime('%m.%d %H:%M', time.localtime(time.time()))
        args['tag'] = local_date
        
    # Use the 'tag' which now contains either the date or a custom tag along with the dataset name for the directory
    run_dir_name = f"{args['dataset']}_{args['tag']}"
   
    # Base directory for saving models
    base_model_save_dir = "saved_models"
   
    # Specific directory for this run
    specific_model_save_dir = os.path.join(base_model_save_dir, run_dir_name)
    os.makedirs(specific_model_save_dir, exist_ok=True)

    print(f"Models will be saved in: {specific_model_save_dir}")        


    log_file = 'log/{}_gpu{}_{}_{}_exp.txt'.format(args['tag'], args['gpu'], args['arch'], args['dataset'])
    
    
    if args['gpu'] is not None:
        print("Use GPU: {} for training".format(args['gpu']))


    # dataset
    train_loader, val_loader, test_loader, num_nodes, seq_length, num_classes = get_default_train_val_test_loader(args)
    
    print('features / nodes', num_nodes,'total time graphs',args['groups'],'time series length',seq_length,'classes', num_classes)
    
    # training model from net.py
    model = GNNStack(gnn_model_type=args['arch'], num_layers=args['num_layers'], 
                     groups=args['groups'], pool_ratio=args['pool_ratio'], kern_size=args['kern_size'], 
                     in_dim=args['in_dim'], hidden_dim=args['hidden_dim'], out_dim=args['out_dim'], 
                     seq_len=seq_length, num_nodes=num_nodes, num_classes=num_classes)

    # print & log
    log_msg('epochs {}, lr {}, weight_decay {}'.format(args['epochs'], args['lr'], args['weight_decay']), log_file)
    
    log_msg(str(args), log_file)


    # determine whether GPU or not
    if not torch.cuda.is_available():
        print("Warning! Using CPU!!!")
    elif args['gpu'] is not None:
        torch.cuda.set_device(args['gpu'])

        # collect cache
        gc.collect()
        torch.cuda.empty_cache()

        model = model.cuda(args['gpu'])
        if args['use_benchmark']:
            cudnn.benchmark = True
        print('Using cudnn.benchmark.')
    else:
        print("Error! We only have one gpu!!!")


    # define loss function(criterion) and optimizer
    # class_weights = torch.tensor([0.087, 0.913], dtype=torch.float).cuda(args['gpu'])
    # class_weights = torch.tensor([0.913, 0.087], dtype=torch.float).cuda(args['gpu'])
    # class_weights = torch.tensor([1.0, 22.47], dtype=torch.float).cuda(args['gpu'])
    
    
    # criterion = nn.CrossEntropyLoss(weight=class_weights).cuda(args['gpu'])
    
    criterion = nn.CrossEntropyLoss().cuda(args['gpu'])
    
    # TODO
    # criterion = FocalLoss(gamma=0.1)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
    
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)

    # validation
    if args['evaluate']:
        validate(val_loader, model, criterion, args)
        return

    # train & valid
    print('****************************************************')
    print('Dataset: ', args['dataset'])

    dataset_time = AverageMeter('Time', ':6.3f')

    loss_train = []
    acc_train = []
    loss_val = []
    acc_val = []
    epoches = []
    
    ########################
    loss_test = []
    acc_test  = []
    #######################
    
    ###### more lists to have values
    roc_train = []
    pr_train  = []
    f1_train  = []
    mcc_train = []
    
    roc_val   = []
    pr_val    = []
    f1_val    = []
    mcc_val   = []
    ##################################################
    roc_test   = []
    pr_test    = []
    f1_test    = []
    mcc_test   = []     
    #################################################
    end = time.time()
    for epoch in tqdm(range(args['epochs'])):
    # for epoch in range(args['epochs']):
        epoches += [epoch]

        # train for one epoch
        acc_train_per, loss_train_per, output_train_per, target_train_per = train(train_loader, model, criterion, optimizer, lr_scheduler, args)
        
        acc_train += [acc_train_per]
        loss_train += [loss_train_per]
        # calculate metric
        # print(len(target_train_per),len(output_train_per))
        auc_roc_value_train = roc_auc_score(target_train_per, output_train_per)
        auc_pr_value_train = average_precision_score(target_train_per, output_train_per)
        p2l_value_train = np.where(np.array(output_train_per) >= 0.5, 1, 0)
        f1_value_train = f1_score(target_train_per, p2l_value_train.tolist())
        mcc_value_train= matthews_corrcoef(target_train_per,p2l_value_train.tolist())
        
        #new code
        roc_train += [auc_roc_value_train]
        pr_train  += [auc_pr_value_train]
        f1_train  += [f1_value_train]
        mcc_train  += [mcc_value_train]

        msg = f'TRAIN, epoch {epoch}, train_loss {loss_train_per}, train_acc {acc_train_per}, train_roc {auc_roc_value_train:.5f}, \
                train_pr {auc_pr_value_train:.5f}, train_f1 {f1_value_train:.5f},train_mcc {mcc_value_train:.5f}'

        print(f'TRAIN, epoch {epoch}, train_loss {loss_train_per:.5f}, train_roc {auc_roc_value_train:.5f}, train_pr {auc_pr_value_train:.5f}, train_f1 {f1_value_train:.5f}, train_mcc {mcc_value_train:.5f}')
        # tqdm.write(f'TRAIN, epoch {epoch}, train_loss {loss_train_per:.5f}, train_roc {auc_roc_value_train:.5f}, train_pr {auc_pr_value_train:.5f}, train_f1 {f1_value_train:.5f}, train_mcc {mcc_value_train:.5f}')
        log_msg(msg, log_file)

        
        # evaluate on validation set
        acc_val_per, loss_val_per, output_val_per, target_val_per = validate(val_loader, model, criterion, args)

        acc_val  += [acc_val_per]
        loss_val += [loss_val_per]
        #calculate metric
        # calculate metric
        # print(len(target_val_per),len(output_val_per))
        auc_roc_value_val = roc_auc_score(target_val_per, output_val_per)
        auc_pr_value_val = average_precision_score(target_val_per, output_val_per)
        p2l_value_val = np.where(np.array(output_val_per) >= 0.5, 1, 0)
        f1_value_val = f1_score(target_val_per, p2l_value_val.tolist())
        mcc_value_val= matthews_corrcoef(target_val_per,p2l_value_val.tolist())
        #new code

        msg = f'VAL, epoch {epoch}, val_loss {loss_val_per}, val_acc {acc_val_per}, val_roc {auc_roc_value_val:.5f}, val_pr {auc_pr_value_val:.5f},val_f1 {f1_value_val:.5f}, val_mcc {mcc_value_val:.5f}'
        
        print(f'VAL, epoch {epoch}, val_loss {loss_val_per:.5f}, val_roc {auc_roc_value_val:.5f}, val_pr {auc_pr_value_val:.5f},val_f1 {f1_value_val:.5f}, val_mcc {mcc_value_val:.5f}')
        # tqdm.write(f'VAL, epoch {epoch}, val_loss {loss_val_per:.5f}, val_roc {auc_roc_value_val:.5f}, val_pr {auc_pr_value_val:.5f},val_f1 {f1_value_val:.5f}, val_mcc {mcc_value_val:.5f}')
        log_msg(msg, log_file)
        #########################################################################################################################
        # evaluate on test set
        acc_test_per, loss_test_per, output_test_per, target_test_per = validate(test_loader, model, criterion, args)

        acc_test   += [acc_test_per]
        loss_test  += [loss_test_per]
        #calculate metric
        # calculate metric
        # print(len(target_val_per),len(output_val_per))
        auc_roc_value_test = roc_auc_score(target_test_per, output_test_per)
        auc_pr_value_test  = average_precision_score(target_test_per, output_test_per)
        p2l_value_test = np.where(np.array(output_test_per) >= 0.5, 1, 0)
        f1_value_test = f1_score(target_test_per, p2l_value_test.tolist())
        mcc_value_test= matthews_corrcoef(target_test_per,p2l_value_test.tolist())
        #new code

        msg = f'TEST, epoch {epoch},test_loss {loss_test_per},test_acc {acc_test_per},test_roc {auc_roc_value_test:.5f},test_pr {auc_pr_value_test:.5f},test_f1 {f1_value_test:.5f}, test_mcc {mcc_value_test:.5f}'
        print(f'TEST, epoch {epoch}, test_loss {loss_test_per:.5f}, test_roc {auc_roc_value_test:.5f}, test_pr {auc_pr_value_test:.5f}, test_f1 {f1_value_test:.5f}, test_mcc {mcc_value_test:.5f}')
        # tqdm.write(f'TEST, epoch {epoch}, test_loss {loss_test_per:.5f}, test_roc {auc_roc_value_test:.5f}, test_pr {auc_pr_value_test:.5f}, test_f1 {f1_value_test:.5f}, test_mcc {mcc_value_test:.5f}')
        log_msg(msg, log_file)        
        
        #########################################################################################################################
        # remember best acc
        best_acc1 = max(acc_val_per, best_acc1)
        best_roc  = max(auc_roc_value_val, best_roc)
        best_pr   = max(auc_pr_value_val, best_pr)
        best_f1   = max(f1_value_val, best_f1)
        best_mcc   = max(mcc_value_val, best_mcc)
        
        #########################################################################################################################
        
        best_test_acc1 = max(acc_test_per, best_test_acc1)
        best_test_roc  = max(auc_roc_value_test, best_test_roc)
        best_test_pr   = max(auc_pr_value_test, best_test_pr)
        best_test_f1   = max(f1_value_test, best_test_f1)
        best_test_mcc   = max(mcc_value_test, best_test_mcc)
 

    #     wandb.log({"train_loss": loss_train_per, "train_roc": auc_roc_value_train, "train_pr": auc_pr_value_train, \
    #                "val_loss": loss_val_per, "val_roc": auc_roc_value_val, "val_pr": auc_pr_value_val, "best_val_roc": best_roc, "best_val_pr": best_pr})
    # wandb.finish()
    
    
    
        # Construct the filename with metrics
        auc_roc_value_val_scalar = auc_roc_value_val.item() if isinstance(auc_roc_value_val, torch.Tensor) else auc_roc_value_val
        auc_pr_value_val_scalar  = auc_pr_value_val.item() if isinstance(auc_pr_value_val, torch.Tensor) else auc_pr_value_val
        f1_value_val_scalar      = f1_value_val.item() if isinstance(f1_value_val, torch.Tensor) else f1_value_val
        mcc_value_val_scalar     = mcc_value_val.item() if isinstance(mcc_value_val, torch.Tensor) else mcc_value_val
        #########################################################################################################################
        auc_roc_value_test_scalar = auc_roc_value_test.item() if isinstance(auc_roc_value_test, torch.Tensor) else auc_roc_value_test
        auc_pr_value_test_scalar  = auc_pr_value_test.item() if isinstance(auc_pr_value_test, torch.Tensor) else auc_pr_value_test
        f1_value_test_scalar      = f1_value_test.item() if isinstance(f1_value_test, torch.Tensor) else f1_value_test
        mcc_value_test_scalar      = mcc_value_test.item() if isinstance(mcc_value_test, torch.Tensor) else mcc_value_test        
        #########################################################################################################################
        # Now use these scalar values in the filename
        filename = f"model_epoch_{epoch}.pth"
       
        # Continue with the existing logic to save the model
        model_path = os.path.join(specific_model_save_dir, filename)
        torch.save(model.state_dict(), model_path)

    # measure elapsed time
    dataset_time.update(time.time() - end)

    # log & print the best_acc
    msg = f'\n\n * BEST_ACC: {best_acc1}\n * TIME: {dataset_time}\n'
    log_msg(msg, log_file)

    print(f' * best_acc1: {best_acc1}, best_roc: {best_roc}, best_pr: {best_pr}, best_f1: {best_f1}, best_mcc: {best_mcc}')
    print(f' * best_test_acc1: {best_test_acc1}, best_test_roc: {best_test_roc}, best_test_pr: {best_test_pr}, best_test_f1: {best_test_f1}, best_test_mcc: {best_test_mcc}')
    print(f' * time: {dataset_time}')
    print('****************************************************')


    # collect cache
    gc.collect()
    torch.cuda.empty_cache()


def train(train_loader, model, criterion, optimizer, lr_scheduler, args):
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc', ':6.2f')
    # met_roc = AverageMeter('ROC', ':6.2f')
    # met_pr = AverageMeter('PR', ':6.2f')
    
    output_list = []
    target_list = [] 

    # switch to train mode
    model.train()

    for count, (data, label, mask) in enumerate(train_loader):

        # data in cuda
        data = data.cuda(args['gpu']).type(torch.float)
        data = data.view(data.size(0), 1, data.size(2), data.size(1))
        mask = mask.cuda(args['gpu']).type(torch.bool)
        label = label.cuda(args['gpu']).type(torch.long)

        # compute output
        output = model(data)
        # print(len(output))
        # print('output', output.shape, 'mask', mask.shape)
        out_flat = torch.masked_select(output, mask.unsqueeze(-1)).reshape(-1, output.shape[-1])
        # print(output)
        # print(output.shape, mask.shape, out_flat.shape)
        # break
        # out_flat = torch.masked_select(output[:,:,1], mask)
        

        label_flat = torch.masked_select(label, mask)
        # print('output',output.shape, 'mask', mask.shape,'out_flat', out_flat.shape, 'label', label.shape, 'label_flat', label_flat.shape)
        loss = criterion(out_flat, label_flat)

        # measure accuracy and record loss
        acc1 = accuracy(out_flat, label_flat, topk=(1, 1))
        
        output_np = torch.softmax(out_flat, dim=1).detach().cpu().numpy()[:,1].tolist()
        
        target_np = label_flat.detach().cpu().numpy().tolist()
        
        # print(output_np, target_np)
        
        losses.update(loss.item(), data.size(0))
        top1.update(acc1[0], data.size(0))
        
        # met_roc.update(roc, data.size(0))
        # met_pr.update(pr, data.size(0))
        output_list += output_np
        target_list += target_np

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    lr_scheduler.step(top1.avg)

    return top1.avg, losses.avg, output_list, target_list


def validate(val_loader, model, criterion, args):
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    # met_roc = AverageMeter('ROC', ':6.2f')
    # met_pr = AverageMeter('PR', ':6.2f')
    output_list = []
    target_list = [] 
    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        for count, (data, label, mask) in enumerate(val_loader):
            if args['gpu'] is not None:
                data = data.cuda(args['gpu'], non_blocking=True).type(torch.float)
                data = data.view(data.size(0), 1, data.size(2), data.size(1))
            if torch.cuda.is_available():
                label = label.cuda(args['gpu'], non_blocking=True).type(torch.long)
                mask  = mask.cuda(args['gpu'], non_blocking=True).type(torch.bool)

            # compute output
            output = model(data)
            
            out_flat = torch.masked_select(output, mask.unsqueeze(-1)).reshape(-1, output.shape[-1])
            label_flat = torch.masked_select(label, mask)

            loss = criterion(out_flat, label_flat)
            
            output_np = torch.softmax(out_flat, dim=1).detach().cpu().numpy()[:,1].tolist()
            target_np = label_flat.detach().cpu().numpy().tolist()

            # measure accuracy and record loss
            acc1 = accuracy(out_flat, label_flat, topk=(1, 1))
            losses.update(loss.item(), data.size(0))
            top1.update(acc1[0], data.size(0))
            
            output_list += output_np
            target_list += target_np
            
            # met_roc.update(roc, data.size(0))
            # met_pr.update(pr, data.size(0))

    return top1.avg, losses.avg, output_list, target_list

def load_model(model_path, model_class, *model_args, **model_kwargs):
    """
    Load the model from a saved state dictionary.

    Parameters:
    - model_path: Path to the saved model state dictionary.
    - model_class: The class of the model to instantiate.
    - model_args: Positional arguments for the model class instantiation.
    - model_kwargs: Keyword arguments for the model class instantiation.

    Returns:
    - model: The loaded model ready for prediction.
    """
    # Instantiate the model
    model = model_class(*model_args, **model_kwargs)
    # Load the saved state dictionary
    model.load_state_dict(torch.load(model_path))
    model.eval()  # Set the model to evaluation mode
    return model

# def predict(test_loader, model, criterion, args):
#     """
#     Predict the labels for the given data using the loaded model.

#     Parameters:
#     - model: The loaded model ready for prediction.
#     - data_loader: DataLoader for the dataset to predict on.

#     Returns:
#     - predictions: Predictions for the input data.
#     """
#     predictions = []
#     with torch.no_grad():  # No need to track gradients
#         for data in test_loader:
#             # Assuming your model expects data in a specific format, adjust accordingly
#             if torch.cuda.is_available():
#                 data = data.to('cuda')
#             output = model(data)
#             # Convert output to probabilities and then to the desired format
#             # For example, using softmax for classification
#             prob = torch.softmax(output, dim=1)
#             predicted_classes = prob.argmax(dim=1)
#             predictions.append(predicted_classes.cpu().numpy())  # Move to CPU and convert to numpy if needed
#     return predictions

In [15]:
main_work(args)

Models will be saved in: saved_models/Resp_failure_development_attentionMSE_03.03 19:33
Use GPU: 0 for training
features / nodes 231 total time graphs 32 time series length 864 classes 2
Using cudnn.benchmark.
****************************************************
Dataset:  Resp_failure_development_attentionMSE


  0%|                                                    | 0/20 [00:00<?, ?it/s]

TRAIN, epoch 0, train_loss 0.67533, train_roc 0.58273, train_pr 0.47674, train_f1 0.43833, train_mcc 0.11019
VAL, epoch 0, val_loss 0.63599, val_roc 0.67714, val_pr 0.56967,val_f1 0.52421, val_mcc 0.25493


  5%|██                                      | 1/20 [29:40<9:23:49, 1780.52s/it]

TEST, epoch 0, test_loss 0.63486, test_roc 0.67450, test_pr 0.56987, test_f1 0.51791, test_mcc 0.25297
TRAIN, epoch 1, train_loss 0.60119, train_roc 0.72098, train_pr 0.63292, train_f1 0.56573, train_mcc 0.31476
VAL, epoch 1, val_loss 0.58900, val_roc 0.73509, val_pr 0.65450,val_f1 0.58592, val_mcc 0.33404


 10%|████                                    | 2/20 [59:05<8:51:20, 1771.16s/it]

TEST, epoch 1, test_loss 0.58678, test_roc 0.73542, test_pr 0.65677, test_f1 0.58199, test_mcc 0.33631
TRAIN, epoch 2, train_loss 0.56881, train_roc 0.75773, train_pr 0.68317, train_f1 0.60445, train_mcc 0.37301
VAL, epoch 2, val_loss 0.57911, val_roc 0.76209, val_pr 0.68428,val_f1 0.52884, val_mcc 0.36406


 15%|█████▋                                | 3/20 [1:28:03<8:17:35, 1756.22s/it]

TEST, epoch 2, test_loss 0.57613, test_roc 0.76189, test_pr 0.68648, test_f1 0.52786, test_mcc 0.37090
TRAIN, epoch 3, train_loss 0.54902, train_roc 0.77754, train_pr 0.70961, train_f1 0.62481, train_mcc 0.40304
VAL, epoch 3, val_loss 0.55265, val_roc 0.77499, val_pr 0.69986,val_f1 0.62151, val_mcc 0.39473


 20%|███████▌                              | 4/20 [1:57:14<7:47:44, 1754.06s/it]

TEST, epoch 3, test_loss 0.54905, test_roc 0.77539, test_pr 0.70446, test_f1 0.62161, test_mcc 0.40254
TRAIN, epoch 4, train_loss 0.53777, train_roc 0.78859, train_pr 0.72428, train_f1 0.63656, train_mcc 0.42000
VAL, epoch 4, val_loss 0.55116, val_roc 0.77621, val_pr 0.70545,val_f1 0.62136, val_mcc 0.40349


 25%|█████████▌                            | 5/20 [2:26:06<7:16:34, 1746.32s/it]

TEST, epoch 4, test_loss 0.54474, test_roc 0.78012, test_pr 0.71337, test_f1 0.62375, test_mcc 0.41460
TRAIN, epoch 5, train_loss 0.52776, train_roc 0.79726, train_pr 0.73477, train_f1 0.64684, train_mcc 0.43491
VAL, epoch 5, val_loss 0.54476, val_roc 0.78204, val_pr 0.70808,val_f1 0.61973, val_mcc 0.40715


 30%|███████████▍                          | 6/20 [2:55:47<6:50:11, 1757.97s/it]

TEST, epoch 5, test_loss 0.53629, test_roc 0.78757, test_pr 0.71693, test_f1 0.62168, test_mcc 0.41973
TRAIN, epoch 6, train_loss 0.51861, train_roc 0.80426, train_pr 0.74333, train_f1 0.65460, train_mcc 0.44657
VAL, epoch 6, val_loss 0.53850, val_roc 0.78741, val_pr 0.71612,val_f1 0.62857, val_mcc 0.41530


 35%|█████████████▎                        | 7/20 [3:24:44<6:19:24, 1751.13s/it]

TEST, epoch 6, test_loss 0.52870, test_roc 0.79247, test_pr 0.72734, test_f1 0.63258, test_mcc 0.42993
TRAIN, epoch 7, train_loss 0.51447, train_roc 0.80926, train_pr 0.74952, train_f1 0.65982, train_mcc 0.45546
VAL, epoch 7, val_loss 0.53397, val_roc 0.79354, val_pr 0.72079,val_f1 0.65230, val_mcc 0.42460


 40%|███████████████▏                      | 8/20 [3:54:28<5:52:18, 1761.52s/it]

TEST, epoch 7, test_loss 0.52410, test_roc 0.79938, test_pr 0.73203, test_f1 0.65754, test_mcc 0.43820
TRAIN, epoch 8, train_loss 0.50791, train_roc 0.81432, train_pr 0.75501, train_f1 0.66517, train_mcc 0.46366
VAL, epoch 8, val_loss 0.53365, val_roc 0.79351, val_pr 0.72103,val_f1 0.64016, val_mcc 0.42568


 45%|█████████████████                     | 9/20 [4:24:00<5:23:33, 1764.90s/it]

TEST, epoch 8, test_loss 0.52265, test_roc 0.79998, test_pr 0.73179, test_f1 0.64562, test_mcc 0.44160
TRAIN, epoch 9, train_loss 0.50326, train_roc 0.81887, train_pr 0.76113, train_f1 0.67101, train_mcc 0.47119
VAL, epoch 9, val_loss 0.53185, val_roc 0.79362, val_pr 0.72326,val_f1 0.63963, val_mcc 0.42957


 50%|██████████████████▌                  | 10/20 [4:52:57<4:52:42, 1756.24s/it]

TEST, epoch 9, test_loss 0.52405, test_roc 0.79692, test_pr 0.73149, test_f1 0.64095, test_mcc 0.43844
TRAIN, epoch 10, train_loss 0.49800, train_roc 0.82262, train_pr 0.76611, train_f1 0.67539, train_mcc 0.47958
VAL, epoch 10, val_loss 0.53000, val_roc 0.79629, val_pr 0.72626,val_f1 0.65060, val_mcc 0.43397


 55%|████████████████████▎                | 11/20 [5:22:00<4:22:49, 1752.22s/it]

TEST, epoch 10, test_loss 0.52231, test_roc 0.79968, test_pr 0.73464, test_f1 0.65168, test_mcc 0.44160
TRAIN, epoch 11, train_loss 0.49721, train_roc 0.82475, train_pr 0.76952, train_f1 0.67766, train_mcc 0.48239
VAL, epoch 11, val_loss 0.53357, val_roc 0.79988, val_pr 0.72631,val_f1 0.63341, val_mcc 0.43348


 60%|██████████████████████▏              | 12/20 [5:51:21<3:53:58, 1754.78s/it]

TEST, epoch 11, test_loss 0.52359, test_roc 0.80239, test_pr 0.73583, test_f1 0.63488, test_mcc 0.44282
TRAIN, epoch 12, train_loss 0.49150, train_roc 0.82915, train_pr 0.77459, train_f1 0.68404, train_mcc 0.49113
VAL, epoch 12, val_loss 0.53142, val_roc 0.79680, val_pr 0.72635,val_f1 0.65043, val_mcc 0.43321


 65%|████████████████████████             | 13/20 [6:21:39<3:26:58, 1774.08s/it]

TEST, epoch 12, test_loss 0.51793, test_roc 0.80450, test_pr 0.73869, test_f1 0.65710, test_mcc 0.45118
TRAIN, epoch 13, train_loss 0.48864, train_roc 0.83224, train_pr 0.77809, train_f1 0.68542, train_mcc 0.49465
VAL, epoch 13, val_loss 0.52718, val_roc 0.79927, val_pr 0.72864,val_f1 0.65866, val_mcc 0.43686


 70%|█████████████████████████▉           | 14/20 [6:50:54<2:56:49, 1768.24s/it]

TEST, epoch 13, test_loss 0.51822, test_roc 0.80407, test_pr 0.73828, test_f1 0.66419, test_mcc 0.45021
TRAIN, epoch 14, train_loss 0.48192, train_roc 0.83556, train_pr 0.78248, train_f1 0.69248, train_mcc 0.50250
VAL, epoch 14, val_loss 0.52771, val_roc 0.80132, val_pr 0.73142,val_f1 0.64270, val_mcc 0.43963


 75%|███████████████████████████▊         | 15/20 [7:19:05<2:25:24, 1744.88s/it]

TEST, epoch 14, test_loss 0.51496, test_roc 0.80753, test_pr 0.74367, test_f1 0.64920, test_mcc 0.45716
TRAIN, epoch 15, train_loss 0.47900, train_roc 0.83948, train_pr 0.78722, train_f1 0.69611, train_mcc 0.50858
VAL, epoch 15, val_loss 0.53381, val_roc 0.79900, val_pr 0.72431,val_f1 0.64790, val_mcc 0.43740


 80%|█████████████████████████████▌       | 16/20 [7:47:16<1:55:14, 1728.61s/it]

TEST, epoch 15, test_loss 0.51777, test_roc 0.80660, test_pr 0.73910, test_f1 0.65540, test_mcc 0.45509
TRAIN, epoch 16, train_loss 0.47412, train_roc 0.84268, train_pr 0.79103, train_f1 0.69942, train_mcc 0.51438
VAL, epoch 16, val_loss 0.53014, val_roc 0.79745, val_pr 0.72405,val_f1 0.65215, val_mcc 0.43388


 85%|███████████████████████████████▍     | 17/20 [8:15:26<1:25:51, 1717.27s/it]

TEST, epoch 16, test_loss 0.51597, test_roc 0.80610, test_pr 0.73963, test_f1 0.66113, test_mcc 0.45310
TRAIN, epoch 17, train_loss 0.46922, train_roc 0.84652, train_pr 0.79564, train_f1 0.70480, train_mcc 0.52210
VAL, epoch 17, val_loss 0.52911, val_roc 0.80067, val_pr 0.72553,val_f1 0.64825, val_mcc 0.44019


 90%|███████████████████████████████████    | 18/20 [8:43:37<56:58, 1709.31s/it]

TEST, epoch 17, test_loss 0.51239, test_roc 0.80944, test_pr 0.74147, test_f1 0.65816, test_mcc 0.46006
TRAIN, epoch 18, train_loss 0.46572, train_roc 0.84917, train_pr 0.79931, train_f1 0.70797, train_mcc 0.52660
VAL, epoch 18, val_loss 0.53493, val_roc 0.80100, val_pr 0.72643,val_f1 0.64473, val_mcc 0.43937


 95%|█████████████████████████████████████  | 19/20 [9:11:48<28:23, 1703.77s/it]

TEST, epoch 18, test_loss 0.51895, test_roc 0.80865, test_pr 0.74184, test_f1 0.65279, test_mcc 0.45606
TRAIN, epoch 19, train_loss 0.46068, train_roc 0.85210, train_pr 0.80326, train_f1 0.71073, train_mcc 0.53229
VAL, epoch 19, val_loss 0.52872, val_roc 0.80334, val_pr 0.72993,val_f1 0.65954, val_mcc 0.44444


100%|███████████████████████████████████████| 20/20 [9:39:59<00:00, 1739.97s/it]

TEST, epoch 19, test_loss 0.51374, test_roc 0.81117, test_pr 0.74603, test_f1 0.66922, test_mcc 0.46581
 * best_acc1: tensor([74.0154], device='cuda:0'), best_roc: 0.8033370457495794, best_pr: 0.7314193522293397, best_f1: 0.6595373010876141, best_mcc: 0.4444436657029826
 * best_test_acc1: tensor([75.0710], device='cuda:0'), best_test_roc: 0.8111713437523786, best_test_pr: 0.746031695483067, best_test_f1: 0.6692231407014086, best_test_mcc: 0.4658092798187367
 * time: Time 34799.484 (34799.484)
****************************************************



