## Packages

In [None]:
import mne
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
from scipy import signal as sg
from fooof import FOOOF
import os
import json
from fooof.sim.gen import gen_aperiodic
import pyreadstat
import datetime

## File paths

In [None]:
# CNT import path
import_cnt = '[INSERT PATH TO CNT FILES]'

#Project name string
project_name = '[INSERT PROJECT NAME]'

# home directory: where this script is located
home_dir = os.getcwd()

# Make a directory called psd_{project_name} if it doesnt exist
if not os.path.exists(f'{home_dir}/psd_{project_name}'):
    os.makedirs(f'{home_dir}/psd_{project_name}')
psd_path = f'{home_dir}/psd_{project_name}'

#Make a direct called fooof_{project_name} if it doesnt exist
if not os.path.exists(f'{home_dir}/fooof_{project_name}'):
    os.makedirs(f'{home_dir}/fooof_{project_name}')
fooof_path = f'{home_dir}/fooof_{project_name}'

## Settings

In [None]:
# Number of fft points
var_n_fft = 1024

# fooof fit frequency range
freq_range = [1, 55]

# fooof fit mode
mode = 'knee'

## Calculate PSD

In [None]:


#Get list of all files in directory
file_list = os.listdir(import_cnt)
print(file_list)
#Get list of all files that end with .cnt
subject_list = [i.split('_')[0] for i in file_list if i.endswith('.cnt')]
cnt_list = [i for i in file_list if i.endswith('.cnt')]
#Sort list of cnt files
cnt_list.sort()
subject_list = list(set(subject_list.copy()))
subject_list.sort()

In [None]:
import_path = import_cnt
export_path = psd_path

fooof_dict = {}
spec_dict = {}
freq_dict = {}

spec_list = []
epoch_num_list = []
for i in subject_list:
    if (f'{i}_{project_name}_psd.json') in os.listdir(export_path):
        print('file already exists')
        continue

    #Get items in cnt_list that are in subject_list
    one_two_epoch_list = []
    for j in cnt_list:
        if i in j:
            file_path = os.path.join(import_path, j)
            print(file_path)
            raw_cnt  = mne.io.read_raw_cnt(file_path, data_format = 'int32', preload=True)
            onsets = raw_cnt.annotations.onset
            durations = raw_cnt.annotations.duration
            last_annotation = onsets[-1] + durations[-1]
            raw_cnt.annotations.append(onset = last_annotation, duration = 100000, description = 'BAD_1')
            epoch_rest = mne.make_fixed_length_epochs(raw_cnt, 2.048, preload=True)
            epoch_rest.apply_baseline(baseline=(None, None))
            one_two_epoch_list.append(epoch_rest)
    epochs = mne.concatenate_epochs(one_two_epoch_list)
    # if len(epochs) < 29:
    #     print('too few epochs')
    #     with open('/home/woess/mnt/d/ADOLREST12/log_EC.txt', 'a') as f:
    #         f.write(f'{i} too few epochs')
    #     continue

    psd = epochs.compute_psd(method='welch', n_fft =var_n_fft).average()
    #Create json object to store psd data
    psd_dict = {}
    psd_dict['spec'] = psd.get_data().tolist()
    psd_dict['freq'] = psd.freqs.tolist()
    psd_dict['epoch_num'] = len(epochs)
    psd_dict['subject'] = i
    psd_dict['ch_names'] = psd.ch_names
    #Export psd data as json
    with open(os.path.join(psd_path, f'{i}_{project_name}_psd.json'), 'w') as f:
        json.dump(psd_dict, f)    

In [None]:


#Get list of all files in directory
file_list = os.listdir(psd_path)
print(file_list)
#Get list of all files that end with .cnt
file_list.sort()

df_fooof_list = []
df_fooof_func_list = []
spec_list = []
bands = {'delta': [1, 4],
        'theta': [4, 8],
        'alpha': [8, 13],
        'beta': [13, 30],
        'gamma': [30, 50]}

loop = 0
for i in file_list:
    os.chdir(import_path)
    if i.endswith('.json'):
        with open(os.path.join(psd_path, i)) as f:
            loop += 1
        
            data = json.load(f)
            freqs = np.array(data['freq'])
            print(i)
            file = i
        for index, spectrum in enumerate(data['spec']):
    
            spectrum = np.array(spectrum)
            # Create FOOOF object
            fm = FOOOF(verbose=False, aperiodic_mode=mode)
            print(data['ch_names'][index])
            # If there is a Dataerror, skip the channel
            try:
                fm.fit(freqs, spectrum, freq_range=freq_range)
            except:
                #write to log file
                with open(f'{home_dir}/log_{project_name}.txt', 'a') as f:
                    f.write(f'Error in {file} at channel {data["ch_names"][index]} \n')
                continue

            #Remove first frequency point at 0 Hz
            spectrum = spectrum[1:]
            freqs = freqs[1:]
            
            ap_spec = 10**gen_aperiodic(freqs, fm.aperiodic_params_) 
            p_spec = spectrum - ap_spec
            
            # Create empty dataframe
            df_fooof = pd.DataFrame()
            df_fooof_func = pd.DataFrame()
            # Add aperiodic series
            df_fooof_func['pc'] = p_spec
            df_fooof_func['apc'] = ap_spec
            df_fooof_func['psd'] = spectrum
            df_fooof_func['freqs'] = freqs
            df_fooof_func['subject'] = data['subject']
            df_fooof_func['channel'] = data['ch_names'][index]
            df_fooof_func['epoch_num'] = data['epoch_num']
            #Generate series 0-len(freqs) for plotting
            df_fooof_func['num'] = np.arange(0, len(freqs))

            offset =  fm.aperiodic_params_[0]
            if mode == 'knee':
                knee = fm.aperiodic_params_[1]
                exponent = fm.aperiodic_params_[2]
            else:
                exponent = fm.aperiodic_params_[1]
            df_fooof['ap_offset'] = [offset]
            if (mode == 'knee'):
                df_fooof['ap_knee'] = [knee]
            df_fooof['ap_exponent'] = [exponent]
            df_fooof['r_squared'] = [fm.r_squared_]
            df_fooof['error'] = [fm.error_]
            df_fooof['subject'] = data['subject']
            df_fooof['channel'] = data['ch_names'][index]

            #Add peak parameters in band
            # Convert the peak_params_ to a dataframe
            params_df = pd.DataFrame(fm.peak_params_, columns=['CF', 'PW', 'BW'])

            # Create a function to get the peak_nums
            def get_peak_nums(df, bands):
                peak_nums = {}
                for band in bands:
                    peak_nums[band] = len(df[df['CF'].between(bands[band][0], bands[band][1], inclusive='left')])
                return peak_nums

            # Call the function to get the peak_nums dictionary
            peak_nums = get_peak_nums(params_df, bands)

            # Iterate over the bands and create the columns in df_fooof
            for band, band_limits in bands.items():
                band_df = params_df[params_df['CF'].between(bands[band][0], bands[band][1], inclusive='left')].reset_index(drop=True)
                for i in range(peak_nums[band]):
                    df_fooof[f'{band}_CF_{i}'] = [band_df.loc[i, 'CF']]
                    df_fooof[f'{band}_PW_{i}'] = [band_df.loc[i, 'PW']]
                    df_fooof[f'{band}_BW_{i}'] = [band_df.loc[i, 'BW']]
            
            df_fooof_list.append(df_fooof)
            df_fooof_func_list.append(df_fooof_func)
            # #Export to csv
            # df_fooof.to_csv(f"fooof_{sub_chan['subject'].iloc[0]}_{sub_chan['channel'].iloc[0]}.csv")
            # df_fooof_func.to_csv(f"fooof_func_{sub_chan['subject'].iloc[0]}_{sub_chan['channel'].iloc[0]}.csv")
            with open(os.path.join(fooof_path, f"{file[:-5]}_{df_fooof_func['channel'].to_list()[0]}_{project_name}_func.json"), 'w') as f:
                json.dump(df_fooof_func.to_json(), f)
            with open(os.path.join(fooof_path, f"{file[:-5]}_{df_fooof['channel'].to_list()[0]}_{project_name}.json"), 'w') as f:
                json.dump(df_fooof.to_json(), f)

#Concatenate all dataframes
df_fooof = pd.concat(df_fooof_list)
df_fooof_func = pd.concat(df_fooof_func_list)

now = datetime.datetime.now()
now = now.strftime("%Y-%m-%d_%H-%M")

# Export to csv
if f"fooof_comb_{project_name}_{now}.csv" not in os.listdir(home_dir):
    df_fooof.to_csv(f"fooof_comb_{project_name}_{now}.csv")
if f"fooof_func_comb_{project_name}_{now}.csv" not in os.listdir(home_dir):
    df_fooof_func.to_csv(f"fooof_func_comb_{project_name}_{now}.csv")

## Convert to wide

In [None]:
import_path = home_dir
os.chdir(import_path)
#Get list of all files in directory
file_list = os.listdir()
print(file_list)
#Get list of all files that end with .cnt
file_list.sort()
print(file_list)

#Get all files that contain string 'fooof_'
fooof_file_list = [i for i in file_list if 'fooof_' in i]
fooof_func_list = [i for i in file_list if 'fooof_func' in i]

In [None]:
os.chdir(import_path)
#Import df_fooof and df_fooof_func
df_fooof = pd.read_csv(fooof_file_list[-1])
df_fooof_func = pd.read_csv(fooof_func_list[-1])
#Drop unnamed column
df_fooof = df_fooof.drop(columns=['Unnamed: 0'])
df_fooof_func = df_fooof_func.drop(columns=['Unnamed: 0'])

df_fooof = df_fooof[~df_fooof['channel'].isin(['M2', 'EKG', 'HEOG', 'VEOG'])]
df_fooof_func = df_fooof_func[~df_fooof_func['channel'].isin(['M2', 'EKG', 'HEOG', 'VEOG'])]

In [None]:
df_fooof

In [None]:
df_fooof_func

In [None]:
#Round frequencies to 2 decimals
df_fooof_func['freq'] = df_fooof_func['freqs'].round(2) 
df_fooof_func['freq'] = df_fooof_func['freq'].astype(str) + 'H'
df_fooof_func
df_select = df_fooof_func[df_fooof_func['num'] <120]

# Convert offset to microvolts
df_fooof_microvolts = df_fooof.apply(lambda x: np.log10(10**x*10**12) if 'offset' in x.name else x, axis=0)

#Get column names in df_fooof
fooof_cols = df_fooof.columns.to_list()
#Remove subject and channel
fooof_cols.remove('subject')
fooof_cols.remove('channel')
fooof_cols

In [None]:

wide_df_fooof_microvolts = df_fooof_microvolts.pivot(
                    index=['subject'], 
                    columns=['channel'], 
                    values=fooof_cols
                    )


value = wide_df_fooof_microvolts.columns.get_level_values(0)
channel = wide_df_fooof_microvolts.columns.get_level_values(1)

# #Set the index to source_fileS
# wide_df_fcz = wide_df_fcz.set_index(('lab_num'))

#Rename the columns
wide_df_fooof_microvolts.columns = project_name + '_' + value + '_' + channel

In [None]:

wide_df_fooof_func = df_select.pivot(
                    index=['subject'], 
                    columns=['channel', 'freq'], 
                    values=['psd', 'apc', 'pc', 'epoch_num']
                    )


value = wide_df_fooof_func.columns.get_level_values(0)
channel = wide_df_fooof_func.columns.get_level_values(1)
num = wide_df_fooof_func.columns.get_level_values(2)

# #Set the index to source_fileS
# wide_df_fcz = wide_df_fcz.set_index(('lab_num'))

#Rename the columns
wide_df_fooof_func.columns = project_name + '_' + value + '_' + channel + '_' + num 

In [None]:
#Multiply every column but the last one by 10^12
wide_fooof_func_microvolts = pd.concat([wide_df_fooof_func.filter(regex='(psd)|(apc)|(pc)') * 1e12, wide_df_fooof_func.filter(regex='(epoch_num)')], axis=1)

#Reset index
wide_fooof_func_microvolts.reset_index(inplace=True)

In [None]:
epoch_cols = [col for col in wide_fooof_func_microvolts.columns if 'epoch_num' in col]
wide_fooof_func_microvolts = wide_fooof_func_microvolts.drop(epoch_cols[1:], axis=1)

# Rename column that has epoch_num in it to epoch_num
epoch_num_name = [col for col in wide_fooof_func_microvolts.columns if 'epoch_num' in col]
wide_fooof_func_microvolts = wide_fooof_func_microvolts.rename(columns={epoch_num_name: 'epoch_num'})

In [None]:
wide_fooof_func_microvolts

In [None]:
#Convert index to string
wide_df_fooof_microvolts.index = wide_df_fooof_microvolts.index.astype(str)
#Convert index to string
wide_fooof_func_microvolts.index = wide_fooof_func_microvolts.index.astype(str)

#Reset index
wide_df_fooof_microvolts.reset_index(inplace=True)
#Reset index
wide_fooof_func_microvolts.reset_index(inplace=True)

#Set index to subject
wide_df_fooof_microvolts = wide_df_fooof_microvolts.set_index('subject')
#Set index to subject
wide_fooof_func_microvolts = wide_fooof_func_microvolts.set_index('subject')

# Set index to string
wide_df_fooof_microvolts.index = wide_df_fooof_microvolts.index.astype(str)
# Set index to string
wide_fooof_func_microvolts.index = wide_fooof_func_microvolts.index.astype(str)

In [None]:
#Concatenate the two dataframes
wide_df = pd.concat([ wide_fooof_func_microvolts, wide_df_fooof_microvolts], axis=1)
#Replace offset with off
wide_df.columns = wide_df.columns.str.replace('offset', 'off')
#Replace exponent with exp
wide_df.columns = wide_df.columns.str.replace('exponent', 'exp')

wide_df.reset_index(inplace=True)
wide_df

In [None]:
wide_df

In [None]:

now = datetime.datetime.now()
now = now.strftime("%Y-%m-%d_%H-%M")
pyreadstat.write_sav(wide_df, os.path.join(home_dir, f"final_fooof_{project_name}_{now}.sav"))