In [70]:
import pyedflib
import numpy as np
import scipy
from scipy import signal, misc
from matplotlib import pyplot as plt 
from glob import glob
import os
import pandas as pd
from scipy.signal import butter, filtfilt, iirnotch
from sklearn.metrics import confusion_matrix
from scipy.stats import sem
from scipy import mean
import seaborn as sns
from scipy import stats

from sklearn.model_selection import StratifiedKFold, KFold, train_test_split
from tensorflow import keras
from tensorflow.keras.datasets import mnist
from sklearn.metrics import classification_report, f1_score

import tensorflow.keras
from tensorflow.keras import layers
import numpy as np
from tensorflow.keras import regularizers
from tensorflow.keras.callbacks import TensorBoard

from __future__ import print_function
from scipy.stats import norm
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import CSVLogger, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import Input, Dense, Lambda, Reshape
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras import metrics
from tensorflow.keras.datasets import mnist
from tensorflow.keras.constraints import max_norm
from sklearn.manifold import TSNE
from tensorflow.keras import models
from tensorflow.keras.layers import Concatenate, AveragePooling2D, BatchNormalization, Conv2D, Conv2DTranspose, Dense, Input, Reshape, Flatten

In [71]:
# os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

# or 
GPU = 0
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
    tf.config.experimental.set_memory_growth(physical_devices[GPU], True)

In [72]:
T = 800 # Time point in each segment (freq. rate x 4s ec [4 sec/window])
C = 3
D = 6 # Dim (fs bands)
FS = 200 # Hz

In [73]:
def segment_data(Folder_name, label=0.):
    all_data_segment = glob('../LFP_Bank/'+Folder_name+'/*_200Hz.npy')

    data_all = []
    for indivi_file_segment in all_data_segment:
        print (indivi_file_segment)
        load_data = np.load(indivi_file_segment, allow_pickle=True)
        X, Y = load_data.shape
        N = int(Y/T) # Number of segment(s)
        data_new = []
        for i in range (N):
            data = load_data[:, i*T:((i+1)*T)]
            data_new.append(data)
        data_output = np.array(data_new)
        data_all.append(data_output)
        #print('=================================================================')
        #print(data_output)
        #print('=================================================================')
        #print('=============')
        #print(data_output.shape)
        #print('=============')
        #print('==============================================================================================================')
    X = np.array(data_all)
    y = np.full((X.shape[0], X.shape[1]), label)
    return X, y

def butter_bandpass_filter(data, lowcut, highcut, fs, order):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    y = filtfilt(b, a, data)
    return y

In [74]:
def triplet_loss(margin = 1.0):
    def inner_triplet_loss_objective(y_true, y_pred):
        labels = y_true
        embeddings = y_pred
        return tfa.losses.triplet_semihard_loss(y_true = labels, y_pred = embeddings, margin = margin)
    return inner_triplet_loss_objective

def dummy_loss(margin = 1.0):
    def inner(y_true, y_pred):
        return 0
    return inner

In [75]:
def build(num_class, F1=16, F2=8, P=(1, 10)):
    'encoder'
    encoder_input = Input(shape=(D, T, C))
    x = layers.Conv2D(F1, P, activation='elu', padding='same')(encoder_input)
    x = BatchNormalization(axis=3, epsilon=1e-05, momentum=0.1)(x)
    x = layers.AveragePooling2D(P, padding='same')(x)
    x = layers.Conv2D(F2, P, activation='elu', padding='same')(x)
    x = BatchNormalization(axis=3, epsilon=1e-05, momentum=0.1)(x)
    x = layers.AveragePooling2D(P, padding='same')(x)
#     x = layers.Conv2D(8, (1, 50), activation='elu', padding='same')(x)
#     x = layers.AveragePooling2D((1, 2), padding='same')(x)
    x = Flatten()(x)
    encoder_output = Dense(FS)(x)
    ### the end of the layers
    encoder        = models.Model(inputs=encoder_input, outputs=encoder_output, name='encoder')
    encoder.summary()
    
    'decoder'
    decoder_input = Input(shape=(FS), name='decoder_input')
    x = Dense(D*(T//100)*F2, activation='elu')(decoder_input) 
    x = Reshape((D, T//100, F2))(x)
    x = layers.Conv2D(F2, P, activation='elu', padding='same')(x)
    x = layers.UpSampling2D(P)(x)
    x = layers.Conv2D(F1, P, activation='elu', padding='same')(x)
    x = layers.UpSampling2D(P)(x)
#     x = layers.Conv2D(16, (1, 10), activation='elu', padding='same')(x)
#     x = layers.UpSampling2D((1, 1))(x)
    decoder_output = layers.Conv2D(C, P, activation='elu', padding='same')(x)
    # the end of the layers 
    decoder        = models.Model(inputs=decoder_input, outputs=decoder_output, name='decoder')
    decoder.summary()
    
    'Build the computation graph for training'
    latent  = encoder(encoder_input)
    train_xr= decoder(latent)
    z       = Dense(units=num_class, activation='softmax', kernel_constraint=max_norm(0.5), name='classifier')(latent)

    return models.Model(inputs = [encoder_input], outputs = [train_xr, z],  name = 'AE')

In [76]:
# all_folders = ["00_Blank Syrup", '00_Control group',"00_PEG+Saline (control)", '01_PEG+1mg METH', '02_PEG+5mg METH',
#                '03_Morphine I (5 mg, ip)', '04_Morphine II (15 mg, ip)', '05_Cannabis project (THC 50 mg, ip)',
#                '06_MDMA project (10 mg, ip)', '07_L-DOPA project (25 mg in dw, po)', '08_Ephedrine group (10 mg, ip)',
#                '09_Pseudoephedrine group (50 mg, po)', '010_Pseudoephedrine group (100 mg, po)',
#                '011_Ketamine project (50 mg, ip)', '012_Lorazepam (1 mg, po)', '013_Lorazepam (5 mg, po)',
#                '014_Fluoxetine project (20 mg, po)', '015_KT alkaloid (60 mg, po)', '016_KT alkaloid (80 mg, po)',
#                '017_KT alkaloid (212 mg, cont equal to 10 mg per kg mitragynine, po)',
#                '018_Kratom (water extract) (cont equal to 10 mg per kg mitragynine, po)',
#                '019_Kratom Syrup (contained 10 mg mitragynine, po)', '020_Haloperidol (0_5 mg, po)',
#                '021_Haloperidol (1 mg, po)', '022_Haloperidol+Saline', '023_Haloperidol+5mg METH',
#                '024_Morphine II (15 mg, ip)+Naloxone (20 mg, ip)', '025_Jasmine project']

In [77]:
# X, y = final_output, np.hstack(([0]*10, [1]*10, [2]*10, [3]*8, [4]*9, [5]*8, [6]*7, [7]*12, [8]*7, [9]*12, [10]*12, [11]*7,
#                                 [12]*10, [13]*7, [14]*9, [15]*7, [16]*8, [17]*10, [18]*9, [19]*8, [20]*9, [21]*7, [22]*7,
#                                 [23]*10, [24]*10, [25]*9))

In [78]:
# all_folders = ['00_PEG+Saline (control)', 
#                '02_PEG+5mg METH',
#                '04_Morphine II (15 mg, ip)',
#                '05_Cannabis project (THC 50 mg, ip)',
#                '06_MDMA project (10 mg, ip)',
#                '07_L-DOPA project (25 mg in dw, po)', 
#                '011_Ketamine project (50 mg, ip)', 
#                '013_Lorazepam (5 mg, po)',
#                '014_Fluoxetine project (20 mg, po)',
#                '015_KT alkaloid (60 mg, po)']
# #                '019_Kratom Syrup (contained 10 mg mitragynine, po)',
# #                '025_Jasmine project']

In [79]:
all_folders = ['00_Control group',
               '04_Morphine II (15 mg, ip)']

In [80]:
all_segmented_data = []
all_label = []
for i, fname in enumerate(all_folders):
    print (i)
    X_, y_ = segment_data(Folder_name=fname, label=i)
    all_segmented_data.append(X_)
    all_label.append(y_)
final_output = np.concatenate(np.array(all_segmented_data), axis = 0) # concatenate only axis 0
final_label = np.concatenate(np.array(all_label), axis = 0) # concatenate only axis 0
final_output.shape
final_label.shape

0
../LFP_Bank/00_Control group/2016-08-23--B10--Baseline_200Hz.npy
../LFP_Bank/00_Control group/2016-08-24--B2--Baseline_200Hz.npy
../LFP_Bank/00_Control group/2016-08-24--B3--Baseline_200Hz.npy
../LFP_Bank/00_Control group/2016-08-23--B6--Baseline_200Hz.npy
../LFP_Bank/00_Control group/170759--B6--[Saline]_200Hz.npy
../LFP_Bank/00_Control group/070759--B2--[Saline]_200Hz.npy
../LFP_Bank/00_Control group/080759--B4--[Saline]_200Hz.npy
../LFP_Bank/00_Control group/120759--B8--[Saline]_200Hz.npy
../LFP_Bank/00_Control group/150759--B10--[Saline]_200Hz.npy
../LFP_Bank/00_Control group/160759--B12--[Saline]_200Hz.npy
1
../LFP_Bank/04_Morphine II (15 mg, ip)/090158--N12--[1R4Mor15]_200Hz.npy
../LFP_Bank/04_Morphine II (15 mg, ip)/210258--TN1--[4R5Mor15]_200Hz.npy
../LFP_Bank/04_Morphine II (15 mg, ip)/300158--N9--[2R4Mor15]_200Hz.npy
../LFP_Bank/04_Morphine II (15 mg, ip)/131157--N15--[1R4Mor15]_200Hz.npy
../LFP_Bank/04_Morphine II (15 mg, ip)/271157--N13--[2R4Mor15]_200Hz.npy
../LFP_Bank/0

  final_output = np.concatenate(np.array(all_segmented_data), axis = 0) # concatenate only axis 0
  final_label = np.concatenate(np.array(all_label), axis = 0) # concatenate only axis 0


(19, 450)

In [81]:
X, y = final_output, final_label

In [82]:
X_new = np.zeros((X.shape[0], X.shape[1], D, X.shape[2], X.shape[3]))
for i in range (X.shape[0]):
#     print(i)
    for j in range (X.shape[1]):
#         print(j)
        delta = butter_bandpass_filter(data=X[i, j], lowcut=1, highcut=4, fs=FS, order=5)
        theta = butter_bandpass_filter(data=X[i, j], lowcut=4, highcut=9, fs=FS, order=5)
        alpha = butter_bandpass_filter(data=X[i, j], lowcut=9, highcut=13, fs=FS, order=5)
        beta = butter_bandpass_filter(data=X[i, j], lowcut=13, highcut=30, fs=FS, order=5)
        gamma_I = butter_bandpass_filter(data=X[i, j], lowcut=30, highcut=45, fs=FS, order=5)
        gamma_II = butter_bandpass_filter(data=X[i, j], lowcut=60, highcut=95, fs=FS, order=5)
        X_new[i, j] = np.array([delta, theta, alpha, beta, gamma_I, gamma_II])

In [83]:
# save test data out
np.save('./data/X_meth_control.npy', X_new, allow_pickle=True)
np.save('./data/y_meth_control.npy', y, allow_pickle=True)

In [84]:
def history_values(history):
    print(history.history.keys())
    val_classifier_accuracy = history.history['val_classifier_accuracy']
    
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    
    decoder_loss = history.history['decoder_loss']
    val_decoder_loss = history.history['val_decoder_loss']
    
    classifier_loss = history.history['classifier_loss']
    val_classifier_loss = history.history['val_classifier_loss']
    return val_classifier_accuracy, loss, val_loss, decoder_loss, val_decoder_loss, classifier_loss, val_classifier_loss

def plot_loss(history):
    # list all data in history
    print(history.history.keys())
    # summarize history for accuracy
    plt.plot(history.history['val_classifier_accuracy'])
    plt.title('Model accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Number of traning iteration')
    plt.legend(['Classifier accuracy'], loc = 'lower right')
    plt.show()

    # summarize history for loss
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Total loss')
    plt.ylabel('Loss')
    plt.xlabel('Number of traning iteration')
    plt.legend(['Train', 'Validation'], loc = 'upper right')
    plt.show()
    
    plt.plot(history.history['decoder_loss'])
    plt.plot(history.history['val_decoder_loss'])
    plt.title('Mean squared error loss')
    plt.ylabel('Loss')
    plt.xlabel('Number of traning iteration')
    plt.legend(['Decoder_loss', 'Val_decoder_loss'], loc = 'upper right')
    plt.show() 
    
    plt.plot(history.history['classifier_loss'])
    plt.plot(history.history['val_classifier_loss'])
    plt.title('Cross-entropy loss')
    plt.ylabel('Loss')
    plt.xlabel('Number of traning iteration')
    plt.legend(['Classifier_loss', 'Val_classifier_loss'], loc = 'upper right')
    plt.show() 
    
def plot_tsne(decomposed_data, y_test):
    f, ax = plt.subplots(figsize = (10,8))
    steps=1
    for label in np.unique(y_test):
        if label == 0:
            text_label = "Con"
            decomposed_class = decomposed_data[label == y_test]
            ax.scatter(decomposed_class[::steps, 1], decomposed_class[::steps,0], s = 200, edgecolors='black', 
                       linewidths=0.4, color='gray', marker='o', label = str(text_label))
        if label == 2:
            text_label = "Meth"
            decomposed_class = decomposed_data[label == y_test]
            ax.scatter(decomposed_class[::steps, 1], decomposed_class[::steps,0], s = 200, edgecolors='firebrick',
                       linewidths=0.4, color='red', marker='o', label = str(text_label))
        if label == 1:
            text_label = "Mor"
            decomposed_class = decomposed_data[label == y_test]
            ax.scatter(decomposed_class[::steps, 1], decomposed_class[::steps,0], s = 200, edgecolors='sienna',
                       linewidths=0.4, color='sandybrown', marker='o', label = str(text_label))
        if label == 3:
            text_label = "THC"
            decomposed_class = decomposed_data[label == y_test]
            ax.scatter(decomposed_class[::steps, 1], decomposed_class[::steps,0], s = 200, edgecolors='tan',
                       linewidths=0.4, color='moccasin', marker='o', label = str(text_label))
        if label == 4:
            text_label = "MDMA"
            decomposed_class = decomposed_data[label == y_test]
            ax.scatter(decomposed_class[::steps, 1], decomposed_class[::steps,0], s = 200, edgecolors='olivedrab',
                       linewidths=0.4, color='chartreuse', marker='o', label = str(text_label))
        if label == 5:
            text_label = "L-DOPA"
            decomposed_class = decomposed_data[label == y_test]
            ax.scatter(decomposed_class[::steps, 1], decomposed_class[::steps,0], s = 200, edgecolors='seagreen',
                       linewidths=0.4, color='mediumspringgreen', marker='o', label = str(text_label))
        if label == 6:
            text_label = "Keta"
            decomposed_class = decomposed_data[label == y_test]
            ax.scatter(decomposed_class[::steps, 1], decomposed_class[::steps,0], s = 200, edgecolors='darkcyan',
                       linewidths=0.4, color='darkturquoise', marker='o', label = str(text_label))
        if label == 7:
            text_label = "Lora"
            decomposed_class = decomposed_data[label == y_test]
            ax.scatter(decomposed_class[::steps, 1], decomposed_class[::steps,0], s = 200, edgecolors='navy', 
                       linewidths=0.4, color='blue', marker='o', label = str(text_label))
        if label == 8:
            text_label = "Fluox"
            decomposed_class = decomposed_data[label == y_test]
            ax.scatter(decomposed_class[::steps, 1], decomposed_class[::steps,0], s = 200, edgecolors='darkorchid',
                       linewidths=0.4, color='plum', marker='o', label = str(text_label))
        if label == 9:
            text_label = "AlkaKT"
            decomposed_class = decomposed_data[label == y_test]
            ax.scatter(decomposed_class[::steps, 1], decomposed_class[::steps,0], s = 200, edgecolors='mediumvioletred',
                       linewidths=0.4, color='palevioletred', marker='o', label = str(text_label))
        if label == 10:
            text_label = "SyraKT"
            decomposed_class = decomposed_data[label == y_test]
            ax.scatter(decomposed_class[::steps, 1], decomposed_class[::steps,0], s = 200, edgecolors='teal',
                       linewidths=0.4, color='cyan', marker='o', label = str(text_label))
        if label == 11:
            text_label = "Jas"
            decomposed_class = decomposed_data[label == y_test]
            ax.scatter(decomposed_class[::steps, 1], decomposed_class[::steps,0], s = 200, edgecolors='indigo',
                       linewidths=0.4, color='lightpink', marker='o', label = str(text_label))
    ax.legend()
    plt.show()
    

In [85]:
# Set channels last system
K.set_image_data_format('channels_last')
log_path="logs_3"
model_name="AE"
if not os.path.exists(log_path):
    os.makedirs(log_path)

# Set folder for saving model
model_path = "model"
if not os.path.exists(model_path):
    os.makedirs(model_path)

In [88]:
def fit_model(X, y, epochs, batch_size):
    
    X = X.reshape(-1,D,C,T)
    X = np.swapaxes(X, 2, 3)
    
    yl = y.reshape(-1)
    
    ys = np.arange(y.shape[0])
    ys = np.repeat([ys], y.shape[1], axis=0)
    ys = np.swapaxes(ys, 0, 1)
    ys = ys.reshape(-1)
            
    skf = StratifiedKFold(n_splits = 2, random_state = 42, shuffle = True)
    skf.get_n_splits(X, ys)
    print(skf)
    
    fold=0
    y_true_all, y_pred_all, scores_all, clas_report_all  = [], [], [], []
    val_clas_acc_all, loss_all, val_loss_all, dec_loss_all = [], [], [], []
    val_dec_loss_all, clas_loss_all, val_clas_loss_all = [], [], []
    
    for learn_index, test_index in skf.split(X, ys):
        print("LEARN:", learn_index, "TEST:", test_index)
        X_learn, X_test = X[learn_index], X[test_index]
#         print('X_learn:', X_learn)
        y_learn, y_test = yl[learn_index], yl[test_index]
#         print('y_learn:', y_learn)
        ys_learn, ys_test = ys[learn_index], ys[test_index]
#         print('ys_learn:', ys_learn)

        skf_train = StratifiedKFold(n_splits = 5, random_state = 42, shuffle = True)
        skf_train.get_n_splits(X_learn, ys_learn)
        print(skf_train)
        
        y_true_fold, y_pred_fold, scores_fold, clas_report_fold = [], [], [], []
        val_clas_acc_fold, loss_fold, val_loss_fold, dec_loss_fold = [], [], [], []
        val_dec_loss_fold, clas_loss_fold, val_clas_loss_fold = [], [], []
        
        for train_index, val_index in skf_train.split(X_learn, ys_learn):
            print("TRAIN:", train_index, "VAL:", val_index)
            X_train, X_val = X_learn[train_index], X_learn[val_index]
#             print('X_train:', X_train)
#             print('X_val:', X_val)
            y_train, y_val = y_learn[train_index], y_learn[val_index]
#             print('y_train:', y_train)
            ys_train, ys_val = ys_learn[train_index], ys_learn[val_index]
#             print('ys_train:', ys_train)
            
            y_train_dummy = np.zeros_like(y_train)
            y_val_dummy = np.zeros_like(y_val)
            y_test_dummy = np.zeros_like(y_test)
            
            fold += 1
            
            # reset model
            weights_dir = log_path+'/'+model_name+'_out_weights'+str(fold)+'.h5'
            num_class = len(np.unique(y))
            model = build(num_class = num_class)
            model.summary()
            optimizer = Adam(learning_rate = 1e-4, beta_1 = 0.9, beta_2 = 0.999, epsilon = 1e-08)
            model.compile(optimizer = optimizer, 
                          loss = ['mean_squared_error', 'sparse_categorical_crossentropy'],
                          metrics = ['accuracy'], loss_weights = [1., 1.])
              
            # set callbacks
            checkpointer  = ModelCheckpoint(monitor='val_loss',
                                            filepath=log_path+'/'+model_name+'_out_weights'+str(fold)+'.h5', 
                                            verbose=1, save_best_only=True, save_weight_only=True)
            csv_logger    = CSVLogger(log_path+'/'+model_name+'_out_log'+str(fold)+'.log')
            reduce_lr     = ReduceLROnPlateau(monitor='val_loss', patience=5, factor=0.5, mode='min',
                                              verbose=1, min_lr=1e-5)
            es            = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=20) 
                            # patience = save data at n ecoch followed by 10 bad epoch and stop 
            
            # fit model
            history = model.fit(x = X_train, y = [X_train, y_train], 
                            epochs = epochs, # n of cycle for training 
                            shuffle = True, 
                            batch_size = batch_size, 
                            validation_data = (X_val, [X_val, y_val]),
                            callbacks=[checkpointer, csv_logger, reduce_lr, es])
            
            val_clas_acc, loss, val_loss, dec_loss, val_dec_loss, clas_loss, val_clas_loss = history_values(history)
            
            val_clas_acc_fold.append(val_clas_acc)
            loss_fold.append(loss)
            val_loss_fold.append(val_loss)
            dec_loss_fold.append(dec_loss)
            val_dec_loss_fold.append(val_dec_loss)
            clas_loss_fold.append(clas_loss)
            val_clas_loss_fold.append(val_clas_loss)
            
            def build_test():
                model = build(num_class = num_class)
                encoder_input = model.layers[0].output
                encoder = model.layers[1]
                decoder = model.layers[2]
                classifier = model.layers[3]
                latent = encoder(encoder_input)
                train_xr = decoder(latent)
                z = classifier(latent)
                model = models.Model(inputs = [encoder_input], outputs = [latent, train_xr, z],  name = 'AE')
                model.compile(optimizer = optimizer, 
                              loss = [triplet_loss(margin=1.), 'mean_squared_error', 'sparse_categorical_crossentropy'],
                              metrics = ['accuracy'], loss_weights = [0., 1., 1.])
                return model
            
            model = build_test()
            model.load_weights(weights_dir)
            
            # saving model
            model.save(f'{model_path}/00_6_autoencoder_[fit]_balanced_class-6-bands-wo-loss-test_2.h5')
            
            latent, train_xr, z = model.predict(X_test, batch_size = batch_size)
            
            y_true_fold.append(y_test)
            y_pred = np.argmax(z, axis=1)
            y_pred_fold.append(y_pred)
            
            print(model.evaluate(x = X_test, y = [y_test_dummy, X_test, y_test]))
            print(classification_report(y_test, y_pred, output_dict=True))
            
            scores = model.evaluate(x = X_test, y = [y_test_dummy, X_test, y_test])
            clas_report = classification_report(y_test, y_pred, output_dict=True)
        
            scores_fold.append(scores)
            clas_report_fold.append(clas_report)
                           
            tsne = TSNE(n_components = 2, random_state = 42)
            decomposed_data = tsne.fit_transform(latent)
            
            # plot loss
            plot_loss(history)
            plot_tsne(decomposed_data, y_test)
            
            K.clear_session()
            
        val_clas_acc_all.append(val_clas_acc_fold)
        val_clas_acc_all_data = np.array(val_clas_acc_all)
        loss_all.append(loss_fold)
        loss_all_data = np.array(loss_all)
        val_loss_all.append(val_loss_fold)
        val_loss_all_data = np.array(val_loss_all)
        dec_loss_all.append(dec_loss_fold)
        dec_loss_all_data = np.array(dec_loss_all)
        val_dec_loss_all.append(val_dec_loss_fold)
        val_dec_loss_all_data = np.array(val_dec_loss_all)
        clas_loss_all.append(clas_loss_fold)
        clas_loss_all_data = np.array(clas_loss_all)
        val_clas_loss_all.append(val_clas_loss_fold)
        val_clas_loss_all_data = np.array(val_clas_loss_all)
        
        y_true_all.append(y_true_fold)
        y_true_all_data = np.array(y_true_all)
        
        y_pred_all.append(y_pred_fold)
        y_pred_all_data = np.array(y_pred_all)
        
        scores_all.append(scores_fold)
        scores_all_data = np.array(scores_all)
        
        clas_report_all.append(clas_report_fold)
        clas_report_all_data = np.array(clas_report_all)
    
    return (val_clas_acc_all_data, loss_all_data, val_loss_all_data, dec_loss_all_data, val_dec_loss_all_data,
            clas_loss_all_data, val_clas_loss_all_data, y_true_all_data, y_pred_all_data, scores_all_data,
            clas_report_all_data, train_xr, z, fold)

In [None]:
(val_clas_acc, loss, val_loss, dec_loss, val_dec_loss, clas_loss, val_clas_loss, y_true_all_data, y_pred_all_data, 
 scores_all_data, clas_report_all_data, train_xr, z, fold) = fit_model(X = X_new, y = y, epochs = 71, batch_size = 128)

StratifiedKFold(n_splits=2, random_state=42, shuffle=True)
LEARN: [   0    1    4 ... 8542 8543 8547] TEST: [   2    3    5 ... 8546 8548 8549]
StratifiedKFold(n_splits=5, random_state=42, shuffle=True)
TRAIN: [   1    2    3 ... 4272 4273 4274] VAL: [   0   11   13   14   21   32   55   56   60   62   63   67   76   77
   78   79   80   88   93   97   98  110  113  116  117  120  121  126
  128  137  139  141  145  150  153  162  163  170  171  175  193  200
  203  217  221  226  227  244  248  250  253  254  265  282  284  287
  288  289  292  293  294  298  302  306  313  323  324  328  339  343
  349  350  353  354  358  360  363  364  367  368  379  381  409  411
  415  421  423  428  432  447  471  479  481  484  493  496  500  503
  510  512  515  519  524  528  531  537  541  542  545  555  565  567
  575  583  584  588  591  593  597  608  611  618  619  625  638  640
  641  642  644  650  653  654  656  658  669  678  683  684  690  692
  707  710  712  713  714  719  721  72

In [None]:
# Report classifier_accuracy
val_classifier_accuracy = scores_all_data[:,:,-1].reshape(-1)*100
mean_classifier_accuracy = mean(val_classifier_accuracy) 
sem_classifier_accuracy = sem(val_classifier_accuracy)
 
print("Mean:", mean_classifier_accuracy, "SEM:", sem_classifier_accuracy)

In [None]:
# # Export classification_report
# xx = []
# for i in range (fold):
#     a = clas_report_all_data.reshape(-1)[i].get('macro avg')
#     x = a.get("f1-score")
#     xx.append(x)

In [None]:
val_clas_loss.shape

In [None]:
def plot_acc_loss_mean(val_clas_acc, loss, val_loss, dec_loss, val_dec_loss, clas_loss, val_clas_loss):
    val_clas_accu = val_clas_acc.reshape(-1, val_clas_acc.shape[2])
    val_clas_accu = np.swapaxes(val_clas_accu, 0, 1)
    val_clas_aver = np.average(val_clas_accu, axis=1)*100
    val_clas_sem = stats.sem(val_clas_accu, axis=1)*100
    
    losss = loss.reshape(-1, loss.shape[2])
    losss = np.swapaxes(losss, 0, 1)
    loss_aver = np.average(losss, axis=1)
    loss_sem = stats.sem(losss, axis=1)
    
    val_losss = val_loss.reshape(-1, val_loss.shape[2])
    val_losss = np.swapaxes(val_losss, 0, 1)
    val_loss_aver = np.average(val_losss, axis=1)
    val_loss_sem = stats.sem(val_losss, axis=1)
    
    dec_losss = dec_loss.reshape(-1, dec_loss.shape[2])
    dec_losss = np.swapaxes(dec_losss, 0, 1)
    dec_loss_aver = np.average(dec_losss, axis=1)
    dec_loss_sem = stats.sem(dec_losss, axis=1)
    
    val_dec_losss = val_dec_loss.reshape(-1, val_dec_loss.shape[2])
    val_dec_losss = np.swapaxes(val_dec_losss, 0, 1)
    val_dec_loss_aver = np.average(val_dec_losss, axis=1)
    val_dec_loss_sem = stats.sem(val_dec_losss, axis=1)
    
    clas_losss = clas_loss.reshape(-1, clas_loss.shape[2])
    clas_losss = np.swapaxes(clas_losss, 0, 1)
    clas_loss_aver = np.average(clas_losss, axis=1)
    clas_loss_sem = stats.sem(clas_losss, axis=1)
    
    val_clas_losss = val_clas_loss.reshape(-1, val_clas_loss.shape[2])
    val_clas_losss = np.swapaxes(val_clas_losss, 0, 1)
    val_clas_loss_aver = np.average(val_clas_losss, axis=1)
    val_clas_loss_sem = stats.sem(val_clas_losss, axis=1)

    
    sns.set()
    # x = np.arange(len(val_classifier_average))
    x = np.arange(71) # limit x axis
    plt.figure(figsize=(10, 7.3))
    plt.plot(x, val_clas_aver, 'b-', label='Classifier accuracy')
    plt.fill_between(x, val_clas_aver - val_clas_sem, val_clas_aver + val_clas_sem, color='b', alpha=0.2)
    plt.title('Model accuracy', fontdict=dict(weight='bold'))
    plt.ylabel('Accuracy', fontdict=dict(weight='bold'))
    plt.xlabel('Number of traning iteration', fontdict=dict(weight='bold'))
    plt.legend()
    plt.show()
    
    plt.figure(figsize=(10, 7.3))
    plt.plot(x, loss_aver, 'b-', label='Training')
    plt.fill_between(x, loss_aver - loss_sem, loss_aver + loss_sem, color='b', alpha=0.2)
    plt.plot(x, val_loss_aver, 'r--', label='Validation')
    plt.fill_between(x, val_loss_aver - val_loss_sem, val_loss_aver + val_loss_sem, color='r', alpha=0.2)
    plt.title('Total loss', fontdict=dict(weight='bold'))
    plt.ylabel('Loss', fontdict=dict(weight='bold'))
    plt.xlabel('Number of traning iteration', fontdict=dict(weight='bold'))
    plt.legend()
    plt.show()

    plt.figure(figsize=(10, 7.3))
    plt.plot(x, dec_loss_aver, 'b-', label='Training')
    plt.fill_between(x, dec_loss_aver - dec_loss_sem, dec_loss_aver + dec_loss_sem, color='b', alpha=0.2)
    plt.plot(x, val_dec_loss_aver, 'r--', label='Validation')
    plt.fill_between(x, val_dec_loss_aver - val_dec_loss_sem, val_dec_loss_aver + val_dec_loss_sem, color='r', alpha=0.2)
    plt.title('Mean square error loss', fontdict=dict(weight='bold'))
    plt.ylabel('Loss', fontdict=dict(weight='bold'))
    plt.xlabel('Number of traning iteration', fontdict=dict(weight='bold'))
    plt.legend()
    plt.show()
    
    plt.figure(figsize=(10, 7.3))
    plt.plot(x, clas_loss_aver, 'b-', label='Training')
    plt.fill_between(x, clas_loss_aver - clas_loss_sem, clas_loss_aver + clas_loss_sem, color='b', alpha=0.2)
    plt.plot(x, val_clas_loss_aver, 'r--', label='Validation')
    plt.fill_between(x, val_clas_loss_aver - val_clas_loss_sem, val_clas_loss_aver + val_clas_loss_sem, color='r', alpha=0.2)
    plt.title('Cross-entropy loss', fontdict=dict(weight='bold'))
    plt.ylabel('Loss', fontdict=dict(weight='bold'))
    plt.xlabel('Number of traning iteration', fontdict=dict(weight='bold'))
    plt.legend()
    plt.show()

In [None]:
plot_acc_loss_mean(val_clas_acc, loss, val_loss, dec_loss, val_dec_loss, clas_loss, val_clas_loss)

In [None]:
y_true = y_true_all_data.reshape(-1)
y_pred = y_pred_all_data.reshape(-1)

In [None]:
# y_true = []
# for i in range(fold):
#     for y in y_true_all_data.reshape(-1)[i]:
#         y_true.append(y)
#         y_true_all = np.array(y_true)
        
# y_pred = []import matplotlib.pyplot as plt
# for i in range(fold):
#     for y in y_pred_all_data.reshape(-1)[i]:
#         y_pred.append(y)
#         y_pred_all = np.array(y_pred)

In [None]:
actual_data = y_true
predicted_data = y_pred
cm = confusion_matrix(actual_data, predicted_data)
# print(cm)
plt.figure(figsize=(10, 7.3))
ax = sns.heatmap(cm, annot=True, fmt='g',cmap="Blues_r",annot_kws={'size':9});
ax.set_title('Confusion Matrix', fontdict=dict(weight='bold'));
ax.set_xlabel('Predicted class', fontdict=dict(weight='bold'))
ax.set_ylabel('Actual class', fontdict=dict(weight='bold'));
ax.yaxis.set_ticklabels(['Con', 'Mor'])
ax.xaxis.set_ticklabels(['Con', 'Mor'])
# ax.yaxis.set_ticklabels(['Con', 'Meth', 'Mor','THC','MDMA','L-DOPA', 'Keta','Lora','Fluox','AlkaKT'])
# ax.xaxis.set_ticklabels(['Con', 'Meth', 'Mor','THC','MDMA','L-DOPA', 'Keta','Lora','Fluox','AlkaKT'])
plt.show()

In [None]:
cm = confusion_matrix(y_true, y_pred)
# print(cm)
plt.figure(figsize=(10, 7.3))
ax = sns.heatmap(cm/np.sum(cm), annot=True, fmt='.2%', cmap="Blues_r",annot_kws={'size':9})
ax.set_title('Confusion Matrix', fontdict=dict(weight='bold'));
ax.set_xlabel('Predicted class', fontdict=dict(weight='bold'))
ax.set_ylabel('Actual class', fontdict=dict(weight='bold'));
ax.yaxis.set_ticklabels(['Con', 'Mor'])
ax.xaxis.set_ticklabels(['Con', 'Mor'])
# ax.yaxis.set_ticklabels(['Con', 'Meth', 'Mor','THC','MDMA','L-DOPA', 'Keta','Lora','Fluox','AlkaKT'])
# ax.xaxis.set_ticklabels(['Con', 'Meth', 'Mor','THC','MDMA','L-DOPA', 'Keta','Lora','Fluox','AlkaKT'])
plt.show()

In [None]:
# T = np.array([ y  for y in y_true_all_data.reshape(-1)[i] for i in range(45)])
# P = np.array([ y  for y in y_pred_all_data.reshape(-1)[i] for i in range(45)])

In [None]:
# y_true = []
# for i in range(36):
#     for y in y_true_all_data.reshape(-1)[i]:
#         y_true.append(y)
#         y_true_all = np.array(y_true)
        
# y_pred = []
# for i in range(4):
#     for y in y_pred_all_data.reshape(-1)[i]:
#         y_pred.append(y)
#         y_pred_all = np.array(y_true)

In [None]:
# import numpy as np
# from sklearn.model_selection import StratifiedKFold
# X = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7],[8, 8], [9, 9], [10, 10], [11, 11]])
# y = np.array([0, 0,0,1,1,1, 1,1,1,1,1, 1])
# ys = np.array([0,0,0,1,1,1,2,2,2,3,3,3])
# skf = StratifiedKFold(n_splits=2)
# skf.get_n_splits(X, ys)
# print(skf)
# StratifiedKFold(n_splits=2, random_state=42, shuffle=True)
# for learn_index, test_index in skf.split(X, ys):
#     print("LEARN:", learn_index, "TEST:", test_index)
#     X_learn, X_test = X[learn_index], X[test_index]
#     print('X_learn:', X_learn)
#     y_learn, y_test = y[learn_index], y[test_index]
#     print('y_learn:', y_learn)
#     ys_learn, ys_test = ys[learn_index], ys[test_index]
#     print('ys_learn:', ys_learn)
    
#     skf_train = StratifiedKFold(n_splits=2, random_state=42, shuffle=True)
#     skf_train.get_n_splits(X_learn, ys_learn)
#     print(skf_train)
#     for train_index, val_index in skf_train.split(X_learn, ys_learn):
#         print("TRAIN:", train_index, "VAL:", val_index)
#         X_train, X_val = X_learn[train_index], X_learn[val_index]
#         print('X_train:', X_train)
#         print('X_val:', X_val)
#         y_train, y_val = y_learn[train_index], y_learn[val_index]
#         print('y_train:', y_train)
#         ys_train, ys_val = ys_learn[train_index], ys_learn[val_index]
#         print('ys_train:', ys_train)