In [None]:
# Author: Jacob Woessner <woessner.jacob@gmail.com>

## HOW TO RUN
- Set the path to the folder with the data in the variable `import_path`
- Set project name in the variable `project_name`
- Optional: Set home directory in the variable `home_dir` where files will be generated (default is the location of this notebook)
- Hit `Run All` in the tab
- The script automatically populates all the folders and files in the home directory
<div class="alert alert-block alert-info">
<b>Note:</b> The script should skip files that have already been processed so if an error occurs you should be able to fix it and hit Run All without having to start from the beginning.
</div>

## INPUT:
- Resting EEG files in the CNT format, file name should be in the format of a subject id at the beginning of the file name separated by an underscore (Ex: 1234_resting1.cnt)
<div class="alert alert-block alert-warning">
<b>Warning:</b> The script concatenates Resting EEG files with the same subject id (Ex: 1234_resting1.cnt, 1234_resting2.cnt, 1234_resting3.cnt) so make sure the subject id is the same for all files and that eyes closed and eyes open files are not concatenated together.



## OUTPUT:
- Log file with the name `<ENTER_NAME>.log` in the home directory
- psd files in a psd folder in the home directory (json format)
- individual fooof csv files in a fooof folder in the home directory
- a fooof csv file with aperiodic parameters and error for all subjects in the home directory
- a fooof func csv file with aperiodic, periodic and orginal function values
- wide format spss file with fooof csv and fooof func csv combined
- Analysis of the data (fooof plots, raincloud plots, and statistics)

<div class="alert alert-block alert-warning">
<b>Warning:</b> As of January 1st, 2023, There may be a bug that appends empty data to cnt files in mne. I have included a conditional statement to check for this bug.


## Packages

In [None]:
import mne
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from fooof import FOOOF
from fooof.sim.gen import gen_aperiodic
import os
import json
import pyreadstat
import datetime
import sys
import ptitprince as pt
import seaborn as sns

# configure backend here for plotting
matplotlib.use('Agg')

## File paths

In [None]:
# CNT import path
import_cnt = 'your/path/here/'

#Project name string
project_name = '[project_name]'

# home directory: where this script is located
home_dir = os.getcwd()

# Set stdout and stderr used for logging print statements (comment out if you want to print to terminal)
now = datetime.datetime.now()
log_file = open(f'{home_dir}/log_{project_name}_{now.strftime("%Y-%m-%d_%H:%M")}.txt', 'a')
sys.stdout = log_file
sys.stderr = log_file


def make_dir_path(path):
    os.makedirs(path, exist_ok=True)
    return path

# Make a directory called psd_{project_name} if it doesnt exist
psd_path = make_dir_path(f'{home_dir}/psd_{project_name}')

# Make a directory called fooof_{project_name} if it doesnt exist
fooof_path = make_dir_path(f'{home_dir}/fooof_{project_name}')

# Make a directory called analysis_{project_name} if it doesnt exist
analysis_path = make_dir_path(f'{home_dir}/analysis_{project_name}')

## Processing Settings

In [None]:
# fooof fit frequency range in Hz
freq_range = [1, 55]

# fooof fit mode
mode = 'fixed'

# epoch duration in seconds
epoch_dur = 2.048

# Baseline correction
baseline = (-0.2, 0)

# psd method
psd_method = 'welch'

# Selected channels
fooof_channels = 'all'

# Bad channels to remove from fooof_channels
bad_fooof_channels = ['X', 'Y', 'BLANK']


## Standard Bands

In [None]:
bands = {'delta': [1, 4],
         'theta': [4, 8],
         'alpha': [8, 13],
         'beta': [13, 30],
         'gamma': [30, 50]}

## Analysis Settings

In [None]:
# Select channels to be used for graphical analysis AS A LIST of lower case strings
# e.g. ['fz', 'cz', 'pz']
analysis_channels = ['fcz', 'cz', 'pz', 'f8', 'f7']


## Calculate PSD

In [None]:


#Get list of all files in directory
file_list = os.listdir(import_cnt)
print(file_list)
#Get list of all files that end with .cnt
subject_list = [i.split('_')[0] for i in file_list if i.endswith('.cnt')]
cnt_list = [i for i in file_list if i.endswith('.cnt')]
#Sort list of cnt files
cnt_list.sort()
subject_list = list(set(subject_list.copy()))
subject_list.sort()

In [None]:
import_path = import_cnt
export_path = psd_path

fooof_dict = {}
spec_dict = {}
freq_dict = {}

spec_list = []
epoch_num_list = []
for i in subject_list:
    if (f'{i}_{project_name}_psd.json') in os.listdir(export_path):
        print(f'{i}_{project_name}_psd.json already exists. Skipping...')
        continue

    #Get items in cnt_list that are in subject_list
    one_two_epoch_list = []
    for j in cnt_list:
        if i in j:
            file_path = os.path.join(import_path, j)
            try:
                raw  = mne.io.read_raw_cnt(file_path, preload=True, data_format='int32')
            except:
                print(f'ERROR: {file_path} failed to load. Please check data.')
                continue
            # Select channels
            if fooof_channels != 'all':
                raw.pick_channels(fooof_channels)
            # Remove bad channels
            if bad_fooof_channels != []:
                raw.drop_channels(bad_fooof_channels)

            # # Condition to check for bug
            # if file_path.endswith('cnt'):
            #     # Check if there are 10 zero values in a row in the data
            #     # Check if there are any occurrences of 10 consecutive zero values
            #     data = raw.to_data_frame(picks=[raw.info['ch_names'][0]])
            #     has_consecutive_zeros = any(
            #         data['value'].rolling(window=10).sum() == 0)
            #     # Print error message if there are 10 consecutive zeros
            #     if has_consecutive_zeros:
            #         print(
            #             f'ERROR: {file_path} has 10 consecutive zero values. Please check data.')
            #         continue
                
            srate = raw.info['sfreq']
            var_n_fft = int(srate * epoch_dur)
            onsets = raw.annotations.onset
            durations = raw.annotations.duration
            epoch_rest = mne.make_fixed_length_epochs(raw, epoch_dur, preload=True)
            #epoch_rest.apply_baseline(baseline=baseline)
            one_two_epoch_list.append(epoch_rest)
    epochs = mne.concatenate_epochs(one_two_epoch_list)

    psd = epochs.compute_psd(method=psd_method, n_fft=var_n_fft).average()
    # Create json object to store psd data
    psd_dict = {}
    psd_dict['spec'] = psd.get_data().tolist()
    psd_dict['freq'] = psd.freqs.tolist()
    psd_dict['epoch_num'] = len(epochs)
    psd_dict['subject'] = i
    psd_dict['ch_names'] = psd.ch_names
    # Export psd data as json
    with open(os.path.join(psd_path, f'{i}_{project_name}_psd.json'), 'w') as f:
        json.dump(psd_dict, f)    

## FOOF Analysis

In [None]:


#Get list of all files in directory
file_list = os.listdir(psd_path)
print(file_list)
#Get list of all files that end with .cnt
file_list.sort()

df_fooof_list = []
df_fooof_func_list = []
spec_list = []
loop = 0
for i in file_list:
    if i.endswith('.json'):
        with open(os.path.join(psd_path, i)) as f:
            data = json.load(f)

        freqs = np.array(data['freq'])
        fooof_sub_list = []
        fooof_func_sub_list = []
    
        file = i
        sub_num = file.split('_')[0]


        if f"{sub_num}_{project_name}_func.csv" in str(os.listdir(fooof_path)) and f"{sub_num}_{project_name}.csv" in str(os.listdir(fooof_path)):
            print(f'{sub_num}_{project_name}_func.csv and {sub_num}_{project_name}.csv already exists. Skipping...')
            df_fooof_func_list.append(pd.read_csv(os.path.join(fooof_path, f"{sub_num}_{project_name}_func.csv")))
            df_fooof_list.append(pd.read_csv(os.path.join(fooof_path, f"{sub_num}_{project_name}.csv")))
            continue
        for index, spectrum in enumerate(data['spec']):
            
            
            spectrum = np.array(spectrum)
            # Create FOOOF object
            fm = FOOOF(aperiodic_mode=mode, verbose=False)
            
            # If there is a Dataerror, skip the channel
            try:
                fm.fit(freqs, spectrum, freq_range=freq_range)
                print(f'Fit: {file} at channel {data["ch_names"][index]} \n')
            except:
                #write to log file
                print(f'Error in {file} at channel {data["ch_names"][index]} \n')
                continue
            
            # Remove 0 Hz from the spectrum
            freqs_no_0 = freqs[1:]
            spectrum_no_0 = spectrum[1:]

            # Generate aperiodic component
            ap_spec = 10**gen_aperiodic(freqs_no_0, fm.aperiodic_params_) 
            p_spec = spectrum_no_0 - ap_spec
            
            # Create empty dataframe
            df_fooof = pd.DataFrame()
            df_fooof_func = pd.DataFrame()
            # Add aperiodic series
            df_fooof_func['pc'] = p_spec
            df_fooof_func['apc'] = ap_spec
            df_fooof_func['psd'] = spectrum_no_0
            # print(f'psd: {spectrum}')
            df_fooof_func['freqs'] = freqs_no_0
            df_fooof_func['subject'] = data['subject']
            df_fooof_func['channel'] = data['ch_names'][index]
            df_fooof_func['epoch_num'] = data['epoch_num']
            #Generate series 0-len(freqs) for plotting
            df_fooof_func['num'] = np.arange(0, len(freqs_no_0))

            offset =  fm.aperiodic_params_[0]
            if mode == 'knee':
                knee = fm.aperiodic_params_[1]
                exponent = fm.aperiodic_params_[2]
            else:
                exponent = fm.aperiodic_params_[1]
            df_fooof['ap_offset'] = [offset]
            if (mode == 'knee'):
                df_fooof['ap_knee'] = [knee]
            df_fooof['ap_exponent'] = [exponent]
            df_fooof['r_squared'] = [fm.r_squared_]
            df_fooof['error'] = [fm.error_]
            df_fooof['subject'] = data['subject']
            df_fooof['channel'] = data['ch_names'][index]

            #Add peak parameters in band
            # Convert the peak_params_ to a dataframe
            params_df = pd.DataFrame(fm.peak_params_, columns=['CF', 'PW', 'BW'])

            # Create a function to get the peak_nums
            def get_peak_nums(df, bands):
                peak_nums = {}
                for band in bands:
                    peak_nums[band] = len(df[df['CF'].between(bands[band][0], bands[band][1], inclusive='left')])
                return peak_nums

            # Call the function to get the peak_nums dictionary
            peak_nums = get_peak_nums(params_df, bands)

            # Iterate over the bands and create the columns in df_fooof
            for band, band_limits in bands.items():
                band_df = params_df[params_df['CF'].between(bands[band][0], bands[band][1], inclusive='left')].reset_index(drop=True)
                for i in range(peak_nums[band]):
                    df_fooof[f'{band}_CF_{i}'] = [band_df.loc[i, 'CF']]
                    df_fooof[f'{band}_PW_{i}'] = [band_df.loc[i, 'PW']]
                    df_fooof[f'{band}_BW_{i}'] = [band_df.loc[i, 'BW']]
            
            df_fooof_list.append(df_fooof)
            df_fooof_func_list.append(df_fooof_func)
         
            #Add data to list
            fooof_sub_list.append(df_fooof)
            fooof_func_sub_list.append(df_fooof_func)

    fooof_sub = pd.concat(fooof_sub_list)
    fooof_func_sub = pd.concat(fooof_func_sub_list)
    
    
    # Reset index
    fooof_sub = fooof_sub.reset_index(drop=True)
    fooof_func_sub = fooof_func_sub.reset_index(drop=True)
    fooof_sub.to_csv(os.path.join(fooof_path,f"{sub_num}_{project_name}.csv"))
    fooof_func_sub.to_csv(os.path.join(fooof_path, f"{sub_num}_{project_name}_func.csv"))

#Concatenate all dataframes
df_fooof = pd.concat(df_fooof_list)
df_fooof_func = pd.concat(df_fooof_func_list)

# Reset index
df_fooof = df_fooof.reset_index(drop=True)
df_fooof_func = df_fooof_func.reset_index(drop=True)


now = datetime.datetime.now()
now = now.strftime("%Y-%m-%d_%H-%M")

# Export to csv
df_fooof.to_csv(f"fooof_comb_{project_name}_{now}.csv")
df_fooof_func.to_csv(f"fooof_func_comb_{project_name}_{now}.csv")

## Convert to wide

In [None]:
#Get list of all files in directory
file_list = os.listdir()
print(file_list)
#Get list of all files that end with .cnt
file_list.sort()
print(file_list)

#Get all files that contain string 'fooof_'
fooof_file_list = [i for i in file_list if ('fooof_comb' in i and i.endswith('.csv'))]
fooof_func_file_list = [i for i in file_list if ('fooof_func_comb' in i and i.endswith('.csv'))]

In [None]:
# Import last df_fooof and df_fooof_func file
df_fooof = pd.read_csv(fooof_file_list[-1])
df_fooof_func = pd.read_csv(fooof_func_file_list[-1])

#Drop all columns that contain 'Unnamed'
df_fooof = df_fooof.loc[:, ~df_fooof.columns.str.contains('^Unnamed')]
df_fooof_func = df_fooof_func.loc[:, ~df_fooof_func.columns.str.contains('^Unnamed')]


In [None]:
df_fooof

In [None]:
df_fooof_func

In [None]:
#Round frequencies to 2 decimals
df_fooof_func['freq'] = df_fooof_func['freqs'].round(2) 
df_fooof_func['freq'] = df_fooof_func['freq'].astype(str) + 'H'
df_fooof_func
df_select = df_fooof_func[df_fooof_func['num'] <120]

# Convert offset to microvolts
df_fooof_microvolts = df_fooof.apply(lambda x: np.log10(10**x*10**12) if 'offset' in x.name else x, axis=0)

#Get column names in df_fooof
fooof_cols = df_fooof.columns.to_list()
#Remove subject and channel
fooof_cols.remove('subject')
fooof_cols.remove('channel')
fooof_cols

In [None]:

wide_df_fooof_microvolts = df_fooof_microvolts.pivot(
                    index=['subject'], 
                    columns=['channel'], 
                    values=fooof_cols
                    )


value = wide_df_fooof_microvolts.columns.get_level_values(0)
channel = wide_df_fooof_microvolts.columns.get_level_values(1)

# #Set the index to source_fileS
# wide_df_fcz = wide_df_fcz.set_index(('lab_num'))

#Rename the columns
wide_df_fooof_microvolts.columns = project_name + '_' + value + '_' + channel

In [None]:

wide_df_fooof_func = df_select.pivot(
                    index=['subject'], 
                    columns=['channel', 'freq'], 
                    values=['psd', 'apc', 'pc', 'epoch_num']
                    )


value = wide_df_fooof_func.columns.get_level_values(0)
channel = wide_df_fooof_func.columns.get_level_values(1)
num = wide_df_fooof_func.columns.get_level_values(2)

# #Set the index to source_fileS
# wide_df_fcz = wide_df_fcz.set_index(('lab_num'))

#Rename the columns
wide_df_fooof_func.columns = project_name + '_' + value + '_' + channel + '_' + num 

In [None]:
#Multiply every column but the last one by 10^12
wide_fooof_func_microvolts = pd.concat([wide_df_fooof_func.filter(regex='(psd)|(apc)|(pc)') * 1e12, wide_df_fooof_func.filter(regex='(epoch_num)')], axis=1)

#Reset index
wide_fooof_func_microvolts.reset_index(inplace=True)

In [None]:
epoch_cols = [col for col in wide_fooof_func_microvolts.columns if 'epoch_num' in col]
wide_fooof_func_microvolts = wide_fooof_func_microvolts.drop(epoch_cols[1:], axis=1)

# Rename column that has epoch_num in it to epoch_num
epoch_num_name = [col for col in wide_fooof_func_microvolts.columns if 'epoch_num' in col]
wide_fooof_func_microvolts = wide_fooof_func_microvolts.rename(columns={epoch_num_name[0]: 'epoch_num'})

In [None]:
wide_fooof_func_microvolts

In [None]:
#Convert index to string
wide_df_fooof_microvolts.index = wide_df_fooof_microvolts.index.astype(str)
#Convert index to string
wide_fooof_func_microvolts.index = wide_fooof_func_microvolts.index.astype(str)

#Reset index
wide_df_fooof_microvolts.reset_index(inplace=True)
#Reset index
wide_fooof_func_microvolts.reset_index(inplace=True)

#Set index to subject
wide_df_fooof_microvolts = wide_df_fooof_microvolts.set_index('subject')
#Set index to subject
wide_fooof_func_microvolts = wide_fooof_func_microvolts.set_index('subject')

# Set index to string
wide_df_fooof_microvolts.index = wide_df_fooof_microvolts.index.astype(str)
# Set index to string
wide_fooof_func_microvolts.index = wide_fooof_func_microvolts.index.astype(str)

In [None]:
#Concatenate the two dataframes
wide_df = pd.concat([ wide_fooof_func_microvolts, wide_df_fooof_microvolts], axis=1)
#Replace offset with off
wide_df.columns = wide_df.columns.str.replace('offset', 'off')
#Replace exponent with exp
wide_df.columns = wide_df.columns.str.replace('exponent', 'exp')

wide_df.reset_index(inplace=True)
wide_df

In [None]:
wide_df

In [None]:

now = datetime.datetime.now()
now = now.strftime("%Y-%m-%d_%H-%M")
pyreadstat.write_sav(wide_df, os.path.join(home_dir, f"final_fooof_{project_name}_{now}.sav"))

## Analysis
- Freq and FOOOF plots
- Raincloud plots
- Stats

In [None]:

now = datetime.datetime.now()
now = now.strftime("%Y-%m-%d_%H-%M")
#Get list of all files in directory
file_list = os.listdir(home_dir)
print(file_list)
#Get list of all files that end with .cnt
file_list.sort()
print(file_list)

#Get all files that contain string 'fooof_'
fooof_file_list = [i for i in file_list if ('fooof_comb' in i and i.endswith('.csv'))]
fooof_func_file_list = [i for i in file_list if ('fooof_func_comb' in i and i.endswith('.csv'))]
spss_file_list = [i for i in file_list if ('final_fooof' in i and i.endswith('.sav'))]


In [None]:
# Import last df_fooof and df_fooof_func file
df_fooof = pd.read_csv(fooof_file_list[-1])
df_fooof_func = pd.read_csv(fooof_func_file_list[-1])
final_fooof = pd.read_spss(spss_file_list[-1])

# Drop columns subject and index
final_fooof = final_fooof.drop(columns=['subject', 'index'])

# Convert offset to microvolts
df_fooof = df_fooof.apply(lambda x: np.log10(10**x*10**12) if 'offset' in x.name else x, axis=0)

# Convert psd, apc, and pc to microvolts
df_fooof_func = df_fooof_func.apply(lambda x: x*10**12 if 'psd' in x.name or 'apc' in x.name or 'pc' in x.name else x, axis=0)

# set dtype to numeric in final_fooof
final_fooof = final_fooof.apply(pd.to_numeric, errors='ignore')


channels_regex  = [f'(_{i}_)' for i in analysis_channels]
channels_regex = '|'.join(channels_regex)

# Make column names lower case
final_fooof.columns = final_fooof.columns.str.lower()

# Set all text to lower case in df_fooof_func and df_fooof
df_fooof_func = df_fooof_func.apply(lambda x: x.astype(str).str.lower() if 'channel' in x.name else x, axis=0)
df_fooof = df_fooof.apply(lambda x: x.astype(str).str.lower() if 'channel' in x.name else x, axis=0)

df_fooof_func = df_fooof_func.loc[df_fooof_func['channel'].isin(analysis_channels)]
df_fooof = df_fooof.loc[df_fooof['channel'].isin(analysis_channels)]

final_fooof = final_fooof.filter(regex=channels_regex)



In [None]:
df_fooof_func_grouped = df_fooof_func.groupby(['subject', 'channel'])
# Loop through each group 
for name, group in df_fooof_func_grouped:

    offset = df_fooof.loc[(df_fooof['subject'] == name[0]) & (df_fooof['channel'] == name[1]), 'ap_offset'].values[0].round(2)
    if mode == 'knee':
        knee = df_fooof.loc[(df_fooof['subject'] == name[0]) & (df_fooof['channel'] == name[1]), 'ap_knee'].values[0].round(2)
    elif mode == 'fixed':
        knee = 'NA'
    exponent = df_fooof.loc[(df_fooof['subject'] == name[0]) & (df_fooof['channel'] == name[1]), 'ap_exponent'].values[0].round(2)

    # Plot psd, apc, and pc
    fig, ax = plt.subplots()
    ax = group.plot(x='freqs', y= ['psd', 'apc', 'pc'], ax=ax, title=f'Subject: {name[0]} Channel: {name[1]}')
    ax.set_xlabel('Frequency (Hz)')
    ax.set_ylabel('Power (uV^2/Hz)')
    ax.legend(['PSD', 'APC', 'PC'])
    ax.set_xlim([0, 55])
    # Place the Following Text at the Bottom of the Figure
    fig.text(0.5, -0.01, f'Offset: {offset} Knee: {knee} Exponent: {exponent}', ha='center', fontsize=12)
    
    # Save the Figure
    fig.tight_layout()
    fig.savefig(os.path.join(analysis_path, f'fooof_{name[0]}_{name[1]}_plot_{now}.png'), dpi=300, bbox_inches='tight')

    plt.close()

    # Plot psd, apc, and pc
    fig, ax = plt.subplots()
    ax = group.plot(x='freqs', y= ['psd', 'apc', 'pc'], ax=ax, title=f'Subject: {name[0]} Channel: {name[1]}', logy=True)
    ax.set_xlabel('Frequency (Hz)')
    ax.set_ylabel('Power (uV^2/Hz)')
    ax.legend(['PSD', 'APC', 'PC'])
    ax.set_xlim([0, 55])
    # Place the Following Text at the Bottom of the Figure
    fig.text(0.5, -0.01, f'Offset: {offset} Knee: {knee} Exponent: {exponent}', ha='center', fontsize=12)
    
    # Save the Figure
    fig.tight_layout()
    fig.savefig(os.path.join(analysis_path, f'fooof_{name[0]}_{name[1]}_logy_{now}.png'), dpi=300, bbox_inches='tight')

    plt.close()

In [None]:


# Plot every column in final_fooof as a raincloud plot seperate raincloud plot
for column in final_fooof.columns:
    fig = plt.figure()

    # Remove outliers from final_fooof[column]
    z_score = 3
    data_no_out = final_fooof[column][~((final_fooof[column] < (final_fooof[column].mean() - z_score * final_fooof[column].std())) |
                                        (final_fooof[column] > (final_fooof[column].mean() + z_score * final_fooof[column].std())))]

    ax = pt.RainCloud(data=final_fooof[column], orient='h')
    ax.set_title(column)
    
    mean = round(float(final_fooof[column].mean()), 2)
    median = round(float(final_fooof[column].median()), 2)
    sd = round(float(final_fooof[column].std()), 2)

    # Place descriptive statistics at the bottom of the figure
    fig.text(.5, 0.02, f'Mean: {mean} Median: {median} SD: {sd}', ha='center', fontsize=12)

    plt.savefig(os.path.join(analysis_path, f'fooof_{column}_raincloud_{now}.png'), dpi=300)
    # wait for user to press enter to continue
    
    plt.close()

    fig = plt.figure()

    ax = pt.RainCloud(data=data_no_out, orient='h')
    ax.set_title(f'{column}_noout')
    
    mean = round(float(data_no_out.mean()), 2)
    median = round(float(data_no_out.median()), 2)
    sd = round(float(data_no_out.std()), 2)

    # Place descriptive statistics at the bottom of the figure
    fig.text(.5, 0.02, f'Mean: {mean} Median: {median} SD: {sd}', ha='center', fontsize=12)

    plt.savefig(os.path.join(analysis_path, f'fooof_{column}_raincloud_noout_{now}.png'), dpi=300)
    # wait for user to press enter to continue
    
    plt.close()


    