In [37]:
import numpy as np
import scipy as sp
from tqdm import tqdm
import mne

from helper_code import load_participant_data, find_participant_ids
from helper_code import eeg_channels as EEG_CHANNELS
from helper_code import label_names as LABEL_NAMES

In [14]:
data_folder = '../DEAP/Data'
task = 'regression'

data_folder = data_folder[:-1] if data_folder.endswith('/') else data_folder
participant_ids = find_participant_ids(data_folder)
print('participant_ids', participant_ids)
video_features = []
eeg_features = []
other_physio_features = list()
labels_all = list()
print('loading data...')
for participant_id in tqdm(participant_ids):
    mat = sp.io.loadmat( '{0}/P{1}/s{1}.mat'.format(data_folder, participant_id) )
    ## labels
    labels = mat['labels'][:, :] # video/trial x label (valence, arousal, dominance, liking)
    labels_all.append(labels[np.newaxis, :, :])
    ## physiological signals
    physio_data = mat['data'] # video/trial x channel x data (physiological signals)
    # EEG
    eeg_data = physio_data[:, :32, :]
    eeg_features.append(eeg_data[np.newaxis, :, :, :])
    # other physiological signals
    other_physio_data = physio_data[:, 32:, :]
    other_physio_features.append(other_physio_data[np.newaxis, :, :, :])
        
labels_all = np.vstack(labels_all)
eeg_features = np.vstack(eeg_features)
other_physio_features = np.vstack(other_physio_features)

print('\tlabels_all:', labels_all.shape)
# print('\tvideo_features:', video_features.shape)
print('\teeg_features:', eeg_features.shape)
print('\tother_physio_features:', other_physio_features.shape)

participant_ids ['01', '02', '03', '04', '05', '06', '07', '08', '09', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32']
loading data...


100%|██████████| 32/32 [00:17<00:00,  1.82it/s]


	labels_all: (32, 40, 4)
	eeg_features: (32, 40, 32, 8064)
	other_physio_features: (32, 40, 8, 8064)


In [111]:
# Extract features from the EEG data.
def get_eeg_features(data, sampling_frequency):
    num_channels, num_samples = np.shape(data)

    if num_samples > 0:
        delta_psd, _ = mne.time_frequency.psd_array_welch(data, sfreq=sampling_frequency,  fmin=0.5,  fmax=4.0, verbose=False)
        theta_psd, _ = mne.time_frequency.psd_array_welch(data, sfreq=sampling_frequency,  fmin=4.0,  fmax=8.0, verbose=False)
        alpha_psd, _ = mne.time_frequency.psd_array_welch(data, sfreq=sampling_frequency,  fmin=8.0, fmax=12.0, verbose=False)
        beta_psd,  _ = mne.time_frequency.psd_array_welch(data, sfreq=sampling_frequency, fmin=12.0, fmax=30.0, verbose=False)
        gamma_psd,  _ = mne.time_frequency.psd_array_welch(data, sfreq=sampling_frequency, fmin=30.0, fmax=45.0, verbose=False)

        delta_psd_mean = np.nanmean(delta_psd, axis=1)
        theta_psd_mean = np.nanmean(theta_psd, axis=1)
        alpha_psd_mean = np.nanmean(alpha_psd, axis=1)
        beta_psd_mean  = np.nanmean(beta_psd,  axis=1)
        gamma_psd_mean = np.nanmean(gamma_psd, axis=1)

    else:
        delta_psd_mean = theta_psd_mean = alpha_psd_mean = beta_psd_mean = gamma_psd_mean = float('nan') * np.ones(num_channels)

    features = np.array((delta_psd_mean, theta_psd_mean, alpha_psd_mean, beta_psd_mean, gamma_psd_mean)).T

    return features

In [112]:
passband = [0.1, 45.0]
sampling_frequency = 128

eeg_psd_features = []
labels_all2 = []

participant_trials =[[1, 10],
                     [1, 23],
                     [2, 2],
                     [2, 5],
                     [2, 19],
                     [2, 39],
                     [4, 23],
                     [9, 24],
                     [11, 20],
                     [13, 14],
                     [13, 26],
                     [18, 40],
                     [22, 33]]

# for i in tqdm(range(32)):
#     for j in range(40):
for (i,j) in participant_trials:
    data = eeg_features[i-1][j-1]
    # Promote the data to double precision because these libraries expect double precision.
    data = np.asarray(data, dtype=np.float64)
    # Apply a bandpass filter.
    data = mne.filter.filter_data(data, sampling_frequency, passband[0], passband[1], n_jobs=-1, verbose='error') # filter out too low and too high frequency noise
    data = get_eeg_features(data, sampling_frequency).flatten() # (128,) 128 = 32 channels x 4 psd features

    eeg_psd_features.append(data[np.newaxis, :])
    labels_all2.append(labels_all[i-1][j-1][np.newaxis, :])

eeg_psd_features = np.vstack(eeg_psd_features)
labels_all2 = np.vstack(labels_all2)

In [114]:
eeg_psd_features.shape

(13, 160)

In [86]:
participant_trials_list = []
for i, j in participant_trials:
    participant_trials_list.append( 's{:02d}_trial{:02d}'.format(i,j) )
    print( 's{:02d}_trial{:02d}'.format(i,j) )

s01_trial10
s01_trial23
s02_trial02
s02_trial05
s02_trial19
s02_trial39
s04_trial23
s09_trial24
s11_trial20
s13_trial14
s13_trial26
s18_trial40
s22_trial33


In [121]:
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
plt.style.use('default')

sort_eeg_in_id = []

l2_dist = np.linalg.norm(eeg_psd_features[:, None, :] - eeg_psd_features[None, :, :], axis=-1)
# sns.heatmap(l2_dist, cmap='viridis')
for i in range(13):
    # [print(participant_trials_list[x], end=', ') for x in np.argsort(l2_dist[i])[1:]]
    print ()
    sort_eeg_in_id.append( list(np.argsort(l2_dist[i])[1:]) )
    print( 1+np.argsort(l2_dist[i]))

np.array(sort_eeg_in_id) + 1


[ 1  2 11 10 12  7  3  4  8  6  5  9 13]

[ 2  1 11 10 12  7  3  4  8  6  5  9 13]

[ 3  4  6  5  7 10 11 12  2  1  8  9 13]

[ 4  3  6  5  7 11 10 12  2  1  8  9 13]

[ 5  6  4  3  7 10 11 12  2  1  8  9 13]

[ 6  4  3  5  7 10 11 12  2  1  8  9 13]

[ 7 12 10 11  2  1  3  4  6  8  5  9 13]

[ 8 12 10 11  1  2  7  3  4  6  5  9 13]

[ 9  7  3  4  6  5 10 11 12  2  1  8 13]

[10 11  1  2 12  7  3  4  8  6  5  9 13]

[11 10  1  2 12  7  3  4  8  6  5  9 13]

[12  1  2 10 11  7  3  4  8  6  5  9 13]

[13  5  6  4  3  7  8 12 10 11  9  2  1]


array([[ 2, 11, 10, 12,  7,  3,  4,  8,  6,  5,  9, 13],
       [ 1, 11, 10, 12,  7,  3,  4,  8,  6,  5,  9, 13],
       [ 4,  6,  5,  7, 10, 11, 12,  2,  1,  8,  9, 13],
       [ 3,  6,  5,  7, 11, 10, 12,  2,  1,  8,  9, 13],
       [ 6,  4,  3,  7, 10, 11, 12,  2,  1,  8,  9, 13],
       [ 4,  3,  5,  7, 10, 11, 12,  2,  1,  8,  9, 13],
       [12, 10, 11,  2,  1,  3,  4,  6,  8,  5,  9, 13],
       [12, 10, 11,  1,  2,  7,  3,  4,  6,  5,  9, 13],
       [ 7,  3,  4,  6,  5, 10, 11, 12,  2,  1,  8, 13],
       [11,  1,  2, 12,  7,  3,  4,  8,  6,  5,  9, 13],
       [10,  1,  2, 12,  7,  3,  4,  8,  6,  5,  9, 13],
       [ 1,  2, 10, 11,  7,  3,  4,  8,  6,  5,  9, 13],
       [ 5,  6,  4,  3,  7,  8, 12, 10, 11,  9,  2,  1]])

In [107]:
participant_trials_list_dict = {participant_trials_list[i]: i for i in range(13)}

In [115]:
sort_video = [['s01_trial23', 's13_trial14', 's13_trial26', 's18_trial40', 's02_trial39', 's04_trial23', 's02_trial05', 's02_trial19', 's02_trial02', 's11_trial20', 's09_trial24', 's22_trial33'],
['s01_trial10', 's13_trial14', 's13_trial26', 's18_trial40', 's04_trial23', 's02_trial39', 's02_trial19', 's02_trial05', 's11_trial20', 's02_trial02', 's09_trial24', 's22_trial33'],
['s02_trial05', 's02_trial39', 's02_trial19', 's04_trial23', 's11_trial20', 's09_trial24', 's13_trial26', 's13_trial14', 's22_trial33', 's01_trial23', 's01_trial10', 's18_trial40'],
['s02_trial02', 's02_trial39', 's02_trial19', 's04_trial23', 's11_trial20', 's09_trial24', 's13_trial26', 's13_trial14', 's22_trial33', 's01_trial23', 's01_trial10', 's18_trial40'],
['s02_trial39', 's02_trial05', 's02_trial02', 's04_trial23', 's11_trial20', 's13_trial26', 's09_trial24', 's13_trial14', 's22_trial33', 's01_trial23', 's01_trial10', 's18_trial40'],
['s02_trial19', 's02_trial05', 's02_trial02', 's04_trial23', 's11_trial20', 's13_trial26', 's09_trial24', 's13_trial14', 's01_trial23', 's22_trial33', 's01_trial10', 's18_trial40'],
['s09_trial24', 's11_trial20', 's02_trial39', 's02_trial19', 's13_trial26', 's02_trial05', 's02_trial02', 's13_trial14', 's18_trial40', 's01_trial23', 's22_trial33', 's01_trial10'],
['s04_trial23', 's11_trial20', 's02_trial39', 's13_trial26', 's02_trial19', 's18_trial40', 's02_trial05', 's13_trial14', 's02_trial02', 's01_trial23', 's22_trial33', 's01_trial10'],
['s04_trial23', 's09_trial24', 's13_trial26', 's02_trial39', 's13_trial14', 's02_trial19', 's02_trial05', 's02_trial02', 's18_trial40', 's01_trial23', 's22_trial33', 's01_trial10'],
['s13_trial26', 's01_trial23', 's11_trial20', 's04_trial23', 's18_trial40', 's01_trial10', 's02_trial39', 's02_trial19', 's09_trial24', 's02_trial05', 's02_trial02', 's22_trial33'],
['s13_trial14', 's11_trial20', 's04_trial23', 's01_trial23', 's02_trial39', 's18_trial40', 's09_trial24', 's02_trial19', 's01_trial10', 's02_trial05', 's02_trial02', 's22_trial33'],
['s04_trial23', 's13_trial14', 's13_trial26', 's01_trial23', 's09_trial24', 's11_trial20', 's01_trial10', 's02_trial39', 's02_trial05', 's02_trial19', 's02_trial02', 's22_trial33'],
['s02_trial19', 's04_trial23', 's02_trial05', 's02_trial02', 's02_trial39', 's11_trial20', 's09_trial24', 's13_trial26', 's13_trial14', 's18_trial40', 's01_trial23', 's01_trial10'],]
sort_video_in_id = []
for sort1 in sort_video:
    sort_video_in_id.append([participant_trials_list_dict[x] for x in sort1])
    [print(participant_trials_list_dict[x]+1, end=', ') for x in sort1]
    print()

sort_video_in_id



2, 10, 11, 12, 6, 7, 4, 5, 3, 9, 8, 13, 
1, 10, 11, 12, 7, 6, 5, 4, 9, 3, 8, 13, 
4, 6, 5, 7, 9, 8, 11, 10, 13, 2, 1, 12, 
3, 6, 5, 7, 9, 8, 11, 10, 13, 2, 1, 12, 
6, 4, 3, 7, 9, 11, 8, 10, 13, 2, 1, 12, 
5, 4, 3, 7, 9, 11, 8, 10, 2, 13, 1, 12, 
8, 9, 6, 5, 11, 4, 3, 10, 12, 2, 13, 1, 
7, 9, 6, 11, 5, 12, 4, 10, 3, 2, 13, 1, 
7, 8, 11, 6, 10, 5, 4, 3, 12, 2, 13, 1, 
11, 2, 9, 7, 12, 1, 6, 5, 8, 4, 3, 13, 
10, 9, 7, 2, 6, 12, 8, 5, 1, 4, 3, 13, 
7, 10, 11, 2, 8, 9, 1, 6, 4, 5, 3, 13, 
5, 7, 4, 3, 6, 9, 8, 11, 10, 12, 2, 1, 


[[1, 9, 10, 11, 5, 6, 3, 4, 2, 8, 7, 12],
 [0, 9, 10, 11, 6, 5, 4, 3, 8, 2, 7, 12],
 [3, 5, 4, 6, 8, 7, 10, 9, 12, 1, 0, 11],
 [2, 5, 4, 6, 8, 7, 10, 9, 12, 1, 0, 11],
 [5, 3, 2, 6, 8, 10, 7, 9, 12, 1, 0, 11],
 [4, 3, 2, 6, 8, 10, 7, 9, 1, 12, 0, 11],
 [7, 8, 5, 4, 10, 3, 2, 9, 11, 1, 12, 0],
 [6, 8, 5, 10, 4, 11, 3, 9, 2, 1, 12, 0],
 [6, 7, 10, 5, 9, 4, 3, 2, 11, 1, 12, 0],
 [10, 1, 8, 6, 11, 0, 5, 4, 7, 3, 2, 12],
 [9, 8, 6, 1, 5, 11, 7, 4, 0, 3, 2, 12],
 [6, 9, 10, 1, 7, 8, 0, 5, 3, 4, 2, 12],
 [4, 6, 3, 2, 5, 8, 7, 10, 9, 11, 1, 0]]

In [122]:
from scipy.stats import kendalltau

sort_eeg_in_id = np.array(sort_eeg_in_id)
sort_video_in_id = np.array(sort_video_in_id)

kendalltau_dist = [kendalltau(x, y)[0] for x, y in zip(sort_eeg_in_id, sort_video_in_id)]

In [123]:
kendalltau_dist

[0.6666666666666666,
 0.6969696969696969,
 0.1212121212121212,
 0.1515151515151515,
 0.0909090909090909,
 0.3636363636363636,
 -0.0909090909090909,
 -0.1212121212121212,
 -0.4242424242424242,
 0.3636363636363636,
 -0.1515151515151515,
 0.0,
 0.7575757575757575]

In [44]:
label_high_valence = 1 * (labels_all2[:, 0] > 5)
label_high_arousal = 1 * (labels_all2[:, 1] > 5)

In [63]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

n_estimators = 100
max_leaf_nodes = 10
random_state = 42

X_train, X_test, y_train, y_test = train_test_split(eeg_psd_features, label_high_valence, test_size=0.3, random_state=42)
X_tra, X_val, y_tra, y_val = train_test_split(X_train, y_train, test_size=0.3, random_state=42)

rf = RandomForestClassifier(
        n_estimators=n_estimators, max_leaf_nodes=max_leaf_nodes, random_state=random_state).fit(X_tra, y_tra)
y_tra_pred = rf.predict(X_tra)
y_val_pred = rf.predict(X_val)
y_test_pred = rf.predict(X_test)
(y_val_pred == y_val).sum() / y_val.shape[0]
(y_test_pred == y_test).sum() / y_test.shape[0]
# (y_tra_pred == y_tra).sum() / y_tra.shape[0]


0.6041666666666666

In [41]:
eeg_psd_features

array([[ 0.56157039,  0.99566174,  0.91894346, ...,  1.26635582,
         1.59198698,  0.3557368 ],
       [ 0.69136403,  1.21474133,  1.12074504, ...,  2.08229943,
         1.9683278 ,  0.37309962],
       [ 0.64283221,  1.13670878,  1.03838757, ...,  1.75204526,
         2.31934234,  0.44637827],
       ...,
       [ 0.80967844,  1.43031135,  0.81008333, ..., 19.71596951,
         2.45146701,  0.64780662],
       [ 0.83249276,  1.47169889,  0.96139856, ...,  5.6457659 ,
         1.58424465,  0.5764151 ],
       [ 0.84376836,  1.4942198 ,  0.89558639, ..., 22.47292211,
         2.28325616,  0.58072079]])