In [None]:
# Old name of the notebook: OmissionsMEGAnalysis_PredictabilityLevel

In [None]:
import mne
from mne import find_events
from mne.decoding import Vectorizer, SlidingEstimator, cross_val_multiscore

import numpy as np
import os
import random
random.seed(42)
import warnings
warnings.filterwarnings('ignore')

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV, cross_val_score, train_test_split, StratifiedKFold, RepeatedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, classification_report, accuracy_score, precision_recall_fscore_support, balanced_accuracy_score

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
#%matplotlib tk
import pickle
from scipy.stats import sem
from scipy.stats import wilcoxon

import import_ipynb
from CommonFunctions import loadData, extractDataAndLabels, concatNonEmpty, trainAndTest_MVPA, plot_MVPA

In [None]:
sensors = 'mag'

#File paths
meg_MainFolder = "..\Data\MEG_Data\Data="
figures_MainFolder = "..\Figures\MVPA\Data="
results_MainFolder = "..\Results\Data="
classifiers_MainFolder = "..\Classifiers\Data="

tmin, tmax = -0.1, 0.6

# Decide the time limit based on the time range of the data
tlim = 0
if tmin == -0.8:
    tlim = 140
elif tmin == -0.1:
    tlim = 70

print('tmin = ', tmin)
print('tmax = ', tmax)

dataFolder = meg_MainFolder + str(tmin) + '_' + str(tmax) + '\\'
print('Data folder: ', dataFolder)

figuresFolder = figures_MainFolder + str(tmin) + '_' + str(tmax) + '\\'
print('Figures folder: ', figuresFolder)

results_folder = results_MainFolder + str(tmin) + '_' + str(tmax) + '\\'
print('Results folder: ', results_folder)

clsfFolder = classifiers_MainFolder + str(tmin) + '_' + str(tmax) + '\\'
print('Classifiers folder: ', clsfFolder)



In [None]:
filename_ext = ''
if tmin == -0.8 and tmax == 0.6:
    filename_ext = '-elongated'
elif tmin == -0.8 and tmax == 1:
    filename_ext = '-elongated_2'
    
print('filename ext: ', filename_ext)

In [None]:
s_ids = ['13', '16', '17', '18', '21', '23', '26', '28', '29', '30', '31', '32',
         '33', '34', '35', '36', '38', '39', '40', '41', '42']

print('Number of subjects: ', len(s_ids))

#### Train classifiers

In [None]:
group_results = []

for s_id in s_ids:

    print('------------ ' + s_id + '------------ ')
    if int(s_id) < 23:
        fname = dataFolder+'S'+s_id+'\\'+s_id+'_2_tsss_mc_trans_'+sensors+'_nobase-epochs_afterICA'+filename_ext+'_manually_AR.fif'
    else: 
        fname = dataFolder+'S'+s_id+'\\block_2_tsss_mc_trans_'+sensors+'_nobase-epochs_afterICA'+filename_ext+'_manually_AR.fif'

    # Check if data is resampled already
    if os.path.isfile(fname[:-4]+'_resampled.fif'): 
        print('Already resampled data!')
        epochs = loadData(s_id, sensors, fname[:-4]+'_resampled.fif', resampled=True)
    
    # If not, resample
    else:
        print('Data will be resampled!')
        epochs = loadData(s_id, sensors, fname, resampled=False)
    
    print('Number of epochs: ', len(epochs))
    #Split data
    
    #Real data and labels
    data_real_living, labels_real_living = extractDataAndLabels(epochs, ['living_real_8', 'living_real_9', 'living_real_10'])
    data_real_object, labels_real_object = extractDataAndLabels(epochs, ['object_real_8', 'object_real_9', 'object_real_10'])

    #living omission data and labels
    data_omission_living, labels_omission_living = extractDataAndLabels(epochs, ['living_omission_8_corr', 'living_omission_8_incorr', 'living_omission_9_corr', 'living_omission_9_incorr', 'living_omission_10_corr', 'living_omission_10_incorr'])

    #object omission data and labels
    data_omission_object, labels_omission_object = extractDataAndLabels(epochs, ['object_omission_8_corr', 'object_omission_8_incorr', 'object_omission_9_corr', 'object_omission_9_incorr', 'object_omission_10_corr', 'object_omission_10_incorr'])

    
    #Real sounds 80% predictibility
    data_real_8 = concatNonEmpty([data_real_living[0], data_real_object[0]])
    labels_real_8 = concatNonEmpty([labels_real_living[0], labels_real_object[0]])

    #Real sounds 90% predictibility
    data_real_9 = concatNonEmpty([data_real_living[1], data_real_object[1]])
    labels_real_9 = concatNonEmpty([labels_real_living[1], labels_real_object[1]])

    #Real sounds 100% predictibility
    data_real_10 = concatNonEmpty([data_real_living[2], data_real_object[2]])
    labels_real_10 = concatNonEmpty([labels_real_living[2], labels_real_object[2]])

    #All levels together

    data_real_all = concatNonEmpty([data_real_8, data_real_9, data_real_10])
    labels_real_all = concatNonEmpty([labels_real_8, labels_real_9, labels_real_10])


    #omission sounds
    #80% predictibility
    data_omission_8= concatNonEmpty([data_omission_living[0], data_omission_living[1], data_omission_object[0], data_omission_object[1]])
    labels_omission_8 = concatNonEmpty([labels_omission_living[0], labels_omission_living[1], labels_omission_object[0], labels_omission_object[1]])

    #90% predictibility
    data_omission_9 = concatNonEmpty([data_omission_living[2], data_omission_living[3], data_omission_object[2], data_omission_object[3]])
    labels_omission_9 = concatNonEmpty([labels_omission_living[2], labels_omission_living[3], labels_omission_object[2], labels_omission_object[3]])

    #100% predictibility
    data_omission_10 = concatNonEmpty([data_omission_living[4], data_omission_living[5], data_omission_object[4], data_omission_object[5]])
    labels_omission_10 = concatNonEmpty([labels_omission_living[4], labels_omission_living[5], labels_omission_object[4], labels_omission_object[5]])


    ## MVPA 
    
    # NOTE: Below part in comment is for classification on each predictability level. If you don't want to test this, 
    # skip below part and go to line 116 where classification on all levels together starts
    '''
    results = []

    outputfilename_8 = results_folder+s_id + "_" + sensors + "_results_8_predLevel"
    bestParametersFile_8 = results_folder + s_id + "_" + sensors + "_bestParametes_8_predLevel.txt"
    clsfFile = clsfFolder + "_8_predLevel.pkl"
    results_8 = trainAndTest_MVPA(data_real_8, labels_real_8, [data_omission_8], [labels_omission_8], outputfilename_8, bestParametersFile_8, tlim, clsfFile, nFolds=5)
    results.append(results_8)

    plotname_8 = filep_figures + 'MVPA_S' + s_id+'_' + sensors + '_8_predLevel.png'
    plot_MVPA(results_8, epochs.times[:tlim], tlim, plotname_8)

    outputfilename_9 = results_folder+s_id + "_" + sensors + "_results_9_predLevel"
    bestParametersFile_9 = results_folder + s_id + "_" + sensors + "_bestParametes_9_predLevel.txt"
    clsfFile = clsfFolder + "_9_predLevel.pkl"
    results_9 = trainAndTest_MVPA(data_real_9, labels_real_9, [data_omission_9], [labels_omission_9], outputfilename_9, bestParametersFile_9, tlim, clsfFile, nFolds=5)
    results.append(results_9) 

    plotname_9 = filep_figures + 'MVPA_S' + s_id+'_' + sensors + '_9_predLevel.png'
    plot_MVPA(results_9,  epochs.times[:tlim], tlim, plotname_9)


    outputfilename_10 = results_folder+s_id + "_" + sensors + "_results_10_predLevel"
    bestParametersFile_10 = results_folder + s_id + "_" + sensors + "_bestParametes_10_predLevel.txt"
    clsfFile = clsfFolder + "_10_predLevel.pkl"

    results_10 = trainAndTest_MVPA(data_real_10, labels_real_10, [data_omission_10], [labels_omission_10], outputfilename_10, bestParametersFile_10, tlim, clsfFile, nFolds=5)
    results.append(results_10) 

    plotname_10 = filep_figures + 'MVPA_S' + s_id+'_' + sensors + '_10_predLevel.png'
    plot_MVPA(results_10,  epochs.times[:tlim], tlim, plotname_10)



    outputfilename_all = results_folder+s_id + "_" + sensors + "_results_all_predLevel_linearKernel"
    bestParametersFile_all = results_folder + s_id + "_" + sensors + "_bestParameters_all_linearKernel.txt"
    clsfFile = clsfFolder + "all_predLevel.pkl"
    print(clsfFile)

    '''
    outputfilename_all = results_folder+ '\\S' + s_id + '\\' + s_id + "_" + sensors + "_results_all_predLevel"
    print(outputfilename_all)
    bestParametersFile_all = results_folder + '\\S' + s_id + '\\' + s_id + "_" + sensors + "_bestParameters_all.txt"
    clsfFile = clsfFolder + s_id + "\\all_predLevel.pkl"
    results_all = trainAndTest_MVPA(data_real_all, labels_real_all,
                                    [data_omission_8, data_omission_9, data_omission_10], 
                                    [labels_omission_8, labels_omission_9, labels_omission_10], 
                                    outputfilename_all, bestParametersFile_all, tlim, clsfFile, nFolds=5,
                                    bestParamsFound=False)
    
    
    #results.append(results_all) 

    plotname_all = figuresFolder + 'S' + s_id + '\\MVPA_S' + s_id+'_' + sensors + '_all_predLevel.png'
    print(plotname_all)
    plot_MVPA(results_all, epochs.times[:tlim], tlim, plotname_all)
    
    group_results.append(np.asarray(results_all))
    
    del data_real_living, data_real_object, data_omission_living, data_omission_object
    

#### Load results if you have them already

In [None]:
# Load one subjects epoch data to use times
s_id = '17'
fname = dataFolder+'S'+s_id+'\\'+s_id+'_2_tsss_mc_trans_'+sensors+'_nobase-epochs_afterICA'+filename_ext+'_manually_AR_resampled.fif'
epochs = mne.read_epochs(fname)    

In [None]:
group_results = []

for s_id in s_ids:
    outputfilename = results_folder+ '\\S' + s_id + '\\' + s_id + "_" + sensors + "_results_all_predLevel.npy"
    res_tmp = np.load(outputfilename, allow_pickle=True)
    print('Shape of loaded results: ', res_tmp.shape)
    group_results.append(res_tmp)

In [None]:
group_results = np.asarray(group_results)
group_results[:,0].shape

In [None]:
group_results[:,0]

In [None]:
group_results[:,-1]

#### Check results if there are 0s

#print('is there any 0: ', np.where(np.sum(group_results, axis = 1) == 0))
group_real = np.zeros((len(s_ids),tlim))
for i in range(group_results[:,0].shape[0]): # use only the index 0 as it keeps the real sonds results and the rest is omissions and CV
    group_real[i,:] = group_results[i][0]
print(group_real.shape)  

group_avg_real = np.nanmean(group_real, axis=0)
group_std_real = sem(group_real, axis = 0, nan_policy='omit')/2

In [None]:
print('is there any 0: ', np.where(group_results[:,0] == 0))
group_real = np.zeros((len(s_ids),tlim))
for i in range(group_results[:,0].shape[0]): # use only the index 0 as it keeps the real sonds results and the rest is omissions and CV

    group_real[i,:] = group_results[i][0]
print(group_real.shape)  

group_avg_real = np.nanmean(group_real, axis=0)
group_std_real = sem(group_real, axis = 0, nan_policy='omit')/2

In [None]:
group_real.shape
np.save(results_folder+'real_sounds_testResults_pointBypointAnalysis', group_real)

### Apply stats to compare against chance level

In [None]:
p_threshold = 0.004
print('p value threshold: ', p_threshold)
nt = len(epochs.times)
preal = np.ones(nt)
chance = 0.5
for t in np.arange(nt):
    x, preal[t] = wilcoxon(group_real[:,t]-chance, alternative='greater')
    
print('\nP values: ', preal)

max_score = np.max(group_avg_real)
time_max_score = epochs.times[np.where(group_avg_real == max_score)[0].tolist()]
print('\nMaximum score of ' + str(max_score)+' achieved at ' + str(time_max_score))


In [None]:
from mne.stats import bonferroni_correction
r,p_corrected = mne.stats.bonferroni_correction(preal)
print('P values before the correction: ', preal)
print('P values after the correction: ', p_corrected)

In [None]:
p_threshold

In [None]:
times_sig = epochs.times[np.where(preal<= p_threshold)[0].tolist()]
print('\nSignificant time points before correction: ', times_sig)

times_sig = epochs.times[np.where(p_corrected <= p_threshold)[0].tolist()]
print('\nSignificant time points after correction: ', times_sig)


#### Plot the scores

In [None]:

color = np.array((27,120,55))/256
color_chanceLevel = np.array((169,169,169))/256

al = 0.2
font_size = 14

fig = plt.figure(num=None, figsize=(6, 4), dpi=150)
ax = fig.add_subplot(1,1,1)

plt.plot(epochs.times, group_avg_real, color = color, linewidth = 4, label='Real Sounds')
plt.fill_between(epochs.times, group_avg_real, group_avg_real + group_std_real, color=color,
                 interpolate=True, alpha = al)
plt.fill_between(epochs.times, group_avg_real, group_avg_real - group_std_real, color=color,
                 interpolate=True, alpha = al)

plt.xlabel('Time(s)', fontsize=font_size)
plt.ylabel('AUC', fontsize=font_size)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)


timeInterval = 0.01
print('TIme between two time points: ', timeInterval)

if len(times_sig)>0:
    times_sig_periods = []
    print('Significant time points exist!')
    
    # Computet the time periods that are significcantt
    start, end = None, None
    for k in range(len(times_sig)):
        if k == 0:
            start = times_sig[k]
            if len(times_sig) == 1: # if we have only 1 time point that is significant
                end = times_sig[k]
                times_sig_periods.append([start,end])
        else:
            if round(times_sig[k],2) != round(times_sig[k-1] + timeInterval, 2): # if the data points are not continuous
                end =  times_sig[k-1]
                times_sig_periods.append([start,end])
                start = times_sig[k]
            else:
                if k == len(times_sig)-1:
                    print('the end of computation!')
                    end = times_sig[k]
                    times_sig_periods.append([start,end])

                
    print('Significant time intervals: ', times_sig_periods) 
    
    for p in range(len(times_sig_periods)):
        if times_sig_periods[p][0] == times_sig_periods[p][1]:
            print('Here')
            ax.plot(times_sig_periods[p][0], 0.49)
        else:
            ax.hlines(xmin=times_sig_periods[p][0], xmax=times_sig_periods[p][1], y=0.49, color=color, linestyle='-')
    
    ax.hlines(xmin=epochs.times[0], xmax=epochs.times[-1], y=0.5, color=color_chanceLevel, linestyle='--', label='Chance')
    plt.ylim(0.4, 0.78)
    plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
    plt.savefig(figures_MainFolder+'Group_Level_real_sounds_timePointByTimePoint_bonfr_crr_p='+str(p_threshold)+'.png',  bbox_inches='tight')
   


### CV scores

In [None]:
print('is there any 0: ', np.where(group_results[:,-1] == 0))
group_real = np.zeros((len(s_ids),tlim))
for i in range(group_results[:,-1].shape[0]): # use only the index 0 as it keeps the real sonds results and the rest is omissions and CV

    group_real[i,:] = np.mean(group_results[i][-1], axis=0)
print(group_real.shape)  

group_avg_real = np.nanmean(group_real, axis=0)
group_std_real = sem(group_real, axis = 0, nan_policy='omit')/2

In [None]:
p_threshold = 0.001
print('p value threshold: ', p_threshold)
nt = len(epochs.times)
preal = np.ones(nt)
chance = 0.5
for t in np.arange(nt):
    x, preal[t] = wilcoxon(group_real[:,t]-chance, alternative='greater')
    
print('\nP values: ', preal)

max_score = np.max(group_avg_real)
time_max_score = epochs.times[np.where(group_avg_real == max_score)[0].tolist()]
print('\nMaximum score of ' + str(max_score)+' achieved at ' + str(time_max_score))


In [None]:
from mne.stats import bonferroni_correction
r,p_corrected = mne.stats.bonferroni_correction(preal)
print('P values before the correction: ', preal)
print('P values after the correction: ', p_corrected)

In [None]:
times_sig_wo_corr = epochs.times[np.where(preal<= p_threshold)[0].tolist()]
print('\nSignificant time points before correction: ', times_sig_wo_corr)

times_sig = epochs.times[np.where(p_corrected <= p_threshold)[0].tolist()]
print('\nSignificant time points after correction: ', times_sig)


In [None]:

color = np.array((27,120,55))/256
color_chanceLevel = np.array((169,169,169))/256

al = 0.2
font_size = 14

fig = plt.figure(num=None, figsize=(6, 4), dpi=150)
ax = fig.add_subplot(1,1,1)

plt.plot(epochs.times, group_avg_real, color = color, linewidth = 4, label='Real Sounds')
plt.fill_between(epochs.times, group_avg_real, group_avg_real + group_std_real, color=color,
                 interpolate=True, alpha = al)
plt.fill_between(epochs.times, group_avg_real, group_avg_real - group_std_real, color=color,
                 interpolate=True, alpha = al)

plt.xlabel('Time(s)', fontsize=font_size)
plt.ylabel('AUC', fontsize=font_size)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)


timeInterval = 0.01
print('TIme between two time points: ', timeInterval)

if len(times_sig)>0:
    times_sig_periods = []
    print('Significant time points exist!')
    
    # Computet the time periods that are significcantt
    start, end = None, None
    for k in range(len(times_sig)):
        if k == 0:
            start = times_sig[k]
            if len(times_sig) == 1: # if we have only 1 time point that is significant
                end = times_sig[k]
                times_sig_periods.append([start,end])
        else:
            if round(times_sig[k],2) != round(times_sig[k-1] + timeInterval, 2): # if the data points are not continuous
                end =  times_sig[k-1]
                times_sig_periods.append([start,end])
                start = times_sig[k]
            else:
                if k == len(times_sig)-1:
                    print('the end of computation!')
                    end = times_sig[k]
                    times_sig_periods.append([start,end])

                
    print('Significant time intervals: ', times_sig_periods) 
    
    for p in range(len(times_sig_periods)):
        if times_sig_periods[p][0] == times_sig_periods[p][1]:
            print('Here')
            ax.plot(times_sig_periods[p][0], 0.49)
        else:
            ax.hlines(xmin=times_sig_periods[p][0], xmax=times_sig_periods[p][1], y=0.49, color=color, linestyle='-')
    
    ax.hlines(xmin=epochs.times[0], xmax=epochs.times[-1], y=0.5, color=color_chanceLevel, linestyle='--', label='Chance')
    
    plt.ylim(0.4, 0.78)
    plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
    plt.savefig(figures_MainFolder+'CV_Group_Level_real_sounds_timePointByTimePoint_bonfr_crr_p='+str(p_threshold)+'.png',  bbox_inches='tight')
   
