In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [11]:
import numpy as np
import scipy as sp
from tqdm import auto, tqdm
import torch
from sklearn.model_selection import train_test_split

import sys
sys.path.append('/content/drive/MyDrive/11777-Project-xs/LGG/')
sys.path.append('/content/drive/MyDrive/11777-Project-xs/MLP/')
from prepare_data_DEAP import PrepareData
from cross_validation import CrossValidation
from train_model import *

In [3]:
# Define default parameters
default_args = {
    'dataset': 'DEAP',
    'data_path': './data_preprocessed_python/',
    'subjects': 22,
    'num_class': 2,
    'label_type': 'A',
    'segment': 4,
    'overlap': 0,
    'sampling_rate': 128,
    'scale_coefficient': 1,
    'input_shape': (1, 32, 512),
    'data_format': 'eeg',
    'random_seed': 2021,
    'max_epoch': 50, #200
    'patient': 5, #20,
    'patient_cmb': 8,
    'max_epoch_cmb': 20,
    'batch_size': 64,
    'learning_rate': 1e-3,
    'step_size': 5,
    'dropout': 0.5,
    'LS': True,
    'LS_rate': 0.1,
    'save_path': '/content/drive/MyDrive/11777-Project-xs/LGG/save/',
    'load_path': '/content/drive/MyDrive/11777-Project-xs/LGG/save/max-acc.pth',
    'load_path_final': '/content/drive/MyDrive/11777-Project-xs/LGG/save/final_model.pth',
    'mlp_save_path': '/content/drive/MyDrive/11777-Project-xs/MLP/save/',
    'mlp_load_path': '/content/drive/MyDrive/11777-Project-xs/MLP/save/max-acc.pth',
    'mlp_load_path_final': '/content/drive/MyDrive/11777-Project-xs/MLP/save/final_model.pth',
    'gpu': '0',
    'save_model': True,
    'model': 'LGGNet',
    'pool': 16,
    'pool_step_rate': 0.25,
    'T': 64,
    'graph_type': 'hem',
    'hidden': 32
}

In [4]:
######## Reproduce the result using the saved model ######
# parser.add_argument('--reproduce', action='store_true')
# args = parser.parse_args()
class Args:
    def __init__(self, param_dict):
        for key, value in param_dict.items():
            setattr(self, key, value)

default_args['data_path'] = '/content/drive/MyDrive/11777-Project-xs/data_preprocessed_python/'
default_args['reproduce'] = 'store_true'
# Create an Args object with default parameters
args = Args(default_args)


In [5]:
# only need to run once
sub_to_run = np.arange(args.subjects)
prepdt = PrepareData(args)
prepdt.run(sub_to_run, split=True, expand=True)

data:(40, 32, 7680) label:(40, 4)
Binary label generated!
The data and label are split: Data shape:(40, 15, 1, 32, 512) Label:(40, 15)
Data and label prepared!
data:(40, 15, 1, 32, 512) label:(40, 15)
----------------------
data:(40, 32, 7680) label:(40, 4)
Binary label generated!
The data and label are split: Data shape:(40, 15, 1, 32, 512) Label:(40, 15)
Data and label prepared!
data:(40, 15, 1, 32, 512) label:(40, 15)
----------------------
data:(40, 32, 7680) label:(40, 4)
Binary label generated!
The data and label are split: Data shape:(40, 15, 1, 32, 512) Label:(40, 15)
Data and label prepared!
data:(40, 15, 1, 32, 512) label:(40, 15)
----------------------
data:(40, 32, 7680) label:(40, 4)
Binary label generated!
The data and label are split: Data shape:(40, 15, 1, 32, 512) Label:(40, 15)
Data and label prepared!
data:(40, 15, 1, 32, 512) label:(40, 15)
----------------------
data:(40, 32, 7680) label:(40, 4)
Binary label generated!
The data and label are split: Data shape:(40, 

In [6]:
participant_ids = list(range(1, 23))
data_save_folder = '/content/drive/MyDrive/11777-Project-xs/data_face_seg_mean'

def delete_file(path):
    '''
    used to remove temporal file generated
    '''
    if os.path.exists(path):
        os.remove(path)

# # only need to run once
# for participant_id in auto.tqdm(participant_ids):
#     participant_id = '{:02d}'.format(participant_id)
#     for trial_id in auto.tqdm(range(1, 41)):
#         file_name = 's{0}_trial{1}_mean_frames.npy'.format(participant_id, trial_id)
#         if not os.path.exists(os.path.join(data_save_folder, file_name)):
#             print('No face data found for s{0}_trial{1}'.format(participant_id, trial_id))
#             continue

#         file_name2 = 's{0}_trial{1}_mean_frames_embed.npy'.format(participant_id, trial_id)
#         if os.path.exists(os.path.join(data_save_folder, file_name2)):
#             continue
#         participant_trial_seg = np.load(os.path.join(data_save_folder, file_name))
#         participant_trial_embed = []
#         for image in participant_trial_seg:
#             delete_file('curr_face_image.png')
#             plt.imsave('curr_face_image.png', np.uint8(participant_trial_seg[0]),
#                        cmap=plt.cm.Spectral)
#             embedding_objs = DeepFace.represent(img_path = 'curr_face_image.png', enforce_detection = False)
#             embedding = embedding_objs[0]['embedding']
#             assert isinstance(embedding, list)
#             assert len(embedding) == 2622
#             participant_trial_embed.append(np.array(embedding))

#         participant_trial_embed = np.stack(participant_trial_embed, axis=0)
#         np.save(os.path.join(data_save_folder, file_name2), participant_trial_embed)


In [7]:
import os
data_save_folder = '/content/drive/MyDrive/11777-Project-xs/data_face_seg_mean'
face_embed_file_list = os.listdir(data_save_folder)
# print(len(face_embed_file_list))
face_embed_file_list = [e for e in face_embed_file_list if 'embed.npy' in e]
sub_trial_list = {}
for file in face_embed_file_list:
    # print(file)
    sub, trial, _, _, _ = file.split('_')
    sub, trial = int(sub[1:]), int(trial[5:])
    if sub in sub_trial_list.keys():
        sub_trial_list[sub].append(trial)
    else:
        sub_trial_list[sub] = [trial]

for sub in range(1,23):
    print('sub {}, {} trials with facial video'.format(sub, len(sub_trial_list[sub])))

sub 1, 40 trials with facial video
sub 2, 40 trials with facial video
sub 3, 39 trials with facial video
sub 4, 40 trials with facial video
sub 5, 39 trials with facial video
sub 6, 40 trials with facial video
sub 7, 40 trials with facial video
sub 8, 40 trials with facial video
sub 9, 40 trials with facial video
sub 10, 40 trials with facial video
sub 11, 37 trials with facial video
sub 12, 40 trials with facial video
sub 13, 40 trials with facial video
sub 14, 39 trials with facial video
sub 15, 40 trials with facial video
sub 16, 40 trials with facial video
sub 17, 40 trials with facial video
sub 18, 40 trials with facial video
sub 19, 40 trials with facial video
sub 20, 40 trials with facial video
sub 21, 40 trials with facial video
sub 22, 40 trials with facial video


In [None]:
def load_mean_frames_embed(sub, data_save_folder):
    '''
    sub starting index from 1 (not 0)
    '''
    face_embed_file_list = os.listdir(data_save_folder)
    sub_face_embed_file_list = [e for e in face_embed_file_list if 'embed' in e and sub==int(e[1:3])]
    sorted(sub_face_embed_file_list)
    # print(sub_face_embed_file_list)
    sub_frame_embed = []
    for file in auto.tqdm(sub_face_embed_file_list):
        embed = np.load(os.path.join(data_save_folder, file))
        sub_frame_embed.append(embed)
    sub_frame_embed = np.stack(sub_frame_embed, axis=0)
    print(sub_frame_embed.shape)
    return sub_frame_embed

# 's01_trial1_mean_frames_embed.npy'

data_save_folder = '/content/drive/MyDrive/11777-Project-xs/data_face_seg_mean'
all_sub_frame_embed = []
for sub in auto.tqdm(range(1,23)):
    sub_frame_embed = load_mean_frames_embed(sub, data_save_folder)
    all_sub_frame_embed.append(sub_frame_embed)


  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

(40, 15, 2622)


  0%|          | 0/40 [00:00<?, ?it/s]

(40, 15, 2622)


  0%|          | 0/39 [00:00<?, ?it/s]

(39, 15, 2622)


  0%|          | 0/40 [00:00<?, ?it/s]

(40, 15, 2622)


  0%|          | 0/39 [00:00<?, ?it/s]

(39, 15, 2622)


  0%|          | 0/40 [00:00<?, ?it/s]

(40, 15, 2622)


  0%|          | 0/40 [00:00<?, ?it/s]

(40, 15, 2622)


  0%|          | 0/40 [00:00<?, ?it/s]

(40, 15, 2622)


  0%|          | 0/40 [00:00<?, ?it/s]

(40, 15, 2622)


  0%|          | 0/40 [00:00<?, ?it/s]

(40, 15, 2622)


  0%|          | 0/37 [00:00<?, ?it/s]

(37, 15, 2622)


  0%|          | 0/40 [00:00<?, ?it/s]

(40, 15, 2622)


  0%|          | 0/40 [00:00<?, ?it/s]

(40, 15, 2622)


  0%|          | 0/39 [00:00<?, ?it/s]

(39, 15, 2622)


  0%|          | 0/40 [00:00<?, ?it/s]

(40, 15, 2622)


  0%|          | 0/40 [00:00<?, ?it/s]

(40, 15, 2622)


  0%|          | 0/40 [00:00<?, ?it/s]

(40, 15, 2622)


  0%|          | 0/40 [00:00<?, ?it/s]

(40, 15, 2622)


  0%|          | 0/40 [00:00<?, ?it/s]

(40, 15, 2622)


  0%|          | 0/40 [00:00<?, ?it/s]

(40, 15, 2622)


  0%|          | 0/40 [00:00<?, ?it/s]

(40, 15, 2622)


  0%|          | 0/40 [00:00<?, ?it/s]

(40, 15, 2622)


In [8]:
import pickle
file_name = 'all_sub_frame_embed.pkl'
# with open(os.path.join(data_save_folder, file_name), 'wb') as file:
#     pickle.dump(all_sub_frame_embed, file)

with open(os.path.join(data_save_folder, file_name), 'rb') as file:
    all_sub_frame_embed = pickle.load(file)

In [9]:
def generate_eeg_embed(eeg_data, LGG, CUDA=True):
    LGG.eval()
    eeg_embed = []
    # eeg_data = np.concatenate(eeg_data, axis=0)
    for eeg_data_trial in eeg_data:
        eeg_data_trial = torch.from_numpy(eeg_data_trial).float()
        if CUDA:
            eeg_data_trial = eeg_data_trial.cuda()
            LGG = LGG.cuda()
        with torch.no_grad():
            eeg_embed_trial = LGG.get_embed(eeg_data_trial)
        eeg_embed_trial = eeg_embed_trial.to('cpu')
        eeg_embed_trial = eeg_embed_trial.numpy()
        eeg_embed.append(eeg_embed_trial)
    eeg_embed = np.stack(eeg_embed, axis=0)
    return eeg_embed

In [10]:
def get_model_embed(args):
    if args.model == 'LGGNet':
        idx_local_graph = list(np.array(h5py.File('num_chan_local_graph_{}.hdf'.format(args.graph_type), 'r')['data']))
        channels = sum(idx_local_graph)
        input_size = (args.input_shape[0], channels, args.input_shape[2])
        model = LGGNet2(
            num_classes=args.num_class, input_size=input_size,
            sampling_rate=int(args.sampling_rate*args.scale_coefficient),
            num_T=args.T, out_graph=args.hidden,
            dropout_rate=args.dropout,
            pool=args.pool, pool_step_rate=args.pool_step_rate,
            idx_graph=idx_local_graph)

    return model

def transfer_weight(LGG, LGG_embed):
    for i, (pre_trained_layer, custom_layer) in enumerate(zip(LGG.children(), LGG_embed.children())):
        custom_layer.load_state_dict(pre_trained_layer.state_dict())
    return LGG_embed

In [15]:
from mlp_networks import SimpleMLP
from mlp_train_model import train_simpleMLP, test_simpleMLP

CUDA = torch.cuda.is_available()

In [16]:
CUDA = torch.cuda.is_available()
fold = 'noFold'

with open(os.path.join(data_save_folder, file_name), 'rb') as file:
    all_sub_frame_embed = pickle.load(file)

for sub in range(22):

    print('='*48)
    cv = CrossValidation(args)
    eeg_data, label = cv.load_per_subject(sub)
    trial_to_remove = [i-1 for i in range(1,41) if i not in sub_trial_list[sub+1]]
    ### face embeddings
    sub_frame_embed = all_sub_frame_embed[sub]

    if len(trial_to_remove) > 0:
        mask = np.ones(eeg_data.shape[0], dtype=bool)
        mask[trial_to_remove] = False
        eeg_data = eeg_data[mask]
        label = label[mask]

    ### Load pre-trained LGG model weights
    ### and generate the EEG embeddings
    lgg_model = get_model(args)
    lgg_model_embed = get_model_embed(args)
    if CUDA:
        lgg_model = lgg_model.cuda()
    model_name_reproduce = 'sub' + str(sub+1) + '.pth'
    data_type = 'model_{}_{}'.format(args.data_format, args.label_type)
    experiment_setting = 'T_{}_pool_{}'.format(args.T, args.pool)
    save_path = os.path.join(args.save_path, experiment_setting, data_type)
    ensure_path(save_path)
    model_name_reproduce = os.path.join(save_path, model_name_reproduce)
    lgg_model.load_state_dict(torch.load(model_name_reproduce))
    lgg_model_embed = transfer_weight(lgg_model, lgg_model_embed)
    eeg_embed = generate_eeg_embed(eeg_data, lgg_model_embed, CUDA)

    ### concat eeg + face embedding
    eeg_face_embed = np.concatenate([eeg_embed, sub_frame_embed], axis=2)

    print(sub+1, trial_to_remove, sub_frame_embed.shape[0])
    trial_id = np.arange(eeg_data.shape[0])
    trial_id_train, trial_id_test = train_test_split(trial_id, test_size=0.2, random_state=42)
    trial_id_tra, trial_id_val = train_test_split(trial_id_train, test_size=0.2, random_state=42)
    print(trial_id_tra.shape[0], trial_id_val.shape[0], trial_id_test.shape[0])

    data_train, data_val, label_train, label_val = \
        eeg_face_embed[trial_id_tra], eeg_face_embed[trial_id_val], label[trial_id_tra], label[trial_id_val]
    data_train, data_val, label_train, label_val = \
        np.concatenate(data_train, axis=0), np.concatenate(data_val, axis=0), np.concatenate(label_train, axis=0), np.concatenate(label_val, axis=0)
    print(data_train.shape, data_val.shape, label_train.shape, label_val.shape)

    data_test, label_test = \
        eeg_face_embed[trial_id_test], label[trial_id_test]
    data_test, label_test = \
        np.concatenate(data_test, axis=0), np.concatenate(label_test, axis=0)
    print(data_test.shape, label_test.shape)

    data_train = torch.from_numpy(data_train).float()
    label_train = torch.from_numpy(label_train).long()
    data_val = torch.from_numpy(data_val).float()
    label_val = torch.from_numpy(label_val).long()
    data_test = torch.from_numpy(data_test).float()
    label_test = torch.from_numpy(label_test).long()


    ### define a simple MLP as the predictor
    predictor = SimpleMLP(data_train.shape[1], data_train.shape[1] // 4, 2)

    acc_val, F1_val = train_simpleMLP(model=predictor, args=args,
                                      data_train=data_train,
                                      label_train=label_train,
                                      data_val=data_val,
                                      label_val=label_val,
                                      subject = str(sub+1))

    if CUDA:
        predictor = predictor.cuda()
    predictor.load_state_dict(torch.load(os.path.join(args.mlp_save_path, 'candidate.pth')))

    model_name_reproduce = 'sub' + str(sub+1) + '_mlp.pth'
    data_type = 'model_{}_{}'.format(args.data_format, args.label_type)
    experiment_setting = 'T_{}_pool_{}'.format(args.T, args.pool)
    save_path = os.path.join(args.mlp_save_path, experiment_setting, data_type)
    ensure_path(save_path)
    model_name_reproduce = os.path.join(save_path, model_name_reproduce)
    print(model_name_reproduce)
    torch.save(predictor.state_dict(), model_name_reproduce)

    acc_test, pred, act = test_simpleMLP(model=predictor, args=args, data=data_test, label=label_test,
                                               reproduce=args.reproduce,
                                               subject=str(sub+1))

>>> Data:(40, 15, 1, 32, 512) Label:(40, 15)
1 [] 40
25 7 8
(375, 3166) (105, 3166) (375,) (105,)
(120, 3166) (120,)
using gpu: 0
epoch 1, loss=0.6930 acc=0.5573 f1=0.6891
epoch 1, val, loss=0.6879 acc=0.5714 f1=0.7273
ETA:0s/7s SUB:1
epoch 2, loss=0.6796 acc=0.6053 f1=0.7385
epoch 2, val, loss=0.6883 acc=0.5714 f1=0.7273
ETA:0s/5s SUB:1
epoch 3, loss=0.6702 acc=0.6000 f1=0.7475
epoch 3, val, loss=0.6862 acc=0.5714 f1=0.7273
ETA:0s/4s SUB:1
epoch 4, loss=0.6632 acc=0.5947 f1=0.7379
epoch 4, val, loss=0.6893 acc=0.5714 f1=0.7273
ETA:0s/4s SUB:1
epoch 5, loss=0.6319 acc=0.6507 f1=0.7722
epoch 5, val, loss=0.7047 acc=0.5619 f1=0.7195
ETA:0s/3s SUB:1
epoch 6, loss=0.5994 acc=0.7280 f1=0.8061
epoch 6, val, loss=0.7784 acc=0.5619 f1=0.7195
ETA:0s/3s SUB:1
epoch 7, loss=0.6001 acc=0.6587 f1=0.7322
epoch 7, val, loss=0.8919 acc=0.5619 f1=0.7195
ETA:0s/3s SUB:1
epoch 8, loss=0.5991 acc=0.6960 f1=0.7500
epoch 8, val, loss=0.8274 acc=0.5714 f1=0.7239
ETA:0s/3s SUB:1
epoch 9, loss=0.5786 acc=0.693

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


ETA:0s/4s SUB:3
epoch 3, loss=0.5840 acc=0.7528 f1=0.0220
epoch 3, val, loss=0.3419 acc=1.0000 f1=0.0000
ETA:0s/4s SUB:3
epoch 4, loss=0.5894 acc=0.7500 f1=0.0000
epoch 4, val, loss=0.3725 acc=1.0000 f1=0.0000
ETA:0s/4s SUB:3
epoch 5, loss=0.5929 acc=0.7500 f1=0.0000
epoch 5, val, loss=0.3847 acc=1.0000 f1=0.0000


  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


ETA:0s/3s SUB:3
epoch 6, loss=0.5779 acc=0.7472 f1=0.0000
epoch 6, val, loss=0.3273 acc=1.0000 f1=0.0000
ETA:0s/3s SUB:3
epoch 7, loss=0.5756 acc=0.7500 f1=0.0000
epoch 7, val, loss=0.4161 acc=1.0000 f1=0.0000
ETA:0s/3s SUB:3
epoch 8, loss=0.5622 acc=0.7444 f1=0.0000
epoch 8, val, loss=0.3045 acc=1.0000 f1=0.0000


  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


ETA:0s/3s SUB:3
epoch 9, loss=0.5676 acc=0.7500 f1=0.0000
epoch 9, val, loss=0.4622 acc=1.0000 f1=0.0000
ETA:0s/3s SUB:3
epoch 10, loss=0.5576 acc=0.7500 f1=0.0000
epoch 10, val, loss=0.3016 acc=1.0000 f1=0.0000
ETA:0s/3s SUB:3
epoch 11, loss=0.5446 acc=0.7500 f1=0.0000
epoch 11, val, loss=0.3819 acc=1.0000 f1=0.0000


  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


ETA:0s/3s SUB:3
epoch 12, loss=0.5384 acc=0.7528 f1=0.0220
epoch 12, val, loss=0.3595 acc=1.0000 f1=0.0000
ETA:0s/3s SUB:3
epoch 13, loss=0.5422 acc=0.7528 f1=0.0220
epoch 13, val, loss=0.4310 acc=1.0000 f1=0.0000
ETA:1s/3s SUB:3
epoch 14, loss=0.5239 acc=0.7611 f1=0.0851
epoch 14, val, loss=0.4757 acc=0.9714 f1=0.0000
ETA:1s/3s SUB:3


  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


epoch 15, loss=0.5017 acc=0.7861 f1=0.2667
epoch 15, val, loss=0.5537 acc=0.7810 f1=0.0000
ETA:1s/3s SUB:3
epoch 16, loss=0.5405 acc=0.7917 f1=0.4526
epoch 16, val, loss=0.3874 acc=1.0000 f1=0.0000
ETA:1s/3s SUB:3
epoch 17, loss=0.5009 acc=0.7917 f1=0.3363
epoch 17, val, loss=0.5296 acc=0.8381 f1=0.0000
ETA:1s/3s SUB:3
epoch 18, loss=0.5025 acc=0.7750 f1=0.2957
epoch 18, val, loss=0.6421 acc=0.6000 f1=0.0000
ETA:1s/3s SUB:3
epoch 19, loss=0.4807 acc=0.7944 f1=0.4789
epoch 19, val, loss=0.3640 acc=0.9905 f1=0.0000
ETA:1s/3s SUB:3
epoch 20, loss=0.4670 acc=0.8194 f1=0.4961
epoch 20, val, loss=0.5308 acc=0.7810 f1=0.0000
ETA:1s/3s SUB:3
epoch 21, loss=0.4490 acc=0.8556 f1=0.6623
epoch 21, val, loss=0.5133 acc=0.7905 f1=0.0000
early stopping
/content/drive/MyDrive/11777-Project-xs/MLP/save/T_64_pool_16/model_eeg_A/sub3_mlp.pth
using gpu: 0
>>> Test:  loss=0.5422 acc=0.7417 f1=0.0606
>>> Data:(40, 15, 1, 32, 512) Label:(40, 15)
4 [] 40
25 7 8
(375, 3166) (105, 3166) (375,) (105,)
(120, 3166