## Script for Morbi EEG project ERP analysis

In [None]:
%matplotlib inline
import os
import mne
from mne.time_frequency import tfr_morlet, psd_multitaper, psd_welch, psd_array_multitaper
from mne.viz import plot_topomap
from scipy.stats import ttest_1samp, ttest_ind, ttest_rel, f_oneway
from mne.stats import fdr_correction, f_mway_rm, permutation_cluster_test
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from jupyterthemes import jtplot
jtplot.style(theme='grade3')

data_path = 'E:/Bilingual_Morphology_Project/Data/EEG_prep'
plot_path = 'E:/Bilingual_Morphology_Project/Results/Plot'
sub_list = list(range(5,35))
mark_list = list(range(2,10))

## Section 1: Meta data import and channel location

In [None]:
demo = mne.read_epochs_eeglab('E:/Bilingual_Morphology_Project/Data/EEG_prep/S31_09.set')
montage_file = 'E:/Bilingual_Morphology_Project/Scripts/morbi.loc'
montage = mne.channels.read_custom_montage(montage_file)
demo.set_montage(montage)

eeg_meta = {}
for sub in sub_list:
    conditions = {}
    for mark in mark_list:
        tp = mne.read_epochs_eeglab(os.path.join(data_path, 'S' + str(sub) + '_0'+ str(mark) + '.set'))
        tp.set_montage(montage)
        conditions[mark] = tp
    eeg_meta[sub] = conditions

In [None]:
# Channel indexing
# Obtain the channel names as a list
ch_names = eeg_meta[5][2].ch_names
print(ch_names)
# Crate a index list for channels
ch_idx = list(range(31))
# Combine the channels and index and convert to a dict
ch_num = dict(zip(ch_names, ch_idx))
print(ch_num['PO9'])

## Section 2: ERP analysis

### Functions for ERP analysis

In [None]:
# Function for plotting erp wave of single condition
def plot_erp(erp, times, con_labels, title):
    n_subjects = np.shape(erp)[0]
    avg = np.average(erp, axis=0)
    err = np.std(erp, axis=0, ddof=0)/np.sqrt(n_subjects)
    ##Plotting parameters
    plt.figure()
    ax = plt.axes()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_color('black')
    ax.spines['left'].set_color('black')
    ax.spines['bottom'].set_linewidth(3)
    ax.spines['left'].set_linewidth(3)
    plt.gca().invert_yaxis()
    plt.tick_params(direction='in',length=10,width=3,labelsize=20)
    plt.xlim(-200,800)
    plt.ylim(6,-8)
    plt.grid()
    
    # Plot
    plt.fill_between(times, avg+err, avg-err, alpha=0.2)
    plt.plot(times, avg, alpha=0.9, label=con_labels,lw=3)
    plt.axhline(y=0, color="black",lw=2)
    plt.axvline(x=0, color="black", linestyle="--",lw=2)
    plt.xlabel('Time (ms)',fontdict={'family':'Arial', 'weight':'bold','size':25})
    plt.ylabel('Amplitude ($\mu$V)', fontdict={'family':'Arial', 'weight':'bold','size':25})
    plt.title(str(title),fontdict={'family':'Arial', 'weight':'bold','size':30})
    plt.legend(loc='best', prop={'family':'Arial', 'size':20})
    plt.show()

In [None]:
def plot_erp_comparison(erp1, erp2, times, con_labels=['Condition1', 'Condition2'], p_threshold=0.05):
    n_subjects = np.shape(erp1)[0]
    avg1 = np.average(erp1, axis=0)
    avg2 = np.average(erp2, axis=0)
    err1 = np.std(erp1, axis=0, ddof=0)/np.sqrt(n_subjects)
    err2 = np.std(erp2, axis=0, ddof=0)/np.sqrt(n_subjects)
    
    # Statistics
    t_vals, p_vals = ttest_rel(erp1, erp2, axis=0)
    # FDR multiple comariosn correction
    rejects, p_fdr_corrected = fdr_correction(p_vals, alpha=p_threshold)
    # Deleneate the significant windows
    for i, p_val in enumerate(p_vals):
        if p_val < 0.05:
            plt.axvline(x=times[i], color='grey', alpha=0.2)
    
    plt.fill_between(times, avg1+err1, avg1-err1, alpha=0.2, label=con_labels[0])
    plt.fill_between(times, avg2+err2, avg2-err2, alpha=0.2, label=con_labels[1])
    plt.plot(times, avg1, alpha=0.9)
    plt.plot(times, avg2, alpha=0.9)
    plt.axhline(y=0, color="black",lw=2)
    plt.axvline(x=0, color="black", linestyle="--",lw=2)
    plt.xlim(-200,800)
    plt.gca().invert_yaxis()
    plt.xlabel('Time (ms)')
    plt.ylabel('Amplitude ($\mu$V)')
    plt.legend()
    plt.show()

### Obtain the data array and ERP

In [None]:
cp_meta, cc_meta, ep_meta, ec_meta= {}, {}, {}, {}
for sub in sub_list:
    epoch = eeg_meta[sub]
    # Convert data to numpy array
    cp = np.concatenate([epoch[2].get_data(),epoch[3].get_data()],axis=0)*10**6  # Chinese priming condition
    cc = np.concatenate([epoch[4].get_data(),epoch[5].get_data()],axis=0)*10**6  # Chinese control condition
    ep = np.concatenate([epoch[6].get_data(),epoch[7].get_data()],axis=0)*10**6  # English priming condition
    ec = np.concatenate([epoch[9].get_data(),epoch[8].get_data()],axis=0)*10**6  # English control condition
    cp_meta[sub], cc_meta[sub], ep_meta[sub], ec_meta[sub] = cp, cc, ep, ec


# ERP data structure: [n_channels, n_sub, n_times]
cp_erp, cc_erp, ep_erp, ec_erp = np.zeros([31,30,500]), np.zeros([31,30,500]), np.zeros([31,30,500]), np.zeros([31,30,500])
# Loop across all channels and subjects
for ch in ch_idx:
    for sub in sub_list:
        cp_erp[ch,sub-5,:] = np.average(cp_meta[sub][:,ch,:], axis=0)
        cc_erp[ch,sub-5,:] = np.average(cc_meta[sub][:,ch,:], axis=0)
        ep_erp[ch,sub-5,:] = np.average(ep_meta[sub][:,ch,:], axis=0)
        ec_erp[ch,sub-5,:] = np.average(ec_meta[sub][:,ch,:], axis=0)


### Plot Demo

In [None]:
# Construct timeseries
times = np.arange(-200, 800, 2)
plot_erp(cp_erp[4,:,:], times, con_labels='Chinese Priming', title='')

## Section 3: Mass-univariate ANOVA for ERP

### 3.1 Single channel ERP

In [None]:
factor_levels = [2, 2]
times = np.arange(-200, 800, 2)

for ch in ch_idx:
    reshaped_A1B1 = cp_erp[ch,:,:].reshape(30, 1, 500)
    reshaped_A1B2 = cc_erp[ch,:,:].reshape(30, 1, 500)
    reshaped_A2B1 = ep_erp[ch,:,:].reshape(30, 1, 500)
    reshaped_A2B2 = ec_erp[ch,:,:].reshape(30, 1, 500)

    # Concatenate the data as structure (A1B1、A1B2、A2B1、A2B2)
    data_combine = np.concatenate((reshaped_A1B1, reshaped_A1B2, reshaped_A2B1, reshaped_A2B2), axis=1)
    
    # f_mway_rm for repeated measure ANOVA analysis
    # Main effect of A
    f_main_A, p_main_A = f_mway_rm(data_combine, factor_levels, effects='A')
    # Main effect of B
    f_main_B, p_main_B = f_mway_rm(data_combine, factor_levels, effects='B')
    # Interaction effect 
    f_inter, p_interaction = f_mway_rm(data_combine, factor_levels, effects='A:B')
    # FDR correction
    rejects_A, p_main_A = fdr_correction(p_main_A, alpha=0.05)
    rejects_B, p_main_B = fdr_correction(p_main_B, alpha=0.05)
    rejects_inter, p_interaction = fdr_correction(p_interaction, alpha=0.05)
    
    #Plotting parameters
    plt.figure()
    ax = plt.axes()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_color('black')
    ax.spines['left'].set_color('black')
    ax.spines['bottom'].set_linewidth(3)
    ax.spines['left'].set_linewidth(3)
    plt.gca().invert_yaxis()
    plt.tick_params(direction='in',length=10,width=3,labelsize=20)
    plt.xlim(-200,800)
    plt.ylim(6,-8)
    plt.grid()
    
    # Plot
    plt.plot(times, np.average(cp_erp[ch,:,:], axis=0), label='Chinese priming',lw=3)
    plt.plot(times, np.average(cc_erp[ch,:,:], axis=0), label='Chinese control',lw=3)
    plt.plot(times, np.average(ep_erp[ch,:,:], axis=0), label='English priming',lw=3)
    plt.plot(times, np.average(ec_erp[ch,:,:], axis=0), label='English control',lw=3)
    plt.axhline(y=0, color="black",lw=2)
    plt.axvline(x=0, color="black", linestyle="--",lw=2)
    plt.xlabel('Time (ms)',fontdict={'family':'Arial', 'weight':'bold','size':25})
    plt.ylabel('Amplitude ($\mu$V)', fontdict={'family':'Arial', 'weight':'bold','size':25})
    plt.legend(loc='best', prop={'family':'Arial', 'size':20})
    
    # Annotate the significant time-window
    for i in range(250):
        if p_main_A[i] < 0.05:
            plt.axvline(x=times[i], ymin=0, color='blue', alpha=0.3)
        if p_main_B[i] < 0.05:
            plt.axvline(x=times[i], ymin=0, color='black', alpha=0.3)
        if p_interaction[i] < 0.05:
            plt.axvline(x=times[i], ymin=0, color='red', alpha=0.3)
    
    
    plt.savefig(os.path.join(plot_path, 'singch_erp', ch_names[ch] + '.png'))
    plt.close()
    plt.show()

Significant electrodes
* P250: F4, F8, O1, O2, Oz, P7, PO9, PO10 () 220-300

### 3.2 ROI based : Right frontal P250

In [None]:
factor_levels = [2, 2]
times = np.arange(-200, 800, 4)
# right frontal ROI index
rf_roi = [1,5,6,9,10]

# Data averaging
reshaped_A1B1 = np.average(cp_erp[rf_roi,:,:], axis=0).reshape(30, 1, 250)
reshaped_A1B2 = np.average(cc_erp[rf_roi,:,:], axis=0).reshape(30, 1, 250)
reshaped_A2B1 = np.average(ep_erp[rf_roi,:,:], axis=0).reshape(30, 1, 250)
reshaped_A2B2 = np.average(ec_erp[rf_roi,:,:], axis=0).reshape(30, 1, 250)

# Concatenate the data as structure (A1B1、A1B2、A2B1、A2B2)
data_combine = np.concatenate((reshaped_A1B1, reshaped_A1B2, reshaped_A2B1, reshaped_A2B2), axis=1)
    
# f_mway_rm for repeated measure ANOVA analysis
# Main effect of A
f_main_A, p_main_A = f_mway_rm(data_combine, factor_levels, effects='A')
# FDR correction
rejects_A, p_main_A = fdr_correction(p_main_A, alpha=0.05)

# Plotting parameters
plt.figure(figsize=(30,10))
ax = plt.axes()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_color('black')
ax.spines['left'].set_color('black')
ax.spines['bottom'].set_linewidth(2)
ax.spines['left'].set_linewidth(2)
plt.gca().invert_yaxis()
plt.tick_params(direction='in',length=10)
plt.xlim(-200,800)
plt.grid()

# Plot
plt.plot(times, np.average(np.average(cp_erp[rf_roi,:,:], axis=0),axis=0), label='Chinese priming',linewidth=2)
plt.plot(times, np.average(np.average(cc_erp[rf_roi,:,:], axis=0),axis=0), label='Chinese control',linewidth=2)
plt.plot(times, np.average(np.average(ep_erp[rf_roi,:,:], axis=0),axis=0), label='English priming',linewidth=2)
plt.plot(times, np.average(np.average(ec_erp[rf_roi,:,:], axis=0),axis=0), label='English control',linewidth=2)
plt.axhline(y=0, color="black")
plt.axvline(x=0, color="black", linestyle="--")
plt.xlabel('Time (ms)')
plt.ylabel('Amplitude ($\mu$V)')
plt.legend(fontsize=16,loc='best')
#plt.title('Right Frontal')

# Annotate the significant time-window
# for i in range(250):
#     if p_main_A[i] < 0.05:
#         plt.axvline(x=times[i], ymin=0.01, ymax=0.1, color='blue', alpha=1)

plt.savefig(os.path.join(plot_path, 'ROI_rightfrontal.png'))
plt.close()
#plt.show()

### 3.4 Topoplot

In [None]:
demo.info

In [None]:
# Plot the 

topo = np.average(ep_erp[:,:,225],axis=1)
fig,ax = plt.subplots()
im,_ = plot_topomap(topo, demo.info, axes=ax, sensors='ko', show=False, sphere=0.13)    
ax_x_start, ax_x_width, ax_y_start, ax_y_height  = 0.95, 0.04, 0.1, 0.9
cbar_ax = fig.add_axes([ax_x_start, ax_y_start, ax_x_width, ax_y_height])
clb = fig.colorbar(im, cax=cbar_ax)
#plt.savefig(os.path.join(plot_dir + 'p250_ec.png'),bbox_inches='tight',dpi=300,pad_inches=0.1)
#plt.close()

In [None]:
for i in np.linspace(0, 450, 10):
    i = int(i)
    topo = np.average(np.average(ec_erp[:,:,i:i+50], axis=2),axis=1)
    fig,ax = plt.subplots()
    im,_ = plot_topomap(topo, demo.info, axes=ax,cmap='RdBu_r', show=False,sphere=0.13)    
    ax_x_start, ax_x_width, ax_y_start, ax_y_height  = 0.95, 0.04, 0.1, 0.9
    cbar_ax = fig.add_axes([ax_x_start, ax_y_start, ax_x_width, ax_y_height])
    clb = fig.colorbar(im, cax=cbar_ax)

In [None]:
# Plot the N400 topo
n400_topo = np.average(np.average(ec_erp[:,:,140:152], axis=2),axis=1)
fig,ax = plt.subplots()
im,_ = plot_topomap(n400_topo, eeg_meta[1].info, axes=ax,cmap='RdBu_r', show=False,sphere=0.13)    
ax_x_start, ax_x_width, ax_y_start, ax_y_height  = 0.95, 0.04, 0.1, 0.9
cbar_ax = fig.add_axes([ax_x_start, ax_y_start, ax_x_width, ax_y_height])
clb = fig.colorbar(im, cax=cbar_ax)
plt.savefig(os.path.join(plot_path + 'n400_ec.png'),bbox_inches='tight',dpi=300,pad_inches=0.1)
plt.close()

In [None]:
## P250 mean amplitude extraction (220-280 ms; 105:120)
rf_p250 = np.concatenate((np.average(np.average(cp_erp[rf_roi,:,105:120], axis=0),axis=1),np.average(np.average(cc_erp[rf_roi,:,105:120], axis=0),axis=1),
        np.average(np.average(ep_erp[rf_roi,:,105:120], axis=0),axis=1),np.average(np.average(ec_erp[rf_roi,:,105:120], axis=0),axis=1)), axis=0)
## N400 mean amplitude extraction (360-410 ms; 140:152)
fc_n400 = np.concatenate((np.average(np.average(cp_erp[fc_roi,:,140:152], axis=0),axis=1),np.average(np.average(cc_erp[fc_roi,:,140:152], axis=0),axis=1),
        np.average(np.average(ep_erp[fc_roi,:,140:152], axis=0),axis=1),np.average(np.average(ec_erp[fc_roi,:,140:152], axis=0),axis=1)), axis=0)

roi_erp = {'rf_p250':rf_p250, 'fc_n400':fc_n400}
roi_erp = pd.DataFrame(roi_erp)
roi_erp.to_csv('F:/3_Projects/1_Morbi_EEG_project/EEG/Statistics/ERP/Morbi_ERP_Stat.csv')

In [None]:
test =  np.concatenate((np.average(np.average(cp_erp[fc_roi,:,140:152], axis=0),axis=1).reshape(30,1),np.average(np.average(cc_erp[fc_roi,:,140:152], axis=0),axis=1).reshape(30,1),
        np.average(np.average(ep_erp[fc_roi,:,140:152], axis=0),axis=1).reshape(30,1),np.average(np.average(ec_erp[fc_roi,:,140:152], axis=0),axis=1).reshape(30,1)), axis=1)

In [None]:
f_main_A, p_main_A = f_mway_rm(test, factor_levels, effects='A')
p_main_A