In [None]:
# Imports
#%matplotlib qt
%matplotlib inline

import matplotlib.pyplot as plt
import mne
import os
import numpy as np
from scipy.stats import pearsonr
from matplotlib import cm
import scipy

from fooof import FOOOF
from fooof import FOOOFGroup

In [None]:
fg = FOOOFGroup()

In [None]:
import sys

In [None]:
# Set up paths

# This base path will need updating
base_path = '/Users/luyandamdanda/Documents/Research/EEG_Dat'

# These should stay the same
#subj_dat_fname = '._3502_resampled.set'
subj_dat_fname = '3503_resampled.set'

In [None]:
# Read in subject listed above

ev_dict = {'Start Block': 1001., 'End Block': 1002., 'Start Labelling Block':1003., 'End Labelling Block':1004}
# event dictionary to ensure "Start Block" and "End Block"

full_path = os.path.join(base_path, subj_dat_fname)
eeg_dat = mne.io.read_raw_eeglab(full_path, event_id=ev_dict)

In [None]:
eeg_dat.info

In [None]:
eeg_dat.set_eeg_reference()
# set EEG average reference

In [None]:
eeg_dat.plot();

In [None]:
# identifying all the events and IDs
mne.find_events(eeg_dat);

In [None]:
events = mne.find_events(eeg_dat)
event_id = {'Start Labelling Block':1003}

epochs = mne.Epochs(eeg_dat, events=events, event_id=event_id)

In [None]:
epochs.plot();

# Marking bad channels
### Based on visualization.


In [None]:
eeg_dat.info['bads'] = ['TP9']

In [None]:
print('Bad channels: ', eeg_dat.info['bads'])

In [None]:
events = mne.find_events(eeg_dat)
print('Found %s events.')
print()

# Plot the events to get an idea of the paradigm
# Specify colors and an event_id dictionary for the legend.
# Plotting only Start and Rest Block events
event_id = {'Start Labelling Block':1003}
color = {1003:'red'}

mne.viz.plot_events(events, eeg_dat.info['sfreq'], eeg_dat.first_samp, color=color,
                    event_id=event_id);

# Marking bad moments
### Identifying bad moments based on visualization

In [None]:
# Not to critical for this data

In [None]:
epochs = mne.Epochs(eeg_dat, events=events, event_id=event_id, tmin = 5, tmax = 125,
                    baseline = None)

In [None]:
chs = mne.channels.read_montage('standard_1020', epochs.ch_names[:-1])
epochs.set_montage(chs)

In [None]:
epochs.drop_bad()

In [None]:
epochs.plot();

In [None]:
#epochs.plot_psd(fmin = 0.5, fmax = 40);

In [None]:
# Calculate PSDs
psds, freqs = mne.time_frequency.psd_welch(epochs, fmin=3., fmax=40., n_fft=500)

# Average PSDs for each channel across each rest block
avg_psds = np.mean(psds, axis=0)

In [None]:
# Get freq res of PSD
f_res = np.mean(np.diff(freqs))

In [None]:
plt.plot(freqs, np.log10(psds[0, 30, :]))
# The 1st PSD

In [None]:
plt.plot(freqs, np.log10(psds[1, 30, :]))
# The 2nd PSD

# FOOOFing Data

In [None]:
fooof_psds = np.squeeze(psds[0,:,:])

In [None]:
# Initialzing FOOOF model
fm = FOOOF()

# Setting frequency range
freq_range = [2, 40]

fm.model(freqs, fooof_psds[22 , :], freq_range)

In [None]:
# Run FOOOF across a group of PSDs
fg = FOOOFGroup(verbose=False)
fg.fit_group(freqs, fooof_psds, freq_range)

In [None]:
fg.plot()

In [None]:
# Get all osc data. Loses channel information. NO good. 
osc_dat = fg.get_all_dat('oscillations_params')
osc_dat.shape

In [None]:
n_channels, n_freq = fooof_psds.shape

In [None]:
n_channels

In [None]:
fooof_results = np.zeros(shape = [n_channels,3])

In [None]:
for i, ch_dat in enumerate(fg.group_results):
    fooof_results[i,:] = get_band_osc(ch_dat.oscillations_params, [8, 12])

In [None]:
cfs = fooof_results[:, 0]
amps = fooof_results[:, 1]
bws = fooof_results[:,2]

In [None]:
cfs = np.nan_to_num(cfs)
amps = np.nan_to_num(amps)
bws = np.nan_to_num(bws)

In [None]:
# Define our oscillation bands
bands = [['Theta', [4, 8]], ['Alpha', [8, 12]]]

In [None]:
# DESIRED OUTPUT:
#  1d array, len of n_channels with osc_band freq for each channel
#  Note: deal with the band in channel: try out different markers for none (0, nan, etc.)

# Mapping

In [None]:
montage = mne.channels.read_montage('standard_1020', epochs.ch_names[:-1])
epochs.set_montage(montage)

In [None]:
mne.viz.plot_topomap(cfs, epochs.info, cmap=cm.viridis, contours=0);

In [None]:
mne.viz.plot_topomap(amps, epochs.info, cmap=cm.viridis, contours=0);

In [None]:
mne.viz.plot_topomap(bws, epochs.info, cmap=cm.viridis, contours=0);