In [None]:
import gc
import warnings

import numpy as np
import matplotlib.pyplot as plt
import tensorflow.keras.backend as K
from keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from keras.optimizers import Adam
from keras.utils.np_utils import to_categorical
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import accuracy_score, ConfusionMatrixDisplay

from utils.gcn_categorical import build_model

warnings.filterwarnings("ignore")

In [None]:
data = np.load("../data/processed/train_img.npy", allow_pickle=True).item()

s, e = 15, 25
a, b = 165, 246

fmri = data["fMRI"]
rsp = data["RSP"]
ppg = data["PPG"]

# train = np.stack((fmri, rsp, ppg), axis=-1)
train = np.expand_dims(np.swapaxes(fmri[:, a:b, s:e], 1, 2), axis=-1)

subject = data["subject"]
target = data["class"].astype(int)  # + 1
level = data["level"]

print(f"Data shape: {train.shape}")
print(f"Subject shape: {subject.shape}")
print(f"Target shape: {target.shape}")
print(f"Level shape: {level.shape}")
print(np.unique(target))
print(np.unique(level))

In [None]:
data_test = np.load("../data/processed/test_img.npy", allow_pickle=True).item()
test_idx = np.where(~np.isnan(data_test["class"]))[0]

y_test = data_test["class"][test_idx].astype(int) + 1
y_test_level = data_test["level"][test_idx]
fmri_test = np.swapaxes(data_test["fMRI"][:, a:b, s:e], 1, 2)
bio_test = np.concatenate((data_test["RSP"], data_test["PPG"]), axis=-1)

print(f"fMRI shape: {fmri_test.shape}")
print(f"Bio shape: {bio_test.shape}")
print(f"Test target shape: {y_test.shape}")
print(np.unique(y_test))

In [None]:
ROI_N = train.shape[2]
fcs = np.zeros((len(fmri), ROI_N, ROI_N))

for i in range(len(fmri)):
    fcs[i] = np.corrcoef(fmri[i, a:b, s:e])

FCs = np.mean(fcs, axis=0)
np.save("FC", FCs)

In [None]:
k = 10  # (3, 5, 10, 20)
batch_size = 16
epochs = 50
l2_reg = 0.1
dp = 0.5
lr = 1e-3
frames = train.shape[1]
val_acc = []
test_count = []
test_score = []
preds = []
val_mae = []
accs = []

X = train
y = target + 1
g1 = subject

loso_tidx = np.load("../data/loso_tidx.npy", allow_pickle=True)
loso_vidx = np.load("../data/loso_vidx.npy", allow_pickle=True)

cv = LeaveOneGroupOut()
# cv = GroupKFold(n_splits=5)
# cv = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)
for i, (tidx, vidx) in enumerate(cv.split(X, y, g1)):
    print("#" * 50)
    print(f"### Fold {i + 1}")

    # X_train, X_val = X[tidx], X[vidx]
    # y_train, y_val = to_categorical(y[tidx]), to_categorical(y[vidx])
    # y_train_level, y_val_level = level[tidx], level[vidx]
    
    X_train, X_val = X[loso_tidx[i]], X[loso_vidx[i]]
    y_train, y_val = to_categorical(y[loso_tidx[i]]), to_categorical(y[loso_vidx[i]])
    y_train_level, y_val_level = level[loso_tidx[i]], level[loso_vidx[i]]
    
    mdl_ch = ModelCheckpoint(f"../results/models/GCN_v1_f{i + 1}.h5", monitor="val_class_acc",
                             save_best_only=True, save_weights_only=True, verbose=0)

    print(f"### train size {len(tidx)}, valid size {len(vidx)}")
    print("#" * 50)

    # Fit the model
    K.clear_session()
    model = build_model(
        graph_path="FC.npy",
        ROI_N=ROI_N,
        frames=frames,
        kernels=[8, 8, 16, 32, 64, 128],
        k=k,
        l2_reg=l2_reg,
        dp=dp,
        num_classes=3)

    model.compile(loss={"class": "categorical_crossentropy", "level": "mae"},
                  optimizer=Adam(learning_rate=lr),
                  metrics={"class": "accuracy", "level": "mae"})

    reduce_lr = ReduceLROnPlateau(monitor="val_class_acc", factor=0.5, patience=10, min_lr=1e-6)
    earlystop = EarlyStopping(monitor="val_loss", patience=10)

    model.fit((X_train, X_train), [y_train, y_train_level],
              shuffle=True,
              batch_size=batch_size,
              validation_data=(X_val, [y_val, y_val_level]),
              epochs=epochs,
              callbacks=[mdl_ch, reduce_lr],
              verbose=0)

    val_acc.append(np.max(model.history.history["val_class_acc"]))
    val_mae.append(np.min(model.history.history["val_level_mean_absolute_error"]))

    # Inference on test set
    Y_pred = model.predict(np.expand_dims(fmri_test, axis=-1), verbose=0)
    preds.append(Y_pred)
    pred_classes = np.argmax(Y_pred[0], axis=1) - 1
    true_classes = data_test["class"][test_idx].astype(int)
    class_error = 1 - (np.sum(pred_classes[test_idx] == true_classes) / len(true_classes))
    test_count.append(np.sum(pred_classes[test_idx] == true_classes))
    accs.append(accuracy_score(true_classes, pred_classes[test_idx]))
    test_score.append(class_error)

    del model
    gc.collect()

print("#" * 100)
print(val_acc)
print("#" * 50)
print("Acc. stats:")
print(np.mean(val_acc), np.std(val_acc), np.min(val_acc), np.max(val_acc))
print("MAE stats:")
print(np.mean(val_mae), np.std(val_mae), np.min(val_mae), np.max(val_mae))
print("#" * 100)
print("Test stats:")
print(f"Correct predictions: {test_count} out of {len(y_test)}")
print(f"Score: ")
print(np.mean(test_score), np.std(test_score), np.min(test_score), np.max(test_score))

In [None]:
p = np.mean(preds, axis=0)

ConfusionMatrixDisplay.from_predictions(y_test - 1, np.argmax(preds[0][test_idx], axis=1) - 1, cmap=plt.cm.Blues)
plt.grid(False)
plt.show()

In [None]:
# Phase2 Test labels accuracy
np.mean(accs)