## [Working Title]

Created by Toomas Erik Anijärv in 07.04.2024

This notebook is a representation of EEG processing done for the publication with one of the participants as an example.

You are free to use this or any other code from this repository for your own projects and publications. Citation or reference to the repository is not required, but would be much appreciated (see more on README.md).

In [None]:
# Import packages
import mne, os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
import matplotlib.ticker as ticker
import seaborn as sns
from autoreject import (get_rejection_threshold, AutoReject)
from fooof import FOOOF
from fooof.plts.spectra import plot_spectra, plot_spectra_shading

# Set the default directory
os.chdir('/Users/aliciacampbell/Documents/GitHub/EEG-pyline')
mne.set_log_level('error')

# Import functions
import basic.arrange_data as arrange
import signal_processing.pre_process as prep
import signal_processing.spectral_analysis as spectr
import signal_processing.erp_analysis as erpan

In [None]:
# Folder where to get the raw EEG files
raw_folder = 'Data/Raw/'

# Folder where to export the clean epochs files
clean_folder = 'Data/Clean/'

# Folder where to save the results and plots
results_folder = 'Results/'

# Sub-folder for the experiment (i.e. timepoint or group)
exp_folder = 'LEISURE/T1/SART/'
exp_condition = 'SART_T1'

### SPECTRAL ANALYSIS: APERIODIC + THETA ACTIVITY

In [None]:
def convert_flat_spectr_amplitude(fm, flat_spectr_scale='linear'):
    """Log-linear conversion based on the chosen amplitude scale"""
    if flat_spectr_scale == 'linear':
        flatten_spectrum = 10 ** fm._spectrum_flat
        flat_spectr_ylabel = 'Amplitude (uV\u00b2/Hz)'
    elif flat_spectr_scale == 'log':
        flatten_spectrum = fm._spectrum_flat
        flat_spectr_ylabel = 'Log-normalised amplitude'
    return flatten_spectrum, flat_spectr_ylabel

def plot_fooof_spectra(fm, flat_spectr_scale, psd_params, bands=None, abs_bp=None, rel_bp=None, plot_rich=True):
    """Plot fooof spectra, regular with model fit and flattened spectrum."""
    # Set plot styles
    data_kwargs = {'color' : 'black', 'linewidth' : 1.4, 'label' : 'Original'}
    model_kwargs = {'color' : 'red', 'linewidth' : 1.4, 'alpha' : 0.75, 'label' : 'Full model'}
    aperiodic_kwargs = {'color' : 'blue', 'linewidth' : 1.4, 'alpha' : 0.75,
                        'linestyle' : 'dashed', 'label' : 'Aperiodic model'}
    flat_kwargs = {'color' : 'black', 'linewidth' : 1.4}
    hvline_kwargs = {'color' : 'blue', 'linewidth' : 1.0, 'linestyle' : 'dashed', 'alpha' : 0.75}

    # Log-linear conversion based on the chosen amplitude scale
    flatten_spectrum, flat_spectr_ylabel = convert_flat_spectr_amplitude(fm, flat_spectr_scale)
    
    # Plot power spectrum model + aperiodic fit for MEAN POST-EVENT PSD
    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 4), dpi=100)
    plot_spectra(fm.freqs, fm.power_spectrum,
                ax=axs[0], **data_kwargs)
    plot_spectra(fm.freqs, fm.fooofed_spectrum_,
                ax=axs[0], **model_kwargs)
    plot_spectra(fm.freqs, fm._ap_fit,
                ax=axs[0], **aperiodic_kwargs)
    axs[0].set_xlim(psd_params['fminmax'])
    axs[0].grid(linewidth=0.2)
    axs[0].set_xlabel('Frequency (Hz)')
    axs[0].set_ylabel('Log-normalised power (log$_{10}$[µV\u00b2/Hz])')
    axs[0].set_title('Spectrum model fit')
    axs[0].legend()
    
    # Flattened spectrum plot (i.e., minus aperiodic fit)
    plot_spectra(fm.freqs, flatten_spectrum,
                ax=axs[1], **flat_kwargs)
    if bands!=None:
        axs[1].axvspan(bands[list(bands.keys())[0]][0], bands[list(bands.keys())[0]][1], facecolor='green', alpha=0.2)
    axs[1].set_xlim(psd_params['fminmax'])
    axs[1].grid(linewidth=0.2)
    axs[1].set_xlabel('Frequency (Hz)')
    axs[1].set_ylabel(flat_spectr_ylabel)
    axs[1].set_title('Flattened spectrum')

    # If true, plot all the exported variables on the plots
    if plot_rich == True:
        axs[0].annotate('Error: ' + str(np.round(fm.get_params('error'), 4)) +
                    '\nR\u00b2: ' + str(np.round(fm.get_params('r_squared'), 4)),
                    (0.1, 0.16), xycoords='figure fraction', color='red', fontsize=8.5)
        axs[0].annotate('Exponent: ' + str(np.round(fm.get_params('aperiodic_params','exponent'), 4)) +
                    '\nOffset: ' + str(np.round(fm.get_params('aperiodic_params','offset'), 4)),
                    (0.19, 0.16), xycoords='figure fraction', color='blue', fontsize=8.5)
        if abs_bp!=None and rel_bp!=None:
            axs[1].annotate('Absolute BP: '+str(np.round(abs_bp, 4))+'\nRelative BP: '+str(np.round(rel_bp, 4)),
                            (0.69, 0.16), xycoords='figure fraction', color='green', fontsize=8.5)
    
    return fig, axs

def ev_to_df(df_ch_ev, i, fm, ch, ev, subject_names, 
             bands=None, abs_bp=None, rel_bp=None, stimname='stimulus'):
    """Helper function for saving fooof model results to dataframe."""
    if abs_bp == None: abs_bp = np.nan
    if rel_bp == None: rel_bp = np.nan
    df_ch_ev.loc[i, 'Exponent'] = fm.get_params('aperiodic_params','exponent')
    df_ch_ev.loc[i, 'Offset'] = fm.get_params('aperiodic_params','offset')
    if bands != None:
        df_ch_ev.loc[i, '{} absolute power'.format(list(bands.keys())[0])] = abs_bp
        df_ch_ev.loc[i, '{} relative power'.format(list(bands.keys())[0])] = rel_bp
    df_ch_ev.loc[i, 'R_2'] = fm.get_params('r_squared')
    df_ch_ev.loc[i, 'Error'] = fm.get_params('error')
    df_ch_ev['Channel'] = ch
    df_ch_ev['Event'] = ev
    df_ch_ev['Type'] = stimname
    df_ch_ev['Subject'] = subject_names[i]
    return df_ch_ev

In [None]:
# Brain regions and their channels
channels = ['AF3', 'F3', 'FC1', 'AF4', 'F4', 'FC2', 'Fz',
            'F7', 'FC5', 'T7', 'C3', 'F8', 'FC6', 'T8', 'C4', 'Cz',
            'CP5', 'P3', 'P7', 'CP6', 'P4', 'P8', 'CP1', 'CP2', 'Pz',
            'PO3', 'PO4', 'O1', 'O2', 'Oz'] # 'Fp1', 'Fp2'

# Power spectra estimation parameters
psd_params = {'pre' : dict(method='welch', fminmax=[1, 30], window='hamming', window_duration=0.5,
                           window_overlap=0, zero_padding=3, tminmax=[-0.5, 0]),
              'post' : dict(method='welch', fminmax=[1, 30], window='hamming', window_duration=0.8,
                            window_overlap=0, zero_padding=3, tminmax=[0, 0.8])}

# FOOOF (specparam) model parameters
fooof_params = dict(peak_width_limits=[1, 12], max_n_peaks=float('inf'), min_peak_height=0.225,
                    peak_threshold=2.0, aperiodic_mode='fixed')

# Set plot styles
data_kwargs = {'color' : 'black', 'linewidth' : 1.4}
model_kwargs = {'color' : 'red', 'linewidth' : 1.4, 'alpha' : 0.75}
aperiodic_kwargs = {'color' : 'blue', 'linewidth' : 1.4, 'alpha' : 0.75,
                    'linestyle' : 'dashed'}

# Flattened spectra amplitude scale (linear, log)
flat_spectr_scale = 'linear'

# Plot more information on the model fit plots or not; and save these plots or not
plot_rich = True
savefig = True

# Event names (i.e. different stimuli) within the epochs
# event_list = ['GO trial', 'NO-GO trial']
event_list = {'4' : 'GO trial',
              '8' : 'NO-GO trial'}

In [None]:
# Get directories of clean EEG files and set export directory
dir_inprogress = os.path.join(clean_folder, exp_folder)
file_dirs, subject_names = arrange.read_files(dir_inprogress, "_clean-epo.fif")

In [None]:
# Pre-create results folders and dataframe
arrange.create_results_folders(exp_folder=exp_folder, results_folder=results_folder, fooof=True)
df_ch = pd.DataFrame()
# Go through all the files (subjects) in the folder
for i in range(5):#len(file_dirs)):
    # Read the clean data from the disk
    epochs = mne.read_epochs(fname='{}/{}_clean-epo.fif'.format(dir_inprogress, subject_names[i]),
                                                                verbose=False)

    # Loop through all different events
    df_ch_ev_pre = pd.DataFrame()
    df_ch_ev_post = pd.DataFrame()
    df_ch_ev_post_diff = pd.DataFrame()
    for ev, evname in event_list.items():
        print('{} for {} ({}/{})'.format(ev, subject_names[i], i+1, len(file_dirs)))

        # Choose only epochs from the current event
        epochs_ev = epochs[ev]

        # Calculate Welch's power spectrum density (FFT) for the mean pre-event and mean post-event
        psds_pre, freqs_pre = spectr.calculate_psd(epochs_ev, subject_names[i], **psd_params['pre'], verbose=False, plot=False)
        psds_post, freqs_post = spectr.calculate_psd(epochs_ev, subject_names[i], **psd_params['post'], verbose=False, plot=False)

        # Create all-channels PSD dataframe for mean pre-event and mean post-event
        df_psds_pre = arrange.array_to_df(subject_names[i], epochs_ev, psds_pre.mean(axis=(0))).\
                            reset_index().drop(columns='Subject')
        df_psds_post = arrange.array_to_df(subject_names[i], epochs_ev, psds_post.mean(axis=(0))).\
                            reset_index().drop(columns='Subject')
        
        # Go through each individual channel
        fig_pre, axs_pre = plt.subplots(4, 8, figsize=(20, 10))
        fig_pre.suptitle(f'{subject_names[i]} - {evname} - pre')
        plt.tight_layout(pad=5.0)
        fig_post, axs_post = plt.subplots(4, 8, figsize=(20, 10))
        fig_post.suptitle(f'{subject_names[i]} - {evname} - post')
        plt.tight_layout(pad=5.0)
        fig_post_diff, axs_post_diff = plt.subplots(4, 8, figsize=(20, 10))
        fig_post_diff.suptitle(f'{subject_names[i]} - {evname} - post-ERP')
        plt.tight_layout(pad=5.0)
        for c, ch in enumerate(channels):
            # Choose only channel of interest PSD data for mean pre-event and mean post-event
            psds_pre_ch = df_psds_pre[ch].to_numpy()
            psds_post_ch = df_psds_post[ch].to_numpy()

            # Average the event epochs in time domain for that channel to get evoked object
            evoked_ev_ch = epochs_ev.average(picks=ch)

            # Calculate Welch's power spectrum density (FFT) for the ERP (only for single channel this time)
            psds_post_erp_ch = spectr.calculate_psd(evoked_ev_ch, subject_names[i], **psd_params['post'], verbose=False, plot=False)[0][0]

            # Substract the ERP PSD from the post-stimulus PSD -> difference PSD
            psds_post_diff_ch = psds_post_ch - psds_post_erp_ch

            ### SPECPARAM

            # Fit the spectrums with FOOOF
            fm_pre = FOOOF(**fooof_params, verbose=True)
            fm_pre.fit(freqs_pre, psds_pre_ch, psd_params['pre']['fminmax'])
            fm_post = FOOOF(**fooof_params, verbose=True)
            fm_post.fit(freqs_post, psds_post_ch, psd_params['post']['fminmax'])
            fm_post_diff = FOOOF(**fooof_params, verbose=True)
            fm_post_diff.fit(freqs_post, psds_post_diff_ch, psd_params['post']['fminmax'])
        
            # Log-linear conversion based on the chosen amplitude scale
            flatten_spectrum_pre, _ = convert_flat_spectr_amplitude(fm_pre, flat_spectr_scale)
            flatten_spectrum_post, _ = convert_flat_spectr_amplitude(fm_post, flat_spectr_scale)
            flatten_spectrum_post_diff, _ = convert_flat_spectr_amplitude(fm_post_diff, flat_spectr_scale)

            ### PLOTTING

            def fooof_subplot(fm, ax, data_kwargs, model_kwargs, aperiodic_kwargs):
                # Get fit parameters
                exp = fm.get_params('aperiodic_params','exponent')
                off = fm.get_params('aperiodic_params','offset')
                R2 = fm.get_params('r_squared')
                error = fm.get_params('error')

                # Set title with condition
                if exp <= 0 or error > 0.075 or R2 < 0.8:
                    title_color = 'red'
                    title_weight = 'bold'
                else:
                    title_color = 'black'
                    title_weight = 'normal'

                # Plot power spectrum model + aperiodic fit for mean pre
                plot_spectra(fm.freqs, fm.power_spectrum,
                            ax=ax, **data_kwargs)
                plot_spectra(fm.freqs, fm.fooofed_spectrum_,
                            ax=ax, **model_kwargs)
                plot_spectra(fm.freqs, fm._ap_fit,
                            ax=ax, **aperiodic_kwargs)
                ax.set_xlim(psd_params['pre']['fminmax'])
                ax.grid(linewidth=0.2)
                ax.set_xlabel('')
                ax.set_ylabel('')
                ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.1f'))
                ax.set_title(f'{ch}\nexp={np.round(exp, 3)}, off={np.round(off, 3)}\nerror={np.round(error, 3)}, R2={np.round(R2, 3)}',
                                        fontdict={'color': title_color, 'weight': title_weight})
                

            
            row = c // 8
            col = c % 8
            fooof_subplot(fm_pre, axs_pre[row, col], data_kwargs, model_kwargs, aperiodic_kwargs)
            fooof_subplot(fm_post, axs_post[row, col], data_kwargs, model_kwargs, aperiodic_kwargs)
            fooof_subplot(fm_post_diff, axs_post_diff[row, col], data_kwargs, model_kwargs, aperiodic_kwargs)

        # plt.show()
        # plt.show()
        plt.show()



#         ### PLOTTING

#         fig, axs = plot_fooof_spectra(fm_pre, flat_spectr_scale, psd_params['pre'], plot_rich=plot_rich)
#         plt.suptitle('Mean pre-event PSD at {} ({})'.format(reg, subject_names[i]))
#         plt.tight_layout()
#         if savefig == True:
#             plt.savefig(fname='{}/{}/FOOOF/{}_{}_{}_mean_pre_event_PSD.png'.format(results_folder, exp_folder,
#                                                                         exp_condition, subject_names[i],
#                                                                         reg), dpi=300)
#         plt.show()

#         fig, axs = plot_fooof_spectra(fm_post, flat_spectr_scale, psd_params['post'], plot_rich=plot_rich)
#         plt.suptitle('Mean post-event PSD at {} ({})'.format(reg, subject_names[i]))
#         plt.tight_layout()
#         if savefig == True:
#             plt.savefig(fname='{}/{}/FOOOF/{}_{}_{}_mean_post_event_PSD.png'.format(results_folder, exp_folder,
#                                                                         exp_condition, subject_names[i],
#                                                                         reg), dpi=300)
#         plt.show()

#         fig, axs = plot_fooof_spectra(fm_post_diff, flat_spectr_scale, psd_params['post'], plot_rich=plot_rich)
#         plt.suptitle('Post-minus-ERP PSD at {} ({})'.format(reg, subject_names[i]))
#         plt.tight_layout()
#         if savefig == True:
#             plt.savefig(fname='{}/{}/FOOOF/{}_{}_{}_post_minus_erp_PSD.png'.format(results_folder, exp_folder,
#                                                                         exp_condition, subject_names[i],
#                                                                         reg), dpi=300)
#         plt.show()

#         ### EXPORTING
        
#         # Add model parameters to dataframe for mean post-event
#         df_ch_ev_pre = ev_to_df(df_ch_ev_pre, i, fm_pre, reg, ev, subject_names, stimname='Mean pre')
#         df_ch_ev_post = ev_to_df(df_ch_ev_post, i, fm_post, reg, ev, subject_names, stimname='Mean post')
#         df_ch_ev_post_diff = ev_to_df(df_ch_ev_post_diff, i, fm_post_diff, reg, ev, subject_names, stimname='Mean post-ERP')

#         # Concatenate to master dataframe for mean post-event
#         df_ch = pd.concat([df_ch, df_ch_ev_pre])
#         df_ch = pd.concat([df_ch, df_ch_ev_post])
#         df_ch = pd.concat([df_ch, df_ch_ev_post_diff])
        

# # Reorder the channels and reset index
# df_ch = df_ch[['Subject', 'Channel', 'Type', 'Event', 'Exponent', 'Offset', 'R_2', 'Error']]
# df_ch = df_ch.reset_index(drop=True)

# # Export results for post-event data
# df_ch.to_excel('{}/{}/FOOOF/{}_{}_specparam.xlsx'.format(results_folder, exp_folder, exp_condition, reg))
# display(df_ch)

In [None]:
psds_post_ch.shape