In [70]:
import numpy as np
import matplotlib.pylab as plt
import scipy.io
import random
from sklearn.model_selection import train_test_split
from tensorflow.keras import models, layers
from tensorflow import keras
from sklearn.metrics import accuracy_score
from tensorflow.keras.utils import to_categorical
import absl.logging
import tensorflow as tf


def mish(x):
    return x * tf.math.tanh(tf.math.softplus(x))

absl.logging.set_verbosity(absl.logging.ERROR) # .h5 에러메세지 무시

mat_file = scipy.io.loadmat("data/PTSDvsHC.mat")

dataset = mat_file['dataset'] 
# 90개의 데이터, 62개의 전극 채널, 1001의 타임스텝 - 1001씩 이미 에포킹을 한 상태이며, 그에따른 평균을 낸 것임(에포킹에 대해)

subject_test_acc = []

X = dataset
X = X.reshape(90, 1001, 62, 1)
y = mat_file['Y']
y = to_categorical(y, num_classes = 2)

hc_indices = list(range(0, 51))
pg_indices = list(range(51, 90))

fold_val_index_list = []
fold_accuracy = []

input_shape = (1001, 62, 1)

for i in range(len(dataset)): # loso
    test_index = [i]

    hc_val = random.sample([idx for idx in hc_indices if idx not in test_index], 5)
    pg_val = random.sample([idx for idx in pg_indices if idx not in test_index], 5)

    validation_indices = hc_val + pg_val

    training_indices = [idx for idx in range(90) if idx not in test_index + validation_indices]

    fold_val_index_list.append({
        'fold' : i, 
        'test_index' : test_index, 
        'validation_indices' : validation_indices
    })
    
    X_train = X[training_indices, :,  :]
    y_train = y[training_indices, :]
    X_test = X[test_index, :,  :]
    y_test = y[test_index, :]
    X_val = X[validation_indices, :,  :]
    y_val = y[validation_indices, :]

    # print(f"X_train_shape : {X_train.shape}")
    # print(f"X_test_shape : {X_test.shape}")
    # print(f"y_train_shape : {y_train.shape}")
    # print(f"y_test_shape : {y_test.shape}")
    # print(f"X_validation_shape : {X_val.shape}")
    # print(f"y_validation_shape : {y_val.shape}")
    
    model_input = layers.Input(shape = input_shape)
    br1 = layers.Conv2D(32, (50, 1), padding = 'same')(model_input) # temporal filter
    br1 = layers.Activation(mish)(br1) # result1_matrix, shape=(1001, 62, 32)
    # print(br1.shape)

    input_shape_tr = layers.Permute((2, 1, 3))(model_input) # (62, 1001, 1)
    br2 = layers.Conv2D(32, (62, 1), padding = 'same')(input_shape_tr)
    br2 = layers.Activation(mish)(br2) # result2_matrix, shape=(62, 1001, 32)

    # br1, br2 matrix multiply
    br1 = layers.Permute((3, 1, 2))(br1) # (32, 1001, 62)
    br2 = layers.Permute((3, 1, 2))(br2) # (32, 62, 1001)
    br_mul = layers.Lambda(lambda x: tf.matmul(x[0], x[1]))([br2, br1])
    # print(br_mul.shape)  # (32, 62, 62)

    # branch matrix multiply convolution
    br_mul = layers.Permute((2, 3, 1))(br_mul) # (62, 62, 32) <- 풀링하기 위함
    br_mul = layers.Conv2D(16, (3, 3), padding = 'same')(br_mul)
    br_mul = layers.Activation(mish)(br_mul)
    br_mul = layers.GlobalAveragePooling2D()(br_mul)
    end = layers.Flatten()(br_mul)

    prediction = layers.Dense(2, activation = 'softmax')(end)
    model = keras.Model(inputs = model_input, outputs = prediction)
    model.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])
    history = model.fit(X_train, y_train, validation_data = (X_val, y_val), epochs=100, batch_size=1, verbose=2)
    predictions = model.predict(X_test)
    
    predicted_classes = np.argmax(predictions, axis=1)
    actual_classes = np.argmax(y_test, axis=1)

    acc = accuracy_score(predicted_classes, actual_classes)
    fold_accuracy.append(acc)

    # model.summary()
    
    print(f"fold {i} test accuracy : {acc}\n")

    subject_test_acc.append(acc)
scipy.io.savemat('fold_val_indices.mat', {'fold_val_index_list': fold_val_index_list})
print(f"loso mean accuracy : {np.mean(subject_test_acc)}")

Epoch 1/100
79/79 - 4s - 54ms/step - accuracy: 0.5570 - loss: 14.3025 - val_accuracy: 0.5000 - val_loss: 18.5316
Epoch 2/100
79/79 - 3s - 39ms/step - accuracy: 0.4051 - loss: 4.7281 - val_accuracy: 0.6000 - val_loss: 0.5529
Epoch 3/100
79/79 - 3s - 39ms/step - accuracy: 0.6203 - loss: 1.4448 - val_accuracy: 0.5000 - val_loss: 1.1092
Epoch 4/100
79/79 - 3s - 39ms/step - accuracy: 0.6835 - loss: 0.6452 - val_accuracy: 0.5000 - val_loss: 0.8351
Epoch 5/100
79/79 - 3s - 39ms/step - accuracy: 0.6456 - loss: 0.5255 - val_accuracy: 0.6000 - val_loss: 0.7572
Epoch 6/100
79/79 - 3s - 39ms/step - accuracy: 0.7089 - loss: 0.4923 - val_accuracy: 0.7000 - val_loss: 0.7306
Epoch 7/100
79/79 - 3s - 38ms/step - accuracy: 0.7595 - loss: 0.4557 - val_accuracy: 0.5000 - val_loss: 0.6774
Epoch 8/100
79/79 - 3s - 39ms/step - accuracy: 0.7722 - loss: 0.4457 - val_accuracy: 0.6000 - val_loss: 0.6429
Epoch 9/100
79/79 - 3s - 38ms/step - accuracy: 0.8354 - loss: 0.3858 - val_accuracy: 0.6000 - val_loss: 0.9059

KeyboardInterrupt: 