In [1]:
# View and modify the working path
import os
from google.colab import drive

# View current working directory
print("Current Working Directory:", os.getcwd())

# Mount Google Drive
drive.mount('/content/gdrive')

Current Working Directory: /content
Mounted at /content/gdrive


In [2]:
!pip install dill

Collecting dill
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: dill
Successfully installed dill-0.3.8


In [3]:
!pip install dnc

Collecting dnc
  Downloading dnc-1.1.0-py3-none-any.whl (20 kB)
Collecting flann (from dnc)
  Downloading flann-1.6.13-py3-none-any.whl (24 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->dnc)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->dnc)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->dnc)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch->dnc)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch->dnc)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch->dnc)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux

In [4]:
# !pip install virtualenv

In [5]:
# https://saturncloud.io/blog/how-to-install-a-library-permanently-in-colab/
# !virtualenv /content/gdrive/MyDrive/GAMENet/virtual_env/

In [6]:
# !source /content/gdrive/MyDrive/GAMENet/virtual_env/bin/activate; pip install dnc

In [7]:
import sys
# add the path of the virtual environmentsite-packages to colab system path
sys.path.append("/content/gdrive/MyDrive/GAMENet/virtual_env/lib/python3.10/site-packages")

In [8]:
import dnc

In [9]:
# View and modify the working path
import os
from google.colab import drive

# View current working directory
print("Current Working Directory:", os.getcwd())

# Mount Google Drive
drive.mount('/content/gdrive')

# Change working directory to your file position
path = "/content/gdrive/My Drive/GAMENet/code_/baseline"
os.chdir(path)

# Confirm the change
print("Working Directory:", os.getcwd())

Current Working Directory: /content
Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
Working Directory: /content/gdrive/.shortcut-targets-by-id/1HvUJwbm1gmi_iRV21oB-A5JZClBwM9Eb/GAMENet/code_/baseline


In [10]:
import models
import util

In [11]:
import torch
import torch.nn as nn
from sklearn.metrics import jaccard_score, roc_auc_score, precision_score, f1_score, average_precision_score
import numpy as np
import dill
import time
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
import os
import torch.nn.functional as F
import random
from collections import defaultdict

import sys
sys.path.append("..")
from models import Leap
from util import llprint, sequence_metric, sequence_output_process, ddi_rate_score, get_n_params

torch.manual_seed(1203)

model_name = 'Leap'
resume_name = ''



def eval(model, data_eval, voc_size, epoch):
    model.eval()

    # Initialize containers for aggregated metrics
    aggregated_metrics = {
        'ja': [], 'prauc': [], 'avg_p': [], 'avg_r': [], 'avg_f1': []
    }
    records = []
    med_cnt = 0
    visit_cnt = 0

    for step, input in enumerate(data_eval):
        y_gt = []
        y_pred = []
        y_pred_prob = []
        y_pred_label = []

        for adm in input:
            y_gt_tmp = np.zeros(voc_size[2])
            y_gt_tmp[adm[2]] = 1
            y_gt.append(y_gt_tmp)

            output_logits = model(adm).detach().cpu().numpy()
            out_list, sorted_predict = sequence_output_process(output_logits, [voc_size[2], voc_size[2]+1])

            y_pred_label.append(sorted(sorted_predict))
            y_pred_prob.append(np.mean(output_logits[:, :-2], axis=0))

            y_pred_tmp = np.zeros(voc_size[2])
            y_pred_tmp[out_list] = 1
            y_pred.append(y_pred_tmp)

            visit_cnt += 1
            med_cnt += len(sorted_predict)

        records.append(y_pred_label)

        # Process each admission individually for metrics
        for gt, pred, pred_prob, label in zip(y_gt, y_pred, y_pred_prob, y_pred_label):
            adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = sequence_metric(
                np.array([gt]), np.array([pred]), np.array([pred_prob]), [label]
            )
            aggregated_metrics['ja'].append(adm_ja)
            aggregated_metrics['prauc'].append(adm_prauc)
            aggregated_metrics['avg_p'].append(adm_avg_p)
            aggregated_metrics['avg_r'].append(adm_avg_r)
            aggregated_metrics['avg_f1'].append(adm_avg_f1)

        llprint(f'\rEval--Epoch: {epoch}, Step: {step}/{len(data_eval)}')

    # Calculate and print the mean of the aggregated metrics
    ddi_rate = ddi_rate_score(records, path='../data/ddi_A_final.pkl')
    final_metrics = {metric: np.mean(values) for metric, values in aggregated_metrics.items()}
    llprint(f'\tDDI Rate: {ddi_rate:.4f}, Jaccard: {final_metrics["ja"]:.4f}, PRAUC: {final_metrics["prauc"]:.4f}, '
            f'AVG_PRC: {final_metrics["avg_p"]:.4f}, AVG_RECALL: {final_metrics["avg_r"]:.4f}, '
            f'AVG_F1: {final_metrics["avg_f1"]:.4f}\n')
    print(f'avg med {med_cnt / visit_cnt:.4f}')

    return ddi_rate, final_metrics['ja'], final_metrics['prauc'], final_metrics['avg_p'], final_metrics['avg_r'], final_metrics['avg_f1']




# def eval(model, data_eval, voc_size, epoch):
#     print('')
#     model.eval()

#     ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)]
#     records = []
#     med_cnt = 0
#     visit_cnt = 0
#     for step, input in enumerate(data_eval):
#         y_gt = []
#         y_pred = []
#         y_pred_prob = []
#         y_pred_label = []
#         for adm in input:
#             y_gt_tmp = np.zeros(voc_size[2])
#             y_gt_tmp[adm[2]] = 1
#             y_gt.append(y_gt_tmp)

#             output_logits = model(adm)
#             output_logits = output_logits.detach().cpu().numpy()

#             out_list, sorted_predict = sequence_output_process(output_logits, [voc_size[2], voc_size[2]+1])

#             y_pred_label.append(sorted(sorted_predict))

#             # Ensure y_pred_prob has a consistent length for each prediction
#             output_prob_mean = np.mean(output_logits[:, :-2], axis=0)
#             if output_prob_mean.shape[0] != voc_size[2]:
#                 raise ValueError(f"Expected y_pred_prob to have length {voc_size[2]}, got {output_prob_mean.shape[0]}")
#             y_pred_prob.append(output_prob_mean)

#             y_pred_tmp = np.zeros(voc_size[2])
#             y_pred_tmp[out_list] = 1
#             y_pred.append(y_pred_tmp)
#             visit_cnt += 1
#             med_cnt += len(sorted_predict)
#         records.append(y_pred_label)

#         # Debugging: Print shapes before conversion to arrays
#         print("Shapes - y_gt:", np.array(y_gt).shape, "y_pred:", np.array(y_pred).shape, "y_pred_prob:", np.array(y_pred_prob).shape)

#         adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = sequence_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob), np.array(y_pred_label))
#         ja.append(adm_ja)
#         prauc.append(adm_prauc)
#         avg_p.append(adm_avg_p)
#         avg_r.append(adm_avg_r)
#         avg_f1.append(adm_avg_f1)
#         llprint('\rEval--Epoch: %d, Step: %d/%d' % (epoch, step, len(data_eval)))


# def eval(model, data_eval, voc_size, epoch):
#     # evaluate
#     print('')
#     model.eval()

#     ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)]
#     records = []
#     med_cnt = 0
#     visit_cnt = 0
#     for step, input in enumerate(data_eval):
#         y_gt = []
#         y_pred = []
#         y_pred_prob = []
#         y_pred_label = []
#         for adm in input:
#             y_gt_tmp = np.zeros(voc_size[2])
#             y_gt_tmp[adm[2]] = 1
#             y_gt.append(y_gt_tmp)

#             output_logits = model(adm)
#             output_logits = output_logits.detach().cpu().numpy()

#             out_list, sorted_predict = sequence_output_process(output_logits, [voc_size[2], voc_size[2]+1])

#             y_pred_label.append(sorted(sorted_predict))
#             y_pred_prob.append(np.mean(output_logits[:, :-2], axis=0))

#             y_pred_tmp = np.zeros(voc_size[2])
#             y_pred_tmp[out_list] = 1
#             y_pred.append(y_pred_tmp)
#             visit_cnt += 1
#             med_cnt += len(sorted_predict)
#         records.append(y_pred_label)

#         adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = sequence_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob), np.array(y_pred_label))
#         ja.append(adm_ja)
#         prauc.append(adm_prauc)
#         avg_p.append(adm_avg_p)
#         avg_r.append(adm_avg_r)
#         avg_f1.append(adm_avg_f1)
#         llprint('\rEval--Epoch: %d, Step: %d/%d' % (epoch, step, len(data_eval)))

    # ddi rate
    ddi_rate = ddi_rate_score(records, path='../data/ddi_A_final.pkl')

    llprint('\tDDI Rate: %.4f, Jaccard: %.4f,  PRAUC: %.4f, AVG_PRC: %.4f, AVG_RECALL: %.4f, AVG_F1: %.4f\n' % (
        ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1)
    ))
    print('avg med', med_cnt / visit_cnt)
    return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1)

def main():
    if not os.path.exists(os.path.join("saved", model_name)):
        os.makedirs(os.path.join("saved", model_name))

    data_path = '../data/records_final.pkl'
    voc_path = '../data/voc_final.pkl'
    # device = torch.device('cuda:0')
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    data = dill.load(open(data_path, 'rb'))
    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc['med_voc']


    split_point = int(len(data) * 2 / 3)
    data_train = data[:split_point]
    eval_len = int(len(data[split_point:]) / 2)
    data_test = data[split_point:split_point + eval_len]
    data_eval = data[split_point+eval_len:]
    voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word), len(med_voc.idx2word))

    EPOCH = 30
    LR = 0.0002
    TEST = False
    END_TOKEN = voc_size[2] + 1

    model = Leap(voc_size, device=device)
    if TEST:
        model.load_state_dict(torch.load(open(os.path.join("saved", model_name, resume_name), 'rb')))
        # pass

    model.to(device=device)
    print('parameters', get_n_params(model))

    optimizer = Adam(model.parameters(), lr=LR)

    if TEST:
        eval(model, data_test, voc_size, 0)
    else:
        history = defaultdict(list)
        for epoch in range(EPOCH):
            loss_record = []
            start_time = time.time()
            model.train()
            for step, input in enumerate(data_train):
                for adm in input:
                    loss_target = adm[2] + [END_TOKEN]
                    output_logits = model(adm)
                    loss = F.cross_entropy(output_logits, torch.LongTensor(loss_target).to(device))

                    loss_record.append(loss.item())

                    optimizer.zero_grad()
                    loss.backward(retain_graph=True)
                    optimizer.step()

                llprint('\rTrain--Epoch: %d, Step: %d/%d' % (epoch, step, len(data_train)))

            ddi_rate, ja, prauc, avg_p, avg_r, avg_f1 = eval(model, data_eval, voc_size, epoch)
            history['ja'].append(ja)
            history['ddi_rate'].append(ddi_rate)
            history['avg_p'].append(avg_p)
            history['avg_r'].append(avg_r)
            history['avg_f1'].append(avg_f1)
            history['prauc'].append(prauc)

            end_time = time.time()
            elapsed_time = (end_time - start_time) / 60
            llprint('\tEpoch: %d, Loss1: %.4f, One Epoch Time: %.2fm, Appro Left Time: %.2fh\n' % (epoch,
                                                                                                np.mean(loss_record),
                                                                                                elapsed_time,
                                                                                                elapsed_time * (
                                                                                                            EPOCH - epoch - 1)/60))

            torch.save(model.state_dict(), open( os.path.join('saved', model_name, 'Epoch_%d_JA_%.4f_DDI_%.4f.model' % (epoch, ja, ddi_rate)), 'wb'))
            print('')

        dill.dump(history, open(os.path.join('saved', model_name, 'history.pkl'), 'wb'))
        # test
        torch.save(model.state_dict(), open(
            os.path.join('saved', model_name, 'final.model'), 'wb'))

def fine_tune(fine_tune_name=''):
    data_path = '../data/records_final.pkl'
    voc_path = '../data/voc_final.pkl'
    device = torch.device('cuda:0')

    data = dill.load(open(data_path, 'rb'))
    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc['med_voc']
    ddi_A = dill.load(open('../data/ddi_A_final.pkl', 'rb'))

    split_point = int(len(data) * 2 / 3)
    data_train = data[:split_point]
    eval_len = int(len(data[split_point:]) / 2)
    data_test = data[split_point:split_point + eval_len]
    # data_eval = data[split_point+eval_len:]
    voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word), len(med_voc.idx2word))

    model = Leap(voc_size, device=device)
    model.load_state_dict(torch.load(open(os.path.join("saved", model_name, fine_tune_name), 'rb')))
    model.to(device)

    EPOCH = 30
    LR = 0.0001
    END_TOKEN = voc_size[2] + 1

    optimizer = Adam(model.parameters(), lr=LR)
    ddi_rate_record = []
    for epoch in range(1):
        loss_record = []
        start_time = time.time()
        random_train_set = [ random.choice(data_train) for i in range(len(data_train))]
        for step, input in enumerate(random_train_set):
            model.train()
            K_flag = False
            for adm in input:
                target = adm[2]
                output_logits = model(adm)
                out_list, sorted_predict = sequence_output_process(output_logits.detach().cpu().numpy(), [voc_size[2], voc_size[2] + 1])

                inter = set(out_list) & set(target)
                union = set(out_list) | set(target)
                jaccard = 0 if union == 0 else len(inter) / len(union)
                K = 0
                for i in out_list:
                    if K == 1:
                        K_flag = True
                        break
                    for j in out_list:
                        if ddi_A[i][j] == 1:
                            K = 1
                            break

                loss = -jaccard * K * torch.mean(F.log_softmax(output_logits, dim=-1))


                loss_record.append(loss.item())

                optimizer.zero_grad()
                loss.backward(retain_graph=True)
                optimizer.step()

            llprint('\rTrain--Epoch: %d, Step: %d/%d' % (epoch, step, len(data_train)))

            if K_flag:
                ddi_rate, ja, prauc, avg_p, avg_r, avg_f1 = eval(model, data_test, voc_size, epoch)


                end_time = time.time()
                elapsed_time = (end_time - start_time) / 60
                llprint('\tEpoch: %d, Loss1: %.4f, One Epoch Time: %.2fm, Appro Left Time: %.2fh\n' % (epoch,
                                                                                               np.mean(loss_record),
                                                                                               elapsed_time,
                                                                                               elapsed_time * (
                                                                                                       EPOCH - epoch - 1) / 60))

                torch.save(model.state_dict(),
                   open(os.path.join('saved', model_name, 'fine_Epoch_%d_JA_%.4f_DDI_%.4f.model' % (epoch, ja, ddi_rate)),
                        'wb'))
                print('')

    # test
    torch.save(model.state_dict(), open(
        os.path.join('saved', model_name, 'final.model'), 'wb'))



if __name__ == '__main__':
    main()
    #fine_tune(fine_tune_name='Epoch_26_JA_0.4465_DDI_0.0723.model')

parameters 437012
Eval--Epoch: 0, Step: 1058/1059	DDI Rate: 0.0749, Jaccard: 0.3503, PRAUC: 0.5662, AVG_PRC: 0.5240, AVG_RECALL: 0.5306, AVG_F1: 0.5099
avg med 14.4344
	Epoch: 0, Loss1: 3.2017, One Epoch Time: 6.20m, Appro Left Time: 2.99h

Eval--Epoch: 1, Step: 1058/1059	DDI Rate: 0.0640, Jaccard: 0.3643, PRAUC: 0.5650, AVG_PRC: 0.5229, AVG_RECALL: 0.5617, AVG_F1: 0.5239
avg med 15.3778
	Epoch: 1, Loss1: 2.9200, One Epoch Time: 5.45m, Appro Left Time: 2.54h

Eval--Epoch: 2, Step: 1058/1059	DDI Rate: 0.0549, Jaccard: 0.3755, PRAUC: 0.5607, AVG_PRC: 0.5684, AVG_RECALL: 0.5399, AVG_F1: 0.5338
avg med 13.6428
	Epoch: 2, Loss1: 2.8123, One Epoch Time: 4.99m, Appro Left Time: 2.25h

Eval--Epoch: 3, Step: 1058/1059	DDI Rate: 0.0412, Jaccard: 0.3785, PRAUC: 0.5583, AVG_PRC: 0.5533, AVG_RECALL: 0.5577, AVG_F1: 0.5356
avg med 14.5204
	Epoch: 3, Loss1: 2.7471, One Epoch Time: 5.14m, Appro Left Time: 2.23h

Eval--Epoch: 4, Step: 1058/1059	DDI Rate: 0.0441, Jaccard: 0.3797, PRAUC: 0.5610, AVG_PRC:

KeyboardInterrupt: 