In [None]:
import os
import seaborn as sns
import warnings
import numpy as np
import scipy
import pandas as pd
import mne
import matplotlib.pyplot as plt
from fooof import FOOOF
from fooof.plts.spectra import plot_spectrum

# Set the current working directory to be the project main folder
os.chdir('/Users/aliciacampbell/Documents/GitHub/EEG-pyline')

import basic.arrange_data as arrange
import signal_processing.spectral_analysis as spectr

In [None]:
### DEFINE ###
raw_folder = 'Data/Raw/'
clean_folder = 'Data/Clean/'
spectra_folder = 'Data/Spectra/'
results_folder = 'Results/'
savefinal_folder = 'Results/LEISURE/'

plot_folder = 'Results/fooof_ch_plots'

exp_folder = 'LEISURE/T1/EC'
exp_condition = 'EC_T1'

In [None]:
### Define
bands = {'Alpha' : [7, 14]}
channels = ['Fp1', 'AF3', 'F3', 'FC1', 'Fp2', 'AF4', 'F4', 'FC2', 'Fz',
              'F7', 'FC5', 'T7', 'C3', 'F8', 'FC6', 'T8', 'C4', 'Cz',
              'CP5', 'P3', 'P7', 'CP6', 'P4', 'P8', 'CP1', 'CP2', 'Pz',
              'PO3', 'PO4', 'O1', 'O2', 'Oz']

ind_spectr_type = 'linear_flat' # linear_normal, _normal, 'linear_flat', log_flat
plot_rich = True
savefig = True
savespectrum = False
psd_params = dict(method='welch', fminmax=[1, 30], window='hamming', window_duration=2.5,
                  window_overlap=0.5, zero_padding=39)
fooof_params = dict(peak_width_limits=[1,12], max_n_peaks=float("inf"), min_peak_height=0.225,
                    peak_threshold=2.0, aperiodic_mode='fixed')

spectrum_name = psd_params['method']+'_'+str(psd_params['fminmax'][0])+'-'+str(psd_params['fminmax'][1])+'Hz_WIN='+str(
                psd_params['window_duration'])+'s_'+psd_params['window']+'_OL='+str(psd_params['window_overlap']*
                100)+'%_ZP='+str(psd_params['zero_padding']*psd_params['window_duration'])+'s'

sns.set_palette('muted')
sns.set_style("whitegrid")

# Set plot styles
data_kwargs = {'color' : 'black', 'linewidth' : 1.4, 'label' : 'Original'}
model_kwargs = {'color' : 'red', 'linewidth' : 1.4, 'alpha' : 0.75, 'label' : 'Full model'}
aperiodic_kwargs = {'color' : 'blue', 'linewidth' : 1.4, 'alpha' : 0.75,
                    'linestyle' : 'dashed', 'label' : 'Aperiodic model'}
flat_kwargs = {'color' : 'black', 'linewidth' : 1.4}
hvline_kwargs = {'color' : 'blue', 'linewidth' : 1.0, 'linestyle' : 'dashed', 'alpha' : 0.75}

In [None]:
# Set the directory in progress and find all FIF (clean EEG) files in there
dir_inprogress = os.path.join(clean_folder, exp_folder)
file_dirs, subject_names = arrange.read_files(dir_inprogress, '_clean-epo.fif')
arrange.create_results_folders(exp_folder=exp_folder, results_folder=results_folder, fooof=True)

if savespectrum == True:
    try:
        os.makedirs(os.path.join('{}/{}/{}'.format(spectra_folder, exp_folder, spectrum_name)))
    except FileExistsError:
        pass
    try:
        os.makedirs(os.path.join('{}/{}/{}'.format(results_folder, exp_folder, spectrum_name)))
    except FileExistsError:
        pass

df = pd.DataFrame()
for i in range(len(file_dirs)):
    print(f'running {subject_names[i]}... ({i+1}/{len(file_dirs)}) -->')
    # Read in the clean EEG data
    epochs = mne.read_epochs(fname='{}/{}_clean-epo.fif'.format(dir_inprogress, subject_names[i]),
                                                                verbose=False)
    
    # Calculate Welch's power spectrum density
    [psds,freqs] = spectr.calculate_psd(epochs, subject_names[i], method=psd_params['method'],
                                        fminmax=psd_params['fminmax'], window=psd_params['window'],
                                        window_duration=psd_params['window_duration'],
                                        window_overlap=psd_params['window_overlap'],
                                        zero_padding=psd_params['zero_padding'],
                                        verbose=False, plot=False)

    # Create all-channels PSD dataframe
    df_psds = arrange.array_to_df(subject_names[i], epochs, psds.mean(axis=(0))).\
                        reset_index().drop(columns='Subject')

    # loop through all channels
    exp = np.zeros(shape=(len(channels)))
    off = np.zeros(shape=(len(channels)))
    cf = np.zeros(shape=(len(channels)))
    abs_bp = np.zeros(shape=(len(channels)))
    r2 = np.zeros(shape=(len(channels)))
    error = np.zeros(shape=(len(channels)))
    for c, ch in enumerate(channels):
        print(f'channel {ch}... ({c+1}/{len(channels)})')
        # choose PSD for the channel
        psds_ch = df_psds[ch].to_numpy()

        # fit fooof
        fm = FOOOF(**fooof_params, verbose=True)
        fm.fit(freqs, psds_ch, psd_params['fminmax'])

        # log-linear conversion based on the chosen amplitude scale
        if ind_spectr_type == 'linear_flat':
            flatten_spectrum = 10 ** fm._spectrum_flat
            flat_spectr_ylabel = 'Flattened power (µV\u00b2/Hz)'
        elif ind_spectr_type == 'log_flat':
            flatten_spectrum = fm._spectrum_flat
            flat_spectr_ylabel = 'Flattened log10-transformed power'
        elif ind_spectr_type == 'linear_normal':
            flatten_spectrum = psds_ch
            flat_spectr_ylabel = 'Power (µV\u00b2/Hz)'
        elif ind_spectr_type == 'log_normal':
            flatten_spectrum = np.log10(psds_ch)
            flat_spectr_ylabel = 'Log10-transformed power'

        # get measures for the channel
        cf[c], pw, bw, abs_bp[c], rel_bp = spectr.find_ind_band(flatten_spectrum, freqs,
                                                          bands['Alpha'], bw_size=6)
        exp[c] = fm.get_params('aperiodic_params','exponent')
        off[c] = fm.get_params('aperiodic_params','offset')
        r2[c] = fm.get_params('r_squared')
        error[c] = fm.get_params('error')

        ### ugly plotting script down be aware!

        # Plot power spectrum model + aperiodic fit
        fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 4), dpi=100)
        plot_spectrum(fm.freqs, fm.power_spectrum,
                    ax=axs[0], **data_kwargs)
        plot_spectrum(fm.freqs, fm.fooofed_spectrum_,
                    ax=axs[0], **model_kwargs)
        plot_spectrum(fm.freqs, fm._ap_fit,
                    ax=axs[0], **aperiodic_kwargs)
        axs[0].set_xlim(psd_params['fminmax'])
        axs[0].grid(linewidth=0.2)
        axs[0].set_xlabel('Frequency (Hz)')
        axs[0].set_ylabel('Log10-transformed power')
        axs[0].set_title('Original power spectrum with model fit')
        axs[0].legend()
        
        # Flattened spectrum plot (i.e., minus aperiodic fit)
        plot_spectrum(fm.freqs, flatten_spectrum,
                    ax=axs[1], **flat_kwargs)
        axs[1].plot(cf[c], pw, '*', color='blue', label='{} peak'.format(list(bands.keys())[0]))
        axs[1].set_xlim(psd_params['fminmax'])
        if ind_spectr_type == 'linear_flat' or ind_spectr_type == 'linear_normal': 
            (axs[1].set_ylim([0, pw*1.1]))
        if ind_spectr_type == 'log_flat' or ind_spectr_type == 'log_normal':
            (axs[1].set_ylim([None, pw*1.1]))
        axs[1].axvline(x=cf[c], ymin=0, ymax=pw/(pw*1.1), **hvline_kwargs)
        axs[1].axhline(y=pw, xmin=0, xmax=cf[c]/(psd_params['fminmax'][1]+1), **hvline_kwargs)
        axs[1].axvspan(bw[0], bw[1], alpha=0.1, color='green', label='{} band'.format(list(bands.keys())[0]))
        axs[1].grid(linewidth=0.2)
        axs[1].set_xlabel('Frequency (Hz)')
        axs[1].set_ylabel(flat_spectr_ylabel)
        axs[1].set_title('Power spectrum with individual alpha')
        axs[1].legend()

        # If true, plot all the exported variables on the plots
        if plot_rich == True:
            axs[0].annotate('Error: ' + str(np.round(fm.get_params('error'), 4)) +
                        '\nR\u00b2: ' + str(np.round(fm.get_params('r_squared'), 4)),
                        (0.1, 0.16), xycoords='figure fraction', color='red', fontsize=8.5)
            axs[0].annotate('Exponent: ' + str(np.round(fm.get_params('aperiodic_params','exponent'), 4)) +
                        '\nOffset: ' + str(np.round(fm.get_params('aperiodic_params','offset'), 4)),
                        (0.19, 0.16), xycoords='figure fraction', color='blue', fontsize=8.5)
            axs[1].text(cf[c]+1, pw, 'IAF: '+str(np.round(cf[c], 4)),
                        verticalalignment='top', color='blue', fontsize=8.5)
            axs[1].annotate('BW: '+str(np.round(bw[0], 4))+' - '+str(np.round(bw[1], 4))+
                            '\nIABP: '+str(np.round(abs_bp[c], 4)),
                            (0.75, 0.16), xycoords='figure fraction', color='green', fontsize=8.5)
        
        plt.suptitle('{} region ({})'.format(ch, subject_names[i]))
        plt.tight_layout()
        if savefig == True:
            os.makedirs(os.path.join(plot_folder, subject_names[i]), exist_ok=True)
            plt.savefig(fname='{}/{}/{}_{}_{}_fooof.png'.format(plot_folder, subject_names[i], 
                                                                exp_condition, subject_names[i], ch), dpi=100)
        plt.close()


    # define dict with all the measures and their names
    vars_dict = {'exponent':exp,
                'offset': off,
                'a_cf': cf,
                'a_abs_bp': abs_bp,
                'r2': r2,
                'error': error}
    
    df_var = pd.DataFrame()
    for vname, var in vars_dict.items():
        df_var_temp = pd.DataFrame(var, index=channels).T
        df_var_temp.insert(0, 'subject', subject_names[i])
        df_var_temp.insert(1, 'measure', vname)
        df_var = pd.concat([df_var, df_var_temp], axis=0)
    
    df = pd.concat([df, df_var])

display(df)

df.to_csv(os.path.join(plot_folder, 'fooof_ch.csv'))
