In [None]:
# Imports
#%matplotlib qt
%matplotlib inline

import matplotlib.pyplot as plt
import mne
import os
import numpy as np
from scipy.stats import pearsonr
from matplotlib import cm
import scipy

from fooof import FOOOF
from fooof import FOOOFGroup

In [None]:
fg = FOOOFGroup()
fg.save?

In [None]:
import sys

In [None]:
sys.path

In [None]:
# Set up paths

# This base path will need updating
base_path = '/Users/luyandamdanda/Documents/Research/EEG_Dat'

# These should stay the same
#subj_dat_fname = '._3502_resampled.set'
subj_dat_fname = '3503_resampled.set'

In [None]:
# Read in subject listed above

ev_dict = {'Start Block': 1001., 'End Block': 1002., 'Start Labelling Block':1003., 'End Labelling Block':1004}
# event dictionary to ensure "Start Block" and "End Block"

full_path = os.path.join(base_path, subj_dat_fname)
eeg_dat = mne.io.read_raw_eeglab(full_path, event_id=ev_dict)

In [None]:
eeg_dat.info

In [None]:
eeg_dat.set_eeg_reference()
# set EEG average reference

In [None]:
eeg_dat.plot();

In [None]:
# identifying all the events and IDs
mne.find_events(eeg_dat);

In [None]:
events = mne.find_events(eeg_dat)
event_id = {'Start Labelling Block':1003}

epochs = mne.Epochs(eeg_dat, events=events, event_id=event_id)

In [None]:
epochs.plot();

# Marking bad channels
### Based on visualization.


In [None]:
eeg_dat.info['bads'] = ['TP9']

In [None]:
print('Bad channels: ', eeg_dat.info['bads'])

In [None]:
events = mne.find_events(eeg_dat)
print('Found %s events.')
print()

# Plot the events to get an idea of the paradigm
# Specify colors and an event_id dictionary for the legend.
# Plotting only Start and Rest Block events
event_id = {'Start Labelling Block':1003}
color = {1003:'red'}

mne.viz.plot_events(events, eeg_dat.info['sfreq'], eeg_dat.first_samp, color=color,
                    event_id=event_id);

# Marking bad moments
### Identifying bad moments based on visualization

In [None]:
# Not to critical for this data

In [None]:
epochs = mne.Epochs(eeg_dat, events=events, event_id=event_id, tmin = 5, tmax = 125,
                    baseline = None)

In [None]:
chs = mne.channels.read_montage('standard_1020', epochs.ch_names[:-1])
epochs.set_montage(chs)

In [None]:
epochs.drop_bad()

In [None]:
epochs.plot();

In [None]:
#epochs.plot_psd(fmin = 0.5, fmax = 40);

In [None]:
# Calculate PSDs
psds, freqs = mne.time_frequency.psd_welch(epochs, fmin=3., fmax=40., n_fft=500)

# Average PSDs for each channel across each rest block
avg_psds = np.mean(psds, axis=0)

In [None]:
avg_psds.shape

In [None]:
freqs.shape

In [None]:
psds.shape

In [None]:
# Get freq res of PSD
f_res = np.mean(np.diff(freqs))

In [None]:
plt.plot(freqs, np.log10(psds[0, 30, :]))
# The 1st PSD

In [None]:
plt.plot(freqs, np.log10(psds[1, 30, :]))
# The 2nd PSD

# FOOOFing Data

In [None]:
fooof_psds = np.squeeze(psds[0,:,:])

In [None]:
fooof_psds.shape

In [None]:
# Initialzing FOOOF model
fm = FOOOF()

# Setting frequency range
freq_range = [2, 40]

fm.model(freqs, fooof_psds[22 , :], freq_range)

In [None]:
# Run FOOOF across a group of PSDs
fg = FOOOFGroup(verbose=False)
fg.fit_group(freqs, fooof_psds, freq_range)

In [None]:
fg.plot()

In [None]:
sls = fg.get_all_dat('background_params', 1)

In [None]:
sls

In [None]:
sls.shape

In [None]:
# Get all osc data. Loses channel information. NO good. 
osc_dat = fg.get_all_dat('oscillations_params')
osc_dat.shape

In [None]:
n_channels, n_freq = fooof_psds.shape

In [None]:
n_channels

In [None]:
fooof_results = np.zeros(shape = [n_channels,3])

In [None]:
# NOTE:
def get_band_osc(osc_params, band_def, ret_one=True):
    """Searches for a given band of interest within a list of oscillation
    
    Parameters
    ----------
    osc_params : 2d array
        Oscillations parameters, from FOOOF. [n_oscs, 3] 
    band_def : [float, float]
        Defines the band of interest
    ret_one : bool
        Whether to return single oscillation (or all found)
        
    Return
    ---------
    osc_out : array
        Osc data, form - (centers, powers, bws, # oscillations).  
    """
    
    # Catch & return if empty
    if not np.all(osc_params):
        return [np.nan, np.nan, np.nan]
    
    # Find indices of oscillations in the specified range
    osc_inds = (osc_params[:, 0] >= band_def[0]) & (osc_params[:, 0] <= band_def[1])
    
    # Gets the number of oscillations within the specified range
    n_oscs = sum(osc_inds)
    
    # If there are no 
    if n_oscs == 0:
        return np.array([np.nan, np.nan, np.nan])
    
    band_oscs = osc_params[osc_inds, :]

    # If results > 1 and ret_one, then we return the highest power oscillation
    #    Call a sub-function to select highest power oscillation
    #.   Note: see omegamappin/om/meg/single.py - _get_single_osc_power function. 
    if n_oscs > 1 and ret_one:
        # Get highest power oscillation in band
        band_oscs = get_highest_power_osc(band_oscs)
    
    # If results == 1, return osc - [cen, power, bw]
    return np.squeeze(band_oscs)


In [None]:
def get_highest_power_osc(band_oscs):
    """Searches for the highest power oscillation within a band of interest
    
    Parameters
    ----------
    osc_params : 2d array
        Oscillations parameters, from FOOOF. [n_oscs, 3] 
        
    Return
    ---------
    band_oscs : array
        Osc data, form - (centers, powers, bws, # oscillations).  
    """
    
    # Catch & return if empty
    if not np.all(band_oscs):
        return [np.nan, np.nan, np.nan]
    
    high_ind = np.argmax(band_oscs[:, 1])
    return band_oscs[high_ind, :]

In [None]:
fg.group_results[6].oscillations_params

In [None]:
get_band_osc(fg.group_results[6].oscillations_params,[10,12])

In [None]:
np.array([[1,2],[3,4]])

In [None]:
dat = np.array([[10, 1, 1.8],[14, 2, 4],[12, 3, 2]])

In [None]:
dat = np.array([[10, 1, 1.8],[14, 2, 4],[12, 3, 2]])
assert np.array_equal(get_highest_power_osc(dat),[12, 3, 2])

In [None]:
dat = np.array([[10, 1, 1.8],[14, 2, 4]])
assert np.array_equal(get_band_osc(dat,[10, 12]),[10, 1, 1.8])
assert np.all(np.isnan(get_band_osc(dat, [4, 8]))) 
assert np.array_equal(get_band_osc(dat, [10, 14], ret_one=False),[[10, 1, 1.8],[14, 2, 4]])

In [None]:
for i, ch_dat in enumerate(fg.group_results):
    fooof_results[i,:] = get_band_osc(ch_dat.oscillations_params, [8, 12])

In [None]:
fooof_results.shape

In [None]:
cfs = fooof_results[:, 0]
amps = fooof_results[:, 1]
bws = fooof_results[:,2]

In [None]:
cfs

In [None]:
amps

In [None]:
bws

In [None]:
cfs = np.nan_to_num(cfs)
amps = np.nan_to_num(amps)
bws = np.nan_to_num(bws)

In [None]:
# Define our oscillation bands
bands = [['Theta', [4, 8]], ['Alpha', [8, 12]]]

In [None]:
# DESIRED OUTPUT:
#  1d array, len of n_channels with osc_band freq for each channel
#  Note: deal with the band in channel: try out different markers for none (0, nan, etc.)

# Mapping

In [None]:
montage = mne.channels.read_montage('standard_1020', epochs.ch_names[:-1])
epochs.set_montage(montage)

In [None]:
mne.viz.plot_topomap(cfs, epochs.info, cmap=cm.viridis, contours=0);

In [None]:
mne.viz.plot_topomap(amps, epochs.info, cmap=cm.viridis, contours=0);

In [None]:
mne.viz.plot_topomap(bws, epochs.info, cmap=cm.viridis, contours=0);