![alt text](metrics.png)

In [1]:
import torch
import torch.utils.data
from torch.utils.data import TensorDataset

import argparse
import time
import gc
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

from tqdm.notebook import tqdm

# import wandb
import os

from sklearn.metrics import roc_auc_score, average_precision_score

In [2]:
torch.manual_seed(0)
torch.cuda.is_available(), torch.cuda.get_device_name(0)

(True, 'NVIDIA GeForce GTX 1650')

In [3]:
from net import GNNStack
from utils_todynet import AverageMeter, accuracy, log_msg, get_default_train_val_test_loader

In [4]:
import warnings

# Suppress all warnings
warnings.filterwarnings("ignore")

In [5]:
# class FocalLoss(nn.Module):
#     def __init__(self, alpha=None, gamma=2):
#         super(FocalLoss, self).__init__()
#         self.alpha = alpha
#         self.gamma = gamma

#     def forward(self, inputs, targets):
#         ce_loss = F.cross_entropy(inputs, targets, reduction='none')
#         pt = torch.exp(-ce_loss)
#         loss = (self.alpha[targets] * (1 - pt) ** self.gamma * ce_loss).mean()
#         return loss

# class FocalLoss(nn.Module):
#     def __init__(self, alpha=None, gamma=2, num_classes=None):
#         super(FocalLoss, self).__init__()
#         self.gamma = gamma
#         if alpha is None:
#             self.alpha = None
#         else:
#             # If alpha is a single float number, expand it to a tensor
#             if isinstance(alpha, float):
#                 assert num_classes is not None, "num_classes must be specified when alpha is a single float"
#                 self.alpha = torch.full((num_classes,), alpha)
#             else:
#                 self.alpha = alpha

#     def forward(self, inputs, targets):
#         ce_loss = F.cross_entropy(inputs, targets, reduction='none')
#         pt = torch.exp(-ce_loss)
#         if self.alpha is not None:
#             if self.alpha.type() != inputs.data.type():
#                 self.alpha = self.alpha.type_as(inputs.data)
#             at = self.alpha[targets]
#         else:
#             at = 1.0
#         loss = (at * (1 - pt) ** self.gamma * ce_loss).mean()
#         return loss

## aruguments

In [6]:
args = {
    'arch': 'dyGIN2d', #what other models I can put here?? dyGCN2d, dyGIN2d
    'dataset': 'Mortality', # "AtrialFibrillation" # 'Mortality', # 'MIMIC3'
    'num_layers': 2,  # the number of GNN layers  3
    'groups': 16,  # the number of time series groups (num_graphs)
    'pool_ratio': 0.1,  # the ratio of pooling for nodes
    'kern_size': [3,3],  # list of time conv kernel size for each layer [9,5,3]
    'in_dim': 32,  # input dimensions of GNN stacks
    'hidden_dim': 32,  # hidden dimensions of GNN stacks
    'out_dim': 32,  # output dimensions of GNN stacks
    'workers': 4,  # number of data loading workers
    'epochs': 30,  # number of total epochs to run
    'batch_size': 4,  # mini-batch size, this is the total batch size of all GPUs
    'val_batch_size': 4,  # validation batch size
    'lr': 0.000095,  # initial learning rate
    'weight_decay': 1e-4,  # weight decay
    'evaluate': False,  # evaluate model on validation set
    'seed': 2,  # seed for initializing training
    'gpu': 0,  # GPU id to use
    'use_benchmark': True,  # use benchmark
    'tag': 'date',  # the tag for identifying the log and model files
    'loss':'bce'
}

In [7]:
# # start a new wandb run to track this script
# wandb.init(
#     # set the wandb project where this run will be logged
#     project="mortality",
    
#     # track hyperparameters and run metadata
#     config=args
# )

In [8]:
# train_dataset = TensorDataset(data_train, label_train)
# val_dataset   = TensorDataset(data_val, label_val)

In [9]:
# train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=args['batch_size'],shuffle=True, num_workers=args['workers'], pin_memory=True)
# val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args['val_batch_size'], shuffle=False,num_workers=args['workers'],pin_memory=True)

In [10]:
# def main():
#     # args = parser.parse_args()
    
#     # args.kern_size = [ int(l) for l in args.kern_size.split(",") ]

#     # if args.seed is not None:
#     random.seed(args['seed'])
#     torch.manual_seed(args['seed'])

#     main_work(args)

In [11]:
def main_work(args):
    
    random.seed(args['seed'])
    torch.manual_seed(args['seed'])
    
    
    # init acc
    best_acc1 = 0
    best_roc  = 0
    best_pr   = 0
    
    if args['tag'] == 'date':
        local_date = time.strftime('%m.%d %H:%M', time.localtime(time.time()))
        args['tag'] = local_date

    # Use the 'tag' which now contains either the date or a custom tag along with the dataset name for the directory
    run_dir_name = f"{args['dataset']}_{args['tag']}"
    
    # Base directory for saving models
    base_model_save_dir = "saved_models"
    
    # Specific directory for this run
    specific_model_save_dir = os.path.join(base_model_save_dir, run_dir_name)
    os.makedirs(specific_model_save_dir, exist_ok=True)

    print(f"Models will be saved in: {specific_model_save_dir}")

    log_file = 'log/{}_gpu{}_{}_{}_exp.txt'.format(args['tag'], args['gpu'], args['arch'], args['dataset'])
    
    
    if args['gpu'] is not None:
        print("Use GPU: {} for training".format(args['gpu']))


    # dataset
    train_loader, val_loader, num_nodes, seq_length, num_classes = get_default_train_val_test_loader(args)
    
    print('features / nodes', num_nodes,'total time graphs',seq_length,'classes', num_classes)
    
    # training model from net.py
    model = GNNStack(gnn_model_type=args['arch'], num_layers=args['num_layers'], 
                     groups=args['groups'], pool_ratio=args['pool_ratio'], kern_size=args['kern_size'], 
                     in_dim=args['in_dim'], hidden_dim=args['hidden_dim'], out_dim=args['out_dim'], 
                     seq_len=seq_length, num_nodes=num_nodes, num_classes=num_classes)

    # print & log
    log_msg('epochs {}, lr {}, weight_decay {}'.format(args['epochs'], args['lr'], args['weight_decay']), log_file)
    
    log_msg(str(args), log_file)


    # determine whether GPU or not
    if not torch.cuda.is_available():
        print("Warning! Using CPU!!!")
    elif args['gpu'] is not None:
        torch.cuda.set_device(args['gpu'])

        # collect cache
        gc.collect()
        torch.cuda.empty_cache()

        model = model.cuda(args['gpu'])
        if args['use_benchmark']:
            cudnn.benchmark = True
        print('Using cudnn.benchmark.')
    else:
        print("Error! We only have one gpu!!!")


    # define loss function(criterion) and optimizer
    # class_weights = torch.tensor([0.087, 0.913], dtype=torch.float).cuda(args['gpu'])
    # class_weights = torch.tensor([0.913, 0.087], dtype=torch.float).cuda(args['gpu'])
    # class_weights = torch.tensor([1.0, 22.47], dtype=torch.float).cuda(args['gpu'])
    
    
    # criterion = nn.CrossEntropyLoss(weight=class_weights).cuda(args['gpu'])
    
    criterion = nn.CrossEntropyLoss().cuda(args['gpu'])
    
    # criterion = FocalLoss(gamma=0.1)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
    
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)

    # validation
    if args['evaluate']:
        validate(val_loader, model, criterion, args)
        return

    # train & valid
    print('****************************************************')
    print('Dataset: ', args['dataset'])

    dataset_time = AverageMeter('Time', ':6.3f')

    loss_train = []
    acc_train = []
    loss_val = []
    acc_val = []
    epoches = []
    
    ###### 4 more lists to have values
    roc_train = []
    pr_train  = []
    
    roc_val   = []
    pr_val    = []

    end = time.time()
    for epoch in tqdm(range(args['epochs'])):
        epoches += [epoch]

        # train for one epoch
        acc_train_per, loss_train_per, output_train_per, target_train_per = train(train_loader, model, criterion, optimizer, lr_scheduler, args)
        
        acc_train += [acc_train_per]
        loss_train += [loss_train_per]
        # calculate metric
        # print(len(target_train_per),len(output_train_per))
        auc_roc_value_train = roc_auc_score(target_train_per, output_train_per)
        auc_pr_value_train = average_precision_score(target_train_per, output_train_per)
        #new code
        roc_train += [auc_roc_value_train]
        pr_train  += [auc_pr_value_train]

        msg = f'TRAIN, epoch {epoch}, train_loss {loss_train_per}, train_acc {acc_train_per}, train_roc {auc_roc_value_train:.5f}, train_pr {auc_pr_value_train:.5f}'

        print(f'TRAIN, epoch {epoch}, train_loss {loss_train_per:.5f}, train_roc {auc_roc_value_train:.5f}, train_pr {auc_pr_value_train:.5f}')
        log_msg(msg, log_file)

        
        # evaluate on validation set
        acc_val_per, loss_val_per, output_val_per, target_val_per = validate(val_loader, model, criterion, args)

        acc_val  += [acc_val_per]
        loss_val += [loss_val_per]
        #calculate metric
        # calculate metric
        # print(len(target_val_per),len(output_val_per))
        auc_roc_value_val = roc_auc_score(target_val_per, output_val_per)
        auc_pr_value_val = average_precision_score(target_val_per, output_val_per)
        #new code

        msg = f'VAL, epoch {epoch}, val_loss {loss_val_per}, val_acc {acc_val_per}, val_roc {auc_roc_value_val:.5f}, val_pr {auc_pr_value_val:.5f}'
        
        print(f'VAL, epoch {epoch}, val_loss {loss_val_per:.5f}, val_roc {auc_roc_value_val:.5f}, val_pr {auc_pr_value_val:.5f}')
        log_msg(msg, log_file)

        
        
        # remember best acc
        best_acc1 = max(acc_val_per, best_acc1)
        
        best_roc = max(auc_roc_value_val, best_roc)
        
        best_pr  = max(auc_pr_value_val, best_pr)
        
#         wandb.log({"train_loss": loss_train_per, "train_roc": auc_roc_value_train, "train_pr": auc_pr_value_train, \
#                    "val_loss": loss_val_per, "val_roc": auc_roc_value_val, "val_pr": auc_pr_value_val, "best_val_roc": best_roc, "best_val_pr": best_pr})
#     wandb.finish()
        # Construct the filename with metrics
        acc_val_per_scalar = acc_val_per.item() if isinstance(acc_val_per, torch.Tensor) else acc_val_per
        auc_roc_value_val_scalar = auc_roc_value_val.item() if isinstance(auc_roc_value_val, torch.Tensor) else auc_roc_value_val
        auc_pr_value_val_scalar = auc_pr_value_val.item() if isinstance(auc_pr_value_val, torch.Tensor) else auc_pr_value_val
        
        # Now use these scalar values in the filename
        filename = f"model_epoch_{epoch}_acc_{acc_val_per_scalar:.3f}_roc_{auc_roc_value_val_scalar:.3f}_pr_{auc_pr_value_val_scalar:.3f}.pth"
        
        # Continue with the existing logic to save the model
        model_path = os.path.join(specific_model_save_dir, filename)
        torch.save(model.state_dict(), model_path)

    # measure elapsed time
    dataset_time.update(time.time() - end)

    # log & print the best_acc
    msg = f'\n\n * BEST_ACC: {best_acc1}\n * TIME: {dataset_time}\n'
    log_msg(msg, log_file)

    print(f' * best_acc1: {best_acc1}, best_roc: {best_roc}, best_pr: {best_pr}')
    print(f' * time: {dataset_time}')
    print('****************************************************')


    # collect cache
    gc.collect()
    torch.cuda.empty_cache()


def train(train_loader, model, criterion, optimizer, lr_scheduler, args):
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc', ':6.2f')
    # met_roc = AverageMeter('ROC', ':6.2f')
    # met_pr = AverageMeter('PR', ':6.2f')
    
    output_list = []
    target_list = [] 

    # switch to train mode
    model.train()

    for count, (data, label) in enumerate(train_loader):

        # data in cuda
        data = data.cuda(args['gpu']).type(torch.float)
        label = label.cuda(args['gpu']).type(torch.long)

        # compute output
        output = model(data)
    
        loss = criterion(output, label)

        # measure accuracy and record loss
        acc1 = accuracy(output, label, topk=(1, 1))
        
        output_np = torch.softmax(output, dim=1).detach().cpu().numpy()[:,1].tolist()
        
        target_np = label.detach().cpu().numpy().tolist()
        
        # print(output_np, target_np)
        
        losses.update(loss.item(), data.size(0))
        top1.update(acc1[0], data.size(0))
        
        # met_roc.update(roc, data.size(0))
        # met_pr.update(pr, data.size(0))
        output_list += output_np
        target_list += target_np

        # compute gradient and do Adam step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    lr_scheduler.step(top1.avg)

    return top1.avg, losses.avg, output_list, target_list


def validate(val_loader, model, criterion, args):
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    # met_roc = AverageMeter('ROC', ':6.2f')
    # met_pr = AverageMeter('PR', ':6.2f')
    output_list = []
    target_list = [] 
    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        for count, (data, label) in enumerate(val_loader):
            if args['gpu'] is not None:
                data = data.cuda(args['gpu'], non_blocking=True).type(torch.float)
            if torch.cuda.is_available():
                label = label.cuda(args['gpu'], non_blocking=True).type(torch.long)

            # compute output
            output = model(data)

            loss = criterion(output, label)
            
            output_np = torch.softmax(output, dim=1).detach().cpu().numpy()[:,1].tolist()
            target_np = label.detach().cpu().numpy().tolist()

            # measure accuracy and record loss
            acc1 = accuracy(output, label, topk=(1, 1))
            losses.update(loss.item(), data.size(0))
            top1.update(acc1[0], data.size(0))
            
            output_list += output_np
            target_list += target_np
            
            # met_roc.update(roc, data.size(0))
            # met_pr.update(pr, data.size(0))

    return top1.avg, losses.avg, output_list, target_list

def load_model(model_path, model_class, *model_args, **model_kwargs):
    """
    Load the model from a saved state dictionary.

    Parameters:
    - model_path: Path to the saved model state dictionary.
    - model_class: The class of the model to instantiate.
    - model_args: Positional arguments for the model class instantiation.
    - model_kwargs: Keyword arguments for the model class instantiation.

    Returns:
    - model: The loaded model ready for prediction.
    """
    # Instantiate the model
    model = model_class(*model_args, **model_kwargs)
    # Load the saved state dictionary
    model.load_state_dict(torch.load(model_path))
    model.eval()  # Set the model to evaluation mode
    return model

def predict(model, data_loader):
    """
    Predict the labels for the given data using the loaded model.

    Parameters:
    - model: The loaded model ready for prediction.
    - data_loader: DataLoader for the dataset to predict on.

    Returns:
    - predictions: Predictions for the input data.
    """
    predictions = []
    with torch.no_grad():  # No need to track gradients
        for data in data_loader:
            # Assuming your model expects data in a specific format, adjust accordingly
            if torch.cuda.is_available():
                data = data.to('cuda')
            output = model(data)
            # Convert output to probabilities and then to the desired format
            # For example, using softmax for classification
            prob = torch.softmax(output, dim=1)
            predicted_classes = prob.argmax(dim=1)
            predictions.append(predicted_classes.cpu().numpy())  # Move to CPU and convert to numpy if needed
    return predictions

In [12]:
main_work(args)

Models will be saved in: saved_models/Mortality_02.19 19:10
Use GPU: 0 for training
features / nodes 231 total time graphs 288 classes 2
Using cudnn.benchmark.
****************************************************
Dataset:  Mortality


  0%|          | 0/30 [00:00<?, ?it/s]

TRAIN, epoch 0, train_loss 0.30024, train_roc 0.67835, train_pr 0.16221
VAL, epoch 0, val_loss 0.21855, val_roc 0.87050, val_pr 0.41038
TRAIN, epoch 1, train_loss 0.22894, train_roc 0.83433, train_pr 0.38487
VAL, epoch 1, val_loss 0.20015, val_roc 0.87927, val_pr 0.48147
TRAIN, epoch 2, train_loss 0.21301, train_roc 0.86105, train_pr 0.44842
VAL, epoch 2, val_loss 0.19571, val_roc 0.88772, val_pr 0.55788


KeyboardInterrupt: 

In [13]:
#### model load

In [14]:
# Example usage
model_path = 'saved_models/Mortality_02.19 19:10/model_epoch_2_acc_92.604_roc_0.888_pr_0.558.pth'
model = load_model(model_path, GNNStack)

TypeError: GNNStack.__init__() missing 11 required positional arguments: 'gnn_model_type', 'num_layers', 'groups', 'pool_ratio', 'kern_size', 'in_dim', 'hidden_dim', 'out_dim', 'seq_len', 'num_nodes', and 'num_classes'

In [None]:
# dataset
train_loader, val_loader, num_nodes, seq_length, num_classes = get_default_train_val_test_loader(args)