In [None]:
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.collections as collections
import pickle, os, warnings

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, roc_auc_score, roc_curve
from sklearn.linear_model import LogisticRegression
from random import randint

import cv2
from torch.autograd import Function
from scipy.interpolate import interp1d

from linformer import Linformer
from vit_pytorch.efficient import ViT

warnings.filterwarnings("ignore")


# Prespecifications

task = 'classification' # either 'classification' or 'regression'
invasive = False # either True or False
multi = True # either True or False
pred_lag = 300 # 300 for 5-min, 600 for 10-min, 900 for 15-min prediction or others

mtx_width = 300 # 50*60
mtx_height = 3000//mtx_width
patch_size = 10
transformer_last_dim = 32

lin_depth = 24#24 
lin_heads = 16#16
lin_k = 64#32

if multi == True:
    channels = 4 if invasive == True else 3
else:
    channels = 1

cuda_number = 0 # -1 for multi GPU support
num_workers = 0
batch_size = 128
max_epoch = 200

train_ratio = 0.6 # Size for training dataset
valid_ratio = 0.1 # Size for validation dataset
test_ratio = 0.3 # Size for test dataset

random_key = randint(0, 100000) # Prespecify seed number if needed
#random_key = 35593

dr_classification = 0.3 # Drop out ratio for classification model
dr_regression = 0.0 # Drop out ratio for regression model

csv_dir = './model/'+str(random_key)+'/csv/'
pt_dir = './model/'+str(random_key)+'/pt/'

if not ( os.path.isdir( csv_dir ) ):
    os.makedirs ( os.path.join ( csv_dir ) )
    
if not ( os.path.isdir( pt_dir ) ):
    os.makedirs ( os.path.join ( pt_dir ) )


# Establish dataset

class dnn_dataset(torch.utils.data.Dataset):
    def __init__(self, abp, ecg, ple, co2, target, invasive, multi):
        self.invasive, self.multi = invasive, multi
        self.abp, self.ecg, self.ple, self.co2 = abp, ecg, ple, co2
        self.target = target
        
    def __getitem__(self, index):
        if self.invasive == True:
            if self.multi == True: # Invasive multi-channel model
                return np.float32( np.vstack (( np.array ( self.abp[index] ),
                                                np.array ( self.ecg[index] ),
                                                np.array ( self.ple[index] ),
                                                np.array ( self.co2[index] ) ) ) ), np.float32(self.target[index])
            else: # Invasive mono-channel model (arterial pressure-only model)
                return np.float32( np.array ( self.abp[index] ) ), np.float32(self.target[index])       
        else:
            if self.multi == True: # Non-invasive multi-channel model
                return np.float32( np.vstack (( np.array ( self.ecg[index] ),
                                                np.array ( self.ple[index] ),
                                                np.array ( self.co2[index] ) ) ) ), np.float32(self.target[index])
            else: # Non-invasive mono-channel model (photoplethysmography-only model)
                return np.float32( np.array ( self.ple[index] ) ), np.float32(self.target[index])

    def __len__(self):
        return len(self.target)


#################### Vision Transformer ##################


    
efficient_transformer = Linformer(
    dim=transformer_last_dim,
    seq_len = (mtx_width//patch_size)*(mtx_height//patch_size) + 1,  
    depth=lin_depth, #12
    heads=lin_heads, #8
    k=lin_k #64
)


model = ViT(
    dim=transformer_last_dim,    #128
    image_size= max(mtx_width,mtx_height),
    patch_size= patch_size, 
    num_classes=2,
    transformer=efficient_transformer,
    channels=channels, 
)

###########################################################



# Read dataset

processed_dir = './processed_new/'

file_list = np.char.split ( np.array ( os.listdir(processed_dir) ), '.' )
case_list = []
for caseid in file_list:
    case_list.append ( int ( caseid[0] ) )
print ( 'N of total cases: {}'.format ( len ( case_list ) ) )

cases = {}
cases['train'], cases['valid+test'] = train_test_split ( case_list,
                                                        test_size=(valid_ratio+test_ratio),
                                                        random_state=random_key )
cases['valid'], cases['test'] = train_test_split ( cases['valid+test'],
                                                  test_size=(test_ratio/(valid_ratio+test_ratio)),
                                                  random_state=random_key )

for phase in [ 'train', 'valid', 'test' ]:
    print ( "- N of {} cases: {}".format(phase, len(cases[phase])) )

for idx, caseid in enumerate(case_list):
    filename = processed_dir + str ( caseid ) + '.pkl'
    with open(filename, 'rb') as handle:
        data = pickle.load(handle)
        data['caseid'] = [ caseid ] * len ( data['abp'] )
        
        raw_records = raw_records.append ( pd.DataFrame ( data ) ) if idx > 0 else pd.DataFrame ( data )
#########################################
############# nan 값 제거 ##############

nan_list = set()
for x in ['abp','ecg','ple','co2']:
    j = 0
    for i in raw_records[x]:
        if np.isnan(i).any() == True:
            nan_list.add(j)
        j += 1
    
nan_list = list(nan_list)
indexes_to_keep = set(range(raw_records.shape[0])) - set(nan_list)
raw_records = raw_records.take(list(indexes_to_keep))

#########################################
raw_records = raw_records[(raw_records['map']>=20)&(raw_records['map']<=160)].reset_index(drop=True) # Exclude abnormal range


# Define loader and model

if task == 'classification':
    task_target = 'hypo'
    #criterion = nn.BCELoss()
    #criterion = nn.CrossEntropyLoss()
    criterion = nn.BCEWithLogitsLoss()
else:
    task_target = 'map'
    criterion = nn.MSELoss()

print ( '\n===== Task: {}, Seed: {} =====\n'.format ( task, random_key ) )
print ( 'Invasive: {}\nMulti: {}\nPred lag: {}\n'.format ( invasive, multi, pred_lag ))

records = raw_records.loc[ ( raw_records['input_length']==30 ) &
                            ( raw_records['pred_lag']==pred_lag ) ]

records = records [ records.columns.tolist()[-1:] + records.columns.tolist()[:-1] ]
print ( 'N of total records: {}'.format ( len ( records ) ))

split_records = {}
for phase in ['train', 'valid', 'test']:
    split_records[phase] = records[records['caseid'].isin(cases[phase])].reset_index(drop=True)
    print ('- N of {} records: {}'.format ( phase, len ( split_records[phase] )))

print ( '' )

ext = {}
for phase in [ 'train', 'valid', 'test' ]:
    ext[phase] = {}
    for x in [ 'abp', 'ecg', 'ple', 'co2', 'hypo', 'map' ]:
        ext[phase][x] = split_records[phase][x]

dataset, loader = {}, {}
epoch_loss, epoch_auc = {}, {}

for phase in [ 'train', 'valid', 'test' ]:
    
#     # reshape 3000 ---> mtx_height * mtx_weight
    ext[phase]['abp'] = [i.reshape(mtx_height,mtx_width) for i in ext[phase]['abp']]
    ext[phase]['ecg'] = [i.reshape(mtx_height,mtx_width) for i in ext[phase]['ecg']]
    ext[phase]['ple'] = [i.reshape(mtx_height,mtx_width) for i in ext[phase]['ple']]
    ext[phase]['co2'] = [i.reshape(mtx_height,mtx_width) for i in ext[phase]['co2']]
    
    
    dataset[phase] = dnn_dataset ( ext[phase]['abp'],
                                    ext[phase]['ecg'],
                                    ext[phase]['ple'],
                                    ext[phase]['co2'],
                                    ext[phase][task_target],
                                    invasive = invasive, multi = multi )
    loader[phase] = torch.utils.data.DataLoader(dataset[phase],
                                                batch_size=batch_size,
                                                num_workers=num_workers,
                                                shuffle = True if phase == 'train' else False )
    epoch_loss[phase], epoch_auc[phase] = [], []

#Model development and validation

torch.cuda.set_device(cuda_number)
#DNN = Net( task = task, invasive = invasive, multi = multi )
DNN = model
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
DNN = DNN.to(device)

optimizer = torch.optim.Adam(DNN.parameters(), lr=0.0005)
n_epochs = max_epoch

best_loss, best_auc = 99999.99999, 0.0

for epoch in range(n_epochs):

    start_time = time.time()
    
    target_stack, output_stack = {}, {}
    current_loss, current_auc = {}, {}
    for phase in [ 'train', 'valid', 'test' ]:
        target_stack[phase], output_stack[phase] =  [], []
        current_loss[phase], current_auc[phase] = 0.0, 0.0

    DNN.train()
    
    for dnn_inputs, dnn_target in loader['train']:
        
        dnn_inputs, dnn_target = dnn_inputs.to(device), dnn_target.to(device)
        optimizer.zero_grad()
        dnn_inputs = dnn_inputs.reshape(len(dnn_inputs),channels,mtx_height,mtx_width)
        dnn_output = DNN( dnn_inputs )
        #import pdb; pdb.set_trace()
        loss = criterion(dnn_output.T[0], dnn_target)
        current_loss['train'] += loss.item()*dnn_inputs.size(0)

        loss.backward()
        optimizer.step()

    current_loss['train'] = current_loss['train']/len(loader['train'].dataset)
    epoch_loss['train'].append ( current_loss['train'] ) 
    
    for phase in [ 'valid', 'test']:
    
        DNN.eval()
        with torch.no_grad():
            for dnn_inputs, dnn_target in loader[phase]:

                dnn_inputs, dnn_target = dnn_inputs.to(device), dnn_target.to(device)
                dnn_inputs = dnn_inputs.reshape(len(dnn_inputs),channels,mtx_height,mtx_width)
                dnn_output = DNN( dnn_inputs )
                target_stack[phase].extend ( np.array ( dnn_target.cpu() ) )
                output_stack[phase].extend ( np.array ( dnn_output.cpu().T[0] ) )
                loss = criterion((dnn_output.T[0]), dnn_target)
                current_loss[phase] += loss.item()*dnn_inputs.size(0)

            current_loss[phase] = current_loss[phase]/len(loader[phase].dataset)
            epoch_loss[phase].append ( current_loss[phase] ) 

    if task == 'classification':
        log_label = {}
        for phase in ['valid', 'test']:
            current_auc[phase] = roc_auc_score ( target_stack[phase], output_stack[phase] )
            epoch_auc[phase].append ( current_auc[phase] )
    else:
        reg_output, reg_target, reg_label = {}, {}, {}
        for phase in ['valid', 'test']:
            reg_output[phase] = np.array(output_stack[phase]).reshape(-1,1)
            reg_target[phase] = np.array(target_stack[phase]).reshape(-1,1)
            reg_label[phase] = np.where(reg_target[phase]<65, 1, 0)
            method = LogisticRegression(solver='liblinear')
            method.fit(reg_output[phase], reg_label[phase]) # Model fitting
            current_auc[phase] = roc_auc_score (reg_label[phase], method.predict_proba(reg_output[phase]).T[1])
            epoch_auc[phase].append ( current_auc[phase] )
            
            
    label_invasive = 'invasive' if invasive == True else 'noninvasive'
    label_multi = 'multi' if multi == True else 'mono'
    label_pred_lag = str ( int ( pred_lag / 60 ) ) + 'min'

    filename = task+'_'+label_invasive+'_'+label_multi+'_'+label_pred_lag

    pd.DataFrame ( { 'train_loss':epoch_loss['train'],
                        'valid_loss':epoch_loss['valid'],
                        'test_loss':epoch_loss['test'],
                        'valid_auc':epoch_auc['valid'],
                        'test_auc':epoch_auc['test'] } ).to_csv(csv_dir+filename+'.csv')

    best = ''
    if task == 'regression' and abs(current_loss['valid']) < abs(best_loss):
        best = '< ! >'
        last_saved_epoch = epoch
        best_loss = abs(current_loss['valid'])
        #torch.save(DNN.state_dict(), pt_dir+filename+'_epoch_best.pt' )
    elif task == 'classification' and abs(current_auc['valid']) > abs(best_auc):
        best = '< ! >'
        last_saved_epoch = epoch
        best_auc = abs(current_auc['valid'])
        #torch.save(DNN.state_dict(), pt_dir+filename+'_epoch_best.pt' )

    #torch.save(DNN.state_dict(),pt_dir+filename+'_epoch_{0:03d}.pt'.format(epoch+1) )
    
    print ( 'Epoch [{:3d}] Train loss: {:.4f} / Valid loss: {:.4f} (AUC: {:.4f}) / Test loss: {:.4f} (AUC: {:.4f}) {}'.format
            ( epoch+1,
            current_loss['train'],
            current_loss['valid'], current_auc['valid'],
            current_loss['test'], current_auc['test'], best ) )
    print("Time: {:.4f}sec".format((time.time() - start_time)))

N of total cases: 3228
- N of train cases: 1936
- N of valid cases: 323
- N of test cases: 969

===== Task: classification, Seed: 93236 =====

Invasive: False
Multi: True
Pred lag: 300

N of total records: 25706
- N of train records: 15338
- N of valid records: 2558
- N of test records: 7810

Epoch [  1] Train loss: 0.6737 / Valid loss: 0.6365 (AUC: 0.6285) / Test loss: 0.6606 (AUC: 0.6164) < ! >
Time: 14.1301sec
Epoch [  2] Train loss: 0.6468 / Valid loss: 0.6373 (AUC: 0.6429) / Test loss: 0.6503 (AUC: 0.6353) < ! >
Time: 13.9466sec
Epoch [  3] Train loss: 0.6280 / Valid loss: 0.6191 (AUC: 0.6763) / Test loss: 0.6358 (AUC: 0.6737) < ! >
Time: 13.9762sec
Epoch [  4] Train loss: 0.6096 / Valid loss: 0.6004 (AUC: 0.7182) / Test loss: 0.6326 (AUC: 0.7180) < ! >
Time: 13.9909sec
Epoch [  5] Train loss: 0.5834 / Valid loss: 0.5763 (AUC: 0.7366) / Test loss: 0.5850 (AUC: 0.7496) < ! >
Time: 13.9866sec
Epoch [  6] Train loss: 0.5794 / Valid loss: 0.5710 (AUC: 0.7477) / Test loss: 0.5842 (AUC: