In [None]:
import os,sys
import mne
import utils  #my code
import json
import matplotlib.pyplot as plt
import numpy as np
%load_ext autoreload
%autoreload 2

if os.environ.get('DATA_DUSS') is not None:
    data_dir = os.path.expandvars('$DATA_DUSS')
else:
    data_dir = '/home/demitau/data'
    


rawname_ = 'S01_off_hold'
#rawname_ = 'S01_on_hold'
rawname = rawname_ + '_resample_raw.fif'
fname_full = os.path.join(data_dir,rawname)
    

In [None]:
# read file -- resampled to 256 Hz,  Electa MEG, EMG, LFP, EOG channels
raw = mne.io.read_raw_fif(fname_full, None)

In [None]:
# apparently tfr_morlet does not use baseline info from Epochs itself
# MNE baseline (locate in rescale funtion) does not get rid of artifact high values
# the type of basline corr that Jan referred to is called in MNE 'percent'

In [None]:
%matplotlib qt

In [None]:
# Look for visual alpha and motor mu

In [None]:
#raw.times[-1]

In [None]:
epdur = 30
endrec = raw.times[-1]
new_events = mne.make_fixed_length_events(raw, start=0, stop=endrec, duration=epdur)
epochs = mne.Epochs(raw,new_events, tmin=0,tmax = epdur, baseline=None)

freqs = np.linspace(8, 13, num=10)

# sh=2
# cfq = 15
# freqs = np.logspace(*np.log10([cfq-sh,cfq+sh ]), num=6)
n_cycles = freqs / 2.  # different number of cycle per frequency
power = mne.time_frequency.tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=True,
                        return_itc=False, decim=3, n_jobs=10)

In [None]:
#power.plot_topomap(sensors=True, contours=10, tmin=None, tmax=None, fmin=8, fmax=12);

In [None]:
#power.plot_topo(title='Average power', sensors=True, contours=5);

In [None]:
power.plot_topo(baseline=[-1e-10,0], mode='percent', title='Average power');

In [None]:
#reshuffle channels types (by default LFP and EMG types are determined wronng)

# set types for some misc channels
for i,chn in enumerate(raw.ch_names):
    #chn = raw.ch_names[chi]
    show = 0
    if chn.find('_old') >= 0:
        raw.set_channel_types({chn:'emg'}); show = 1
    elif chn.find('_kil') >= 0:
        raw.set_channel_types({chn:'misc'}); show = 1
    elif chn.find('LFP') >= 0:
        raw.set_channel_types({chn:'bio'}); show = 1  # or stim, ecog, eeg
    
    if show:
        print(i, chn )
        
        
bt = mne.io.pick.channel_indices_by_type(raw.info)
miscchans = bt['misc']
gradchans = bt['grad']
magchans = bt['mag']
eogchans = bt['eog']
emgchans = bt['emg']
biochans = bt['bio']
#mne.pick_channels(raw,miscchans)

print('miscchans', len(miscchans))
print('gradchans', len(gradchans) )
print('magchans', len(magchans))
print('eogchans', len(eogchans))
print('emgchans', len(emgchans))
print('biochans', len(biochans))
print( len(miscchans) + len(gradchans) + len(magchans) + len(eogchans) + len(emgchans) +
      len(biochans), len(raw.ch_names) )
print(len(raw.info['bads']))

In [None]:
# get info about bad MEG channels (from separate file)
with open('subj_info.json') as info_json:
        #raise TypeError

    #json.dumps({'value': numpy.int64(42)}, default=convert)
    gen_subj_info = json.load(info_json)
    
subj,medcond,task  = utils.getParamsFromRawname(rawname_)
badchlist = gen_subj_info[subj]['bad_channels'][medcond][task]
raw.info['bads'] = badchlist
print('bad channels are ',badchlist)

In [None]:
gen_subj_info[subj]['tremor_side']

#The event list contains three columns. The first column corresponds to sample number. To convert this to seconds, you should divide the sample number by the used sampling frequency. The second column is reserved for the old value of the trigger channel at the time of transition, but is currently not in use. The third column is the trigger id (amplitude of the pulse).

## load tremor labels

In [None]:
import utils
trem_times_fn = 'trem_times_tau.json'
with open(trem_times_fn ) as jf:
    trem_times_byhand = json.load(jf)   
trem_times_nms_fn = 'trem_times_tau_nms.json'
with open(trem_times_nms_fn ) as jf:
    trem_times_nms_byhand = json.load(jf)   

#%debug
tremIntervalJan, artif         = utils.unpackTimeIntervals(trem_times_byhand, mainSide = True, 
                                                           gen_subj_info=gen_subj_info, skipNotLoadedRaws=0)
tremIntervalJan_nms, artif_nms = utils.unpackTimeIntervals(trem_times_nms_byhand, mainSide = False, 
                                                           gen_subj_info=gen_subj_info, skipNotLoadedRaws=0)
for rawn in [rawname_]:
    if rawn in artif_nms and rawn not in artif:
        artif[rawn] = artif_nms[rawn]
    else:
        if rawn in artif_nms:
            artif[rawn].update(artif_nms[rawn] )
        
for rawn in tremIntervalJan:
    sind_str,medcond,task = utils.getParamsFromRawname(rawn)
    maintremside = gen_subj_info[sind_str]['tremor_side']
    opside= utils.getOppositeSideStr(maintremside)
    if rawn in tremIntervalJan_nms:
        tremIntervalJan[rawn][opside] = tremIntervalJan_nms[rawn][opside] 


mvtTypes = ['tremor', 'no_tremor', 'unk_activity']

plotTremNegOffset = 2.
plotTremPosOffset = 2.
maxPlotLen = 6   # for those interval that are made for plotting, not touching intervals for stats
addIntLenStat = 5
plot_time_end = 150

timeIntervalPerRaw_processed = utils.processJanIntervals(tremIntervalJan, maxPlotLen, addIntLenStat, 
                          plotTremNegOffset, plotTremPosOffset, plot_time_end, mvtTypes=mvtTypes)

In [None]:
sind_str

In [None]:
# for current raw
maintremside = gen_subj_info[subj]['tremor_side']
nonmaintremside = utils.getOppositeSideStr(maintremside)
intervals = timeIntervalPerRaw_processed[rawname_][maintremside]   #[rawn][side] -- list of tuples (beg,end, type string)]   #[rawn][side] -- list of tuples (beg,end, type string)
intervals_nms = timeIntervalPerRaw_processed[rawname_][nonmaintremside]   #[rawn][side] -- list of tuples (beg,end, type string)]   #[rawn][side] -- list of tuples (beg,end, type string)

# convert to intervalType -> intervalInds
import globvars as gv
ivalis = {}  # dict of indices of intervals per interval type
ivalis_nms = {}
for itype in gv.gparams['intTypes']:
    ivit = []
    for i,interval in enumerate(intervals):
        t1,t2,it = interval

        if it == itype:
            ivit += [i]
    if len(ivit) > 0:
        ivalis[itype] = ivit
        
    ivit = []
    for i,interval in enumerate(intervals_nms):
        t1,t2,it = interval

        if it == itype:
            ivit += [i]
    if len(ivit) > 0:
        ivalis_nms[itype] = ivit

print('Main tremor side here is ',maintremside)

display('all intervals:' ,intervals)
display('intervals by type:', ivalis )

# convert intervals to MNE type
annotation_desc_2_event_id = {'middle_full':0, 'no_tremor':1, 'endseg':2}
annotation_desc_2_event_id = {'middle_full':0, 'no_tremor':1}

onset = [ ivl[0] for ivl in intervals ]
duration = [ ivl[1]-ivl[0] for ivl in intervals ]
description = [ ivl[2] for ivl in intervals ]
annot = mne.Annotations(onset, duration, description)
raw.set_annotations(annot)

In [None]:
ivalis_nms

#The event list contains three columns. The first column corresponds to sample number. To convert this to seconds, you should divide the sample number by the used sampling frequency. The second column is reserved for the old value of the trigger channel at the time of transition, but is currently not in use. The third column is the trigger id (amplitude of the pulse).

In [None]:
#just to make it faster
fastTest = 1
if fastTest:
    fastTest_dataCropTime = 300
    #fastTest_dataCropTime = 100
    raw.crop(tmin=0, tmax=fastTest_dataCropTime)
raw.load_data()

In [None]:
help(raw.plot)

## Mawell filter

In [None]:
#Maxwell filter
fine_cal_file  = os.path.join(data_dir,  'sss_cal.dat')
crosstalk_file = os.path.join(data_dir,  'ct_sparse.fif')
raw_sss = mne.preprocessing.maxwell_filter(raw, cross_talk=crosstalk_file,
                                           calibration=fine_cal_file, coord_frame='meg')

In [None]:
%matplotlib qt

In [None]:
#%matplotlib inline
#fig,axs = plt.subplots(2,1,figsize=(10,5))
raw.copy().pick(['meg']).plot(duration=2, butterfly=True);
raw_sss.copy().pick(['meg']).plot(duration=2, butterfly=True);

In [None]:
import gc; gc.collect()

In [None]:
#raw_sss.info

# Notch

In [None]:
# Butterwordth 4th order -- for stopping power grid noise

In [None]:
#help(raw_sss.notch_filter)

In [None]:
import numpy as np
freqsToKill = np.arange(50, 128, 50) # harmonics of 50
raw_sss.notch_filter(freqsToKill, picks=['meg','bio', 'emg'])

In [None]:
#help(raw.plot_psd)

In [None]:
#tfr = AverageTFR(epochs.info, con, times, freqs, len(epochs))
#tfr.plot_topo(fig_facecolor='w', font_color='k', border='k')

In [None]:
intervals[-1]

In [None]:
intind = 2
ival = intervals[intind]; display( intervals[intind] )
a, b, it = ival
assert it == 'middle_full'

In [None]:
intind = -1
ival = intervals[intind]; display( intervals[intind] )
a0, b0, it0 = ival
assert it0 == 'no_tremor'

In [None]:
#some PSD to look at
%matplotlib inline
#%debug
#ax = plt.gca()
raw.plot_psd(picks='meg' ,     fmax=40, tmin=a, tmax=b);
raw_sss.plot_psd(picks='meg' , fmax=40, tmin=a, tmax=b);

In [None]:
#some PSD to look at
%matplotlib inline
#%debug
#ax = plt.gca()
raw.plot_psd(picks='meg' ,     fmax=40, tmin=a0, tmax=b0 );
raw_sss.plot_psd(picks='meg' , fmax=40, tmin=a0, tmax=b0 );

In [None]:
#advise: plot topomap,   no special pattern -- weird

In [None]:
#raw_sss.plot_psd_topomap(ch_type='grad', normalize=True)

In [None]:
help(raw.pick)

In [None]:
#ri2 = raw.pick(['meg','misc','bio'])
#mne.io.pick.channel_indices_by_type(ri2.info)

In [None]:
import gc; gc.collect()

In [None]:
#some PSD to look at
%matplotlib inline
#%debug
#ax = plt.gca()
raw_sss.plot_psd(picks=gradchans[:2] );
raw_sss.plot_psd(picks=biochans);
raw_sss.plot_psd(picks=emgchans);

## Question: does filter order/type matter?

In [None]:
help(raw.filter)

In [None]:
raw_sss.info

# Look per freq

In [None]:
# I need to create single epoch from the data apparently

In [None]:
new_events

In [None]:
dur = 20
new_events = mne.make_fixed_length_events(raw_sss, start=0, stop=300, duration=dur)
epochs = mne.Epochs(raw_sss,new_events, tmin=0,tmax = dur, baseline=None)

In [None]:
freqs

In [None]:
#epochs.info['subject_info']

In [None]:
5121/256

In [None]:
epochs.get_data().shape

In [None]:
intervals

In [None]:
#raw.info['bads']

In [None]:
#mne.viz.plot_tfr_topomap

# define frequencies of interest (log-spaced)
#freqs = np.logspace(*np.log10([6, 48]), num=40)
freqs = np.linspace(6, 48, num=40)

# sh=2
# cfq = 15
# freqs = np.logspace(*np.log10([cfq-sh,cfq+sh ]), num=6)
n_cycles = freqs / 2.  # different number of cycle per frequency
power = mne.time_frequency.tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=True,
                        return_itc=False, decim=3, n_jobs=10)

In [None]:
%matplotlib qt

In [None]:
#power.__file__

In [None]:
#help(mne.time_frequency.tfr_morlet)

In [None]:
#help(power.plot_topo)

In [None]:
mne.time_frequency.__file__

In [None]:
#%matplotlib qt
#power.plot_topo(baseline=(-0.5, 0), mode='logratio', title='Average power');
#power.plot_topo(baseline=(-1e-5, 0), mode='logratio', title='Average power');
power.plot_topo(baseline=None, mode='logratio', title='Average power');


# fig, axis = plt.subplots(1, 2, figsize=(7, 4))
# power.plot_topomap(ch_type='grad', tmin=0.5, tmax=1.5, fmin=8, fmax=12,
#                    baseline=(-0.5, 0), mode='logratio', axes=axis[0],
#                    title='Alpha', show=False)
# power.plot_topomap(ch_type='grad', tmin=0.5, tmax=1.5, fmin=13, fmax=25,
#                    baseline=(-0.5, 0), mode='logratio', axes=axis[1],
#                    title='Beta', show=False)
# mne.viz.tight_layout()

In [None]:
#high pass 1Hz, apparently it should help artifact removal
filt_raw = raw_sss.copy()
filt_raw = filt_raw.load_data().filter(l_freq=2., h_freq=None)  # advise: maybe 0.5

In [None]:
# get rid of bad channels
filt_raw.pick_types(meg=True, misc=False, bio=True, emg=True, eog=True, exclude='bads')

In [None]:
# look at not-meg part of the data
notmeg = filt_raw.copy().pick_types(meg=False, bio=True, emg=True, eog=True)
#notmeg.plot(duration=10, n_channels=100, remove_dc=True);

In [None]:
# what is a projection?
# gradient compensation?

In [None]:
#help(filt_raw.plot)

In [None]:
#help(filt_raw)

In [None]:
%matplotlib notebook

In [None]:
#%matplotlib qt
#%matplotlib inline
#%matplotlib TkAgg

#import matplotlib as mpl
#mpl.use('TkAgg')
#mpl.use('Qt5Cairo')
#['GTK3Agg', 'GTK3Cairo', 'MacOSX', 'nbAgg', 'Qt4Agg', 'Qt4Cairo', 'Qt5Agg', 'Qt5Cairo', 
# 'TkAgg', 'TkCairo', 'WebAgg', 'WX', 'WXAgg', 'WXCairo', 'agg', 'cairo', 'pdf', 'pgf', 'ps', 'svg', 'template']
# look at not-meg part of the data
# butterfly gives strange pic

#filt_raw.plot(duration=20,n_channels=10, bad_color='yellow', scalings='auto', butterfly=0);

## Question: what is the right way to use ICA? Do I use it together/before/after  create_ecg_epochs, create_eog_epochs ?

In [None]:
filt_raw.info

## Question: how should I use these params for ICA?
n_components  int | float | None

    Number of principal components (from the pre-whitening PCA step) that are passed to the ICA algorithm during fitting. If int, must not be larger than max_pca_components. If float between 0 and 1, the number of components with cumulative explained variance less than n_components will be used. If None, max_pca_components will be used. Defaults to None; the actual number used when executing the ICA.fit() method will be stored in the attribute n_components_ (note the trailing underscore).
max_pca_components  int | None

    Number of principal components (from the pre-whitening PCA step) that are retained for later use (i.e., for signal reconstruction in ICA.apply(); see the n_pca_components parameter). If None, no dimensionality reduction occurs and max_pca_components will equal the number of channels in the mne.io.Raw, mne.Epochs, or mne.Evoked object passed to ICA.fit().
n_pca_components  int | float | None

    Total number of components (ICA + PCA) used for signal reconstruction in ICA.apply(). At minimum, at least n_components will be used (unless modified by ICA.include or ICA.exclude). If n_pca_components > n_components, additional PCA components will be incorporated. If float between 0 and 1, the number is chosen as the number of PCA components with cumulative explained variance less than n_components (without accounting for ICA.include or ICA.exclude). If int or float, n_components_ ≤ n_pca_components ≤ max_pca_components must hold. If None, max_pca_components will be used. Defaults to None.


You can impose an optional dimensionality reduction at this step by specifying max_pca_components. From the retained Principal Components (PCs), the first n_components are then passed to the ICA algorithm (n_components may be an integer number of components to use, or a fraction of explained variance that used components should capture).

After visualizing the Independent Components (ICs) and excluding any that capture artifacts you want to repair, the sensor signal can be reconstructed using the ICA object’s apply() method. By default, signal reconstruction uses all of the ICs (less any ICs listed in ICA.exclude) plus all of the PCs that were not included in the ICA decomposition (i.e., the “PCA residual”). If you want to reduce the number of components used at the reconstruction stage, it is controlled by the n_pca_components parameter (which will in turn reduce the rank of your data; by default n_pca_components = max_pca_components resulting in no additional dimensionality reduction)

because filtering is a linear operation, the ICA solution found from the filtered signal can be applied to the unfiltered signal (see 2 for more information), so we’ll keep a copy of the unfiltered Raw object around so we can apply the ICA solution to it later.

In [None]:
from mne.preprocessing import ICA

addstr = ''
if fastTest:
    addstr = 'fast_'
icafname = '{}_{}resampled-ica.fif.gz'.format(rawname_,addstr)

    
icafname_full = os.path.join(data_dir,icafname)
print(icafname_full)
loadICA = 1
saveICA = 0

if loadICA and os.path.exists(icafname_full):
    ica = mne.preprocessing.read_ica(icafname_full)
else:
    ica = ICA(n_components = 0.95, random_state=0).fit(filt_raw)
    if saveICA:
        ica.save(icafname_full)
#ica = ICA(n_components = 20, random_state=0).fit(filt_raw)

In [None]:
import gc; gc.collect()

In [None]:
#components = ica.get_components()
#components.shape

In [None]:
icacomp = ica.get_sources(filt_raw)

In [None]:
icacomp.ch_names

## SHould I mark as invalid components making large excursions?

In [None]:
intervals

In [None]:
filt_raw.plot_psd()

In [None]:
%matplotlib qt

In [None]:
#%matplotlib notebook
ica_inds = ica.plot_sources(raw_sss)

In [None]:
print('Components that I have found by hand', ica.exclude)

In [None]:
ica.exclude

In [None]:
[icacomp.ch_names[i] for i in ica.exclude]

In [None]:
# Look at variance evolution of components

In [None]:
#EOG127, EOG128, EMG061_old - 64

In [None]:
#ica.plot_components();  #topography

## Plot Component info MNE 

In [None]:
nonexcluded = list(  set( range( len(ica.ch_names) ) ) - set(ica.exclude) )
sorted(nonexcluded)

In [None]:
s = ','.join(map(str,ica.exclude) )
exclStr = 'excl_' + s
exclStr

In [None]:
#%%capture
from matplotlib.backends.backend_pdf import PdfPages
compinds = ica.exclude + nonexcluded[:5]
nr = len(compinds); nc = 2

with PdfPages('{}_ica_components_{}.pdf'.format(rawname_,exclStr)) as pdf:

#fig,axs = plt.subplots(nrows=nr,ncols=nc)
#fig = plt.gcf()
    figs = mne.viz.plot_ica_properties(ica,filt_raw,compinds, show=0 )
    for fig in figs:
        pdf.savefig(fig)
        plt.close()
#plt.savefig('fd.pdf')

In [None]:
%matplotlib qt

## Component info by hand

In [None]:
import pandas as pd

windowsz_sec = 3
time_to_look_dur = 15
time_to_look_start = 0; time_to_look_end = time_to_look_dur + time_to_look_start
t0,t1 = icacomp.time_as_index([time_to_look_start,time_to_look_end])
windowsz = int( windowsz_sec * icacomp.info['sfreq'] )
#stds = pd.rolling_std(chdata, windowsz)

In [None]:
%matplotlib notebook
#ind = 0
ww = 5; hh = 3; 
compIndsToShow = ica.exclude + [1,2,3]
nc = 3; nr = len(compIndsToShow) + 2 + 2*2
fig,axs = plt.subplots(nrows=nr, ncols = nc, figsize = (nc*ww,nr*hh))
ylim_var_min = 0.5
ylim_var_max = 2

for axind,ind in enumerate(compIndsToShow ):
    chdata,chtimes = icacomp[ind]; chdata=chdata.flatten()
    chtimes = chtimes[t0:t1]
    chdata = chdata[t0:t1]    
    stds = pd.Series(chdata).rolling(windowsz).std()
    ax = axs[axind,0]
    ax.plot(chtimes,chdata)
    ax.set_title('Raw component {}'.format(ind))
    ax.set_xlabel('sec')
    
    ax = axs[axind,1]
    ax.plot(chtimes,stds)
    vartot = np.std(chdata)
    ax.set_title('Component {} variance  (window size = {:.2f}s), tot = {:.3f}'.format(ind, windowsz_sec, vartot))
    ax.set_xlabel('sec')
    ax.set_ylim(ylim_var_min,ylim_var_max)
    
    ax = axs[axind,2]
    ax.psd(chdata, Fs = int(icacomp.info['sfreq']))
    
for axind,ind in enumerate(compIndsToShow ):
    if ind in ica.exclude:
        continue
    for ii in range(nc):
        axs[axind,ii].set_facecolor('lightgreen')

for chni,chn in enumerate(['EOG127', 'EOG128', 'EMG061_old', 'EMG062_old', 'EMG063_old', 'EMG064_old'] ):
    chdata,chtimes = filt_raw[chn]; chdata=chdata.flatten()
    chtimes = chtimes[t0:t1]
    chdata = chdata[t0:t1]  
    
    ax = axs[nr-1-chni,0]
    ax.set_title(chn)
    ax.plot(chtimes,chdata)
    
    ax = axs[nr-1-chni,2]
    ax.psd(chdata, Fs = int(icacomp.info['sfreq']))

plt.tight_layout()
plt.savefig('components_raw_var_spec.pdf')        
        
ica.plot_components(compIndsToShow);

plt.savefig('components_top.pdf')

In [None]:
ax

In [None]:
fig.show()

In [None]:
ax = axs[0,0]
ax.remove()

In [None]:
fg = plt.figure()
ax = axs[0,0]
fg.axes.append(ax)
fg.add_axes(ax)

###  Q: What is a "bad" epoch?

In [None]:
ecg_epochs = mne.preprocessing.create_ecg_epochs(filt_raw)

In [None]:
ecg_epochs

In [None]:
ecg_epochs.plot_image(combine='median')

In [None]:
ecg_epochs_av = ecg_epochs.average()

In [None]:
import gc; gc.collect()

## Q: should I apply it to raw or to high-passed raw?  
## Q2: if my ICA was done on one data object what does it mean to take "sources from" another?

In [None]:
# tyring to find ECG events directly from raw
ecg_inds, scores = ica.find_bads_ecg(raw_sss)

In [None]:
print(ecg_inds)

#what do image parts mean?
if len(ecg_inds):
    ica.plot_properties(raw_sss,picks=ecg_inds)

In [None]:
ecg_inds2, scores2 = ica.find_bads_ecg(ecg_epochs)

In [None]:
ecg_inds2, scores2

In [None]:
#eog_inds, scores = ica.find_bads_eog(raw)

eog_epochs = mne.preprocessing.create_eog_epochs(raw_sss)  # get single EOG trials
eog_inds, scores = ica.find_bads_eog(eog_epochs)

In [None]:
bt = mne.io.pick.channel_indices_by_type(raw.info)
gradchans = bt['grad']

eog_epochs.plot_image(combine='mean',  picks=gradchans[8:10])

In [None]:
import numpy as np
plt.figure(figsize=(12,3))
sortinds = np.abs(scores)[0,:].argsort()[::-1]
pr = np.abs(scores)[0,sortinds]
plt.xticks(np.arange(len(scores[0])))
#print(pr)
plt.plot(pr)

In [None]:
print('(highest) eog-related componenents',eog_inds)
print('scores of the highes components, one of the EOGs ',pr[:4])
print(len(scores), len(scores[0]), len(scores[1]), scores[-1][0] )

In [None]:
# collect eog - related components?
import numpy as np
ica.plot_scores(scores, exclude=eog_inds,  labels='eog');
for i in range(len(scores)):   # we have two channels
    show_picks = np.abs(scores)[i,:].argsort()[::-1][:5]  # zeroth is largest
    print(show_picks, show_picks.shape)

In [None]:
#ica.exclude = []
print(ica.exclude, eog_inds, ecg_inds)
exclude_bak = ica.exclude

In [None]:
%matplotlib qt

In [None]:
ica.plot_components(range(25),colorbar=1);

In [None]:
help(ica.plot_components)

In [None]:
ica.plot_components(ica.exclude,colorbar=1, axis = []);

In [None]:
# bin componenet in 100ms, plot variance over time

In [None]:
ica.exclude = []
%matplotlib inline
ee = [exclude_bak[:1], ecg_inds, eog_inds]
for toexclude in ee:
    if len(toexclude) == 0:
        continue
    ica.plot_overlay(filt_raw, exclude=toexclude, picks='meg', title='exclude {}'.format(toexclude) )
# ica.plot_overlay(filt_raw, exclude=ecg_inds, picks='grad')
# ica.plot_overlay(filt_raw, exclude=eog_inds, picks='meg')

In [None]:
#ica.plot_sources(raw, show_picks, exclude=eog_inds)
#ica.plot_components(eog_inds, colorbar=True)

ica.exclude = []
n_max_eog = 2
eog_inds = eog_inds[:n_max_eog]
ica.exclude += eog_inds
ica.exclude += ecg_inds


ica.exclude += exclude_bak[:1]
print(ica.exclude)

#### Q: Should I reject something?
#### Q: Do I better apply find_bads_eog on epochs or on entire raw?

# Reconstruct

In [None]:
reconst_raw = filt_raw.copy()
ica.apply(filt_raw)

In [None]:
reconst_raw.plot(duration=20,n_channels=10, bad_color='yellow', scalings='auto');

In [None]:
filt_raw.plot(duration=20,n_channels=10, bad_color='yellow', scalings='auto');

## Compare PSD

In [None]:
#%matplotlib qt
%matplotlib inline

In [None]:
#some PSD to look at
#%debug
#plt.figure()
# fig,axs = plt.subplots(4,2)
# filt_raw.plot_psd(picks='meg' , ax = axs[:2,0] );
# filt_raw.plot_psd(picks='bio', ax=axs[2,0]);
# filt_raw.plot_psd(picks='emg', ax=axs[3,0]);

#making 4x2  would not work because of bug in mne, the use ax_list[0].get_figure(), 
# but it gives list of axes if we have two columns
#ax = plt.gca()
filt_raw.plot_psd(picks='meg'  );
#filt_raw.plot_psd(picks='bio');
#filt_raw.plot_psd(picks='emg');


In [None]:
raw.info['bads']

In [None]:
raw.info['bads'][0] in filt_raw.ch_names

In [None]:
raw.info['bads'][0] in reconst_raw.ch_names

In [None]:
filt_raw.ch_names

In [None]:
help(reconst_raw.plot_psd)

In [None]:
#some PSD to look at
#%debug
ax = plt.gca()
#mne.time_frequency.psd
reconst_raw.plot_psd(picks='grad', fmax=20, ax=ax );
#reconst_raw.plot_psd(picks='bio');
#reconst_raw.plot_psd(picks='emg');

In [None]:
type(reconst_raw.get_channel_types())

In [None]:
set( reconst_raw.get_channel_types() )

In [None]:
meg_chis = np.where ( [ a in ['grad','mag'] for a in reconst_raw.get_channel_types() ] )[0]

meg_chnames = np.array(reconst_raw.ch_names)[meg_chis]

#for chn in meg_chnames
chn = 'MEG0113'
chd, times = reconst_raw[chn]
plt.hist(chd, bins=100)

In [None]:
chd

In [None]:
a0,b0

In [None]:
a,b

In [None]:
strec = 0
endrec = 300
epdur = endrec
new_events = mne.make_fixed_length_events(reconst_raw, start=strec, stop=endrec, duration=epdur)
epochs = mne.Epochs(raw,new_events, tmin=0,tmax = epdur, baseline=None)

freqs = np.linspace(3, 90, num=100)

# sh=2
# cfq = 15
# freqs = np.logspace(*np.log10([cfq-sh,cfq+sh ]), num=6)
n_cycles = freqs / 2.  # different number of cycle per frequency
power = mne.time_frequency.tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=True,
                        return_itc=False, decim=3, n_jobs=10)

In [None]:
a1,b1,uuu_ = intervals[ ivalis['middle_full'][1] ]

In [None]:
intervals

In [None]:
power.plot_topomap(sensors=True, contours=8, tmin=timin, tmax=timax, 
                   fmin=fbmin, fmax=fbmax, colorbar=True, size=40, res=100, show=0, sphere=np.array([0,0,0,1]));

### Plot power band concentrations

In [None]:
timeints = [ ('nonotrem',a0,b0), ('trem', a,b), ('trem2', a1,b1)]
fbs = [ ('tremfreq', 3,9), ('alpha/mu', 8,13),  ('beta', 15,30), ('gamma', 30,90)  ]
nc = len(fbs); nr = len(timeints)
ww = 4; hh = 3
headsph = np.array([0,0,0,0.9])
fig,axs = plt.subplots( nrows = nr, ncols = nc, figsize= (nc*ww, nr*hh))
for i,ti in enumerate(timeints):
    tiname,timin,timax = ti
    for j,fb in enumerate(fbs):
        fbname, fbmin, fbmax = fb
        ax = axs[i,j]
        ttl = 'epoch type: {};  band: {}'.format(tiname,fbname)
        ax.set_title(ttl)
        power.plot_topomap(sensors=True, contours=8, tmin=timin, tmax=timax, 
                           fmin=fbmin, fmax=fbmax, axes=ax, colorbar=True, size=40, res=100, show=0, sphere=headsph);
        #plt.gcf().suptitle('{} : {}'.format(tiname,fbname))
#plt.tight_layout()
plt.savefig('{}_sensor_bandpow_concentr.pdf'.format(rawname_ ))

In [None]:
help(reconst_raw.plot)

In [None]:
help(reconst_raw.filter)

In [None]:
reconst_raw.filter(picks='emg', l_freq=10, h_freq=None)

In [None]:
#left
leftEMG, times = reconst_raw[['EMG063_old', 'EMG064_old']]
plt.plot(times,leftEMG.T, c='b', alpha=0.5)
plt.figure()
rightEMG, times = reconst_raw[['EMG061_old', 'EMG062_old']]
plt.plot(times,rightEMG.T, c='g', alpha=0.5)

In [None]:
intervals

In [None]:
#reconst_raw[.plot(chn=['emg'])]

In [None]:
gen_subj_info[subj]['tremor_side']

# Save reconstructed

In [None]:
assert reconst_raw.info['bads'] == raw.info['bads']
savename = rawname_ + '_resample_afterICA_raw.fif'
print(savename)

In [None]:
reconst_raw.save(os.path.join(data_dir, savename), overwrite=True )

In [None]:
subj

In [None]:
reconst_raw.save

## Q: how to remove power grid noise component best?
for some reason my games with ICA did not help to remove power grid noise

In [None]:
# sensor-level analysis
# maybe perform ICA on TFA of LFP
# save updated file

In [None]:
#event_id = {'Tremor/Left/main': 1, 'Tremor/Right/other': 2}
#epochs has to be same length, so it won't work for me directly

In [None]:
print(description)

In [None]:
reconst_raw.set_annotations(annot)

In [None]:
#reconst_raw.plot(duration=10);

In [None]:
#help(ica.find_bads_eog)

In [None]:
events_tremrel, _ = mne.events_from_annotations(reconst_raw, event_id=annotation_desc_2_event_id, chunk_duration=None)

In [None]:
#help(mne.events_from_annotations)

In [None]:
# events = []
# event_types = { 'tremor_start':0  }

# intType = 'middle_full'
# inds = []
# for ind in ivalis[intType]:
#     a,b,it = intervals[ind]
#     #inds += [i]
#     abinind = reconst_raw.time_as_index(a)[0]
#     events += [ [abinind,0, event_types['tremor_start']]  ]
        
# events

In [None]:
event_types

In [None]:
help(mne.Epochs)

In [None]:
ivalis

In [None]:
a0,b0,it = intervals[ ivalis['no_tremor'][0] ]  # for baseline
a0binind,b0binind = reconst_raw.time_as_index([a0,b0])
epochs_trem_st = mne.Epochs(reconst_raw, events, event_id=annotation_desc_2_event_id, 
                            tmin=0, tmax=3, baseline=None, preload=True)  # baseline should be inside epoch interval

In [None]:
epochs = epochs_trem_st

In [None]:
dur = 20
events_regsampled = mne.make_fixed_length_events(raw_sss, start=0, stop=300, duration=dur)
epochs_regsampled = mne.Epochs(raw_sss,events_regsampled, tmin=0,tmax = dur, baseline=None)

In [None]:
events = events_regsampled

In [None]:
mne.viz.plot_events(events,sfreq=reconst_raw.info['sfreq']);

In [None]:
# mne.viz.plot_events(events, event_id={'middle_full':0, 'no_tremor':1},
#                     sfreq=reconst_raw.info['sfreq']);

In [None]:
help(epochs.plot_psd)

In [None]:
epochs

In [None]:
fmax = 48

In [None]:
epochs.plot_psd(area_mode=None, fmin=2, fmax=fmax, show=False,
                               average=True, spatial_colors=False, n_jobs=10);

In [None]:
gc.collect()

In [None]:
colors[tpi]

### Plot psd for tremor and no_tremor

In [None]:
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
_,axs = plt.subplots(2)

epochTypes = sorted(annotation_desc_2_event_id.keys())
for ep in   epochTypes:
    #print(ep)
    
    evid = epochs_trem_st[ep].event_id
    tp =  list( evid.keys() )[0] 
    tpi =  evid[tp]
    
    epochs_trem_st[ep].plot_psd(area_mode=None, color=colors[tpi], ax=axs, 
                               fmin=2, fmax=fmax, show=False,
                               average=True, spatial_colors=False)
    ax = axs[0]
    
    ax.lines[-1].set_label(tp)
    ax.set(title='', xlabel='Frequency (Hz)')

ax.legend()
#ax.legend(ax.lines, list(annotation_desc_2_event_id.values())) 

In [None]:
len(ax.lines)

In [None]:
#ax.lines[3].__dict__

In [None]:
#help( epochs_trem_st.average )

In [None]:
epochs_trem_st.plot_psd(fmax=fmax, spatial_colors=1, average=False);

In [None]:
epochs_trem_st.plot_psd_topomap(ch_type='grad', normalize=True)

In [None]:
#reconst_raw.ch_names

In [None]:
#psds_welch_mean.shape

In [None]:
epochs = epochs_trem_st
epochs = epochs_regsampled
import mne.time_frequency as tfr
from mne.time_frequency import psd_welch
from mne.time_frequency import tfr_morlet

kwargs = dict(fmin=2, fmax=48, n_jobs=1)
psds_welch_mean, freqs_mean     = psd_welch(epochs, average='mean', **kwargs)
psds_welch_median, freqs_median = psd_welch(epochs, average='median', **kwargs)

# Convert power to dB scale.
psds_welch_mean = 10 * np.log10(psds_welch_mean)
psds_welch_median = 10 * np.log10(psds_welch_median)

# We will only plot the PSD for a single sensor in the first epoch.
ch_name = 'MEG0113'
ch_idx = epochs.info['ch_names'].index(ch_name)
epo_idx = 1

_, ax = plt.subplots()
ax.plot(freqs_mean, psds_welch_mean[epo_idx, ch_idx, :], color='k',
        ls='-', label='mean of segments')
ax.plot(freqs_median, psds_welch_median[epo_idx, ch_idx, :], color='k',
        ls='--', label='median of segments')

ax.set(title='Welch PSD ({}, Epoch {})'.format(ch_name, epo_idx),
       xlabel='Frequency (Hz)', ylabel='Power Spectral Density (dB)')
ax.legend(loc='upper right')
plt.show()


In [None]:
#help(tfr_morlet)

In [None]:
#help(power.plot_topo)

In [None]:
# define frequencies of interest (log-spaced)
#%matplotlib qt
#freqs = np.logspace(*np.log10([2, 48]), num=20)
freqs = np.linspace( 2, 48, 20)
n_cycles = freqs / 2.  # different number of cycle per frequency
power, itc = tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=True,
                        return_itc=True, decim=3, n_jobs=10)

In [None]:
#bl = [-0.5,0]
bl = None

mode = None #'logratio'
power.plot_topo(baseline=bl,  mode=mode, title='Average power')
power.plot([82], baseline=bl, mode=mode, title=power.ch_names[82])

In [None]:
fig, axis = plt.subplots(1, 2, figsize=(7, 4))
power.plot_topomap(ch_type='grad', tmin=0.5, tmax=1.5, fmin=8, fmax=12,
                   baseline=bl, mode='logratio', axes=axis[0],
                   title='Alpha', show=False)
power.plot_topomap(ch_type='grad', tmin=0.5, tmax=1.5, fmin=13, fmax=25,
                   baseline=bl, mode='logratio', axes=axis[1],
                   title='Beta', show=False)
mne.viz.tight_layout()
plt.show()

In [None]:
power.plot_joint(baseline=(-0.5, 0), mode='mean', tmin=-.5, tmax=2,
                 timefreqs=[(.5, 10), (1.3, 8)])


In [None]:
epochs.plot(block=True)

In [None]:
help(epochs.plot_image)

In [None]:
epochs.plot_image(278, cmap='interactive')

In [None]:
epochs.plot_topo_image(vmin=-250, vmax=250, title='ERF images', sigma=2.,
                       fig_facecolor='w', font_color='k')

In [None]:
help(reconst_raw.apply_function)

## Read src

In [None]:
srcname = 'srcd_S01_off_hold_HirschPt2011_test.mat'
srcname_full = os.path.join(data_dir,srcname)

In [None]:
import h5py
srcf = h5py.File(srcname_full,'r')

In [None]:
src = srcf['source_data'] 
nsrc = src['avg']['mom'].shape[1]
print(nsrc, src['avg']['mom'].shape)

In [None]:
#srcf[ preref[0,1] ]

In [None]:
src_ind = 0
preref = src['avg']['mom']
ref = preref[0, src_ind]
srcdata = srcf[ref]
print(srcdata.shape)

In [None]:
raw.info['sfreq']

In [None]:
srcdata[:,0][None,:]

In [None]:
freqs = np.arange(2,48,2)
Sxx = mne.time_frequency.tfr_array_morlet(srcdata[:,0][None,None,:], sfreq=int(raw.info['sfreq']),
                                 freqs=freqs, n_cycles=freqs * 0.75, 
                                 output='avg_power')
print(Sxx.shape)


In [None]:
if Sxx.ndim == 3:
    Sxx = Sxx[0,:,:]

In [None]:
Sxx.shape

In [None]:
srctimes = src['time'][:,0]

In [None]:
Sxx.shape

In [None]:
srctimes.shape

In [None]:
# MNE tutorial to look at bands
#https://mne.tools/stable/auto_examples/time_frequency/plot_time_frequency_global_field_power.html#sphx-glr-auto-examples-time-frequency-plot-time-frequency-global-field-power-py

In [None]:
#mne.baseline.rescale(data,times,baselineTuple,mode in ['mean', 'ratio', 'logratio'])
help( mne.baseline )

In [None]:
ivalinds = ivalis['no_tremor']
a,b,_ = intervals[ivalinds[0]]

print(a,b)

### No baseline correction

In [None]:
%matplotlib inline
plt.figure(figsize=(15,5))
plt.pcolor(srctimes,freqs, Sxx)

In [None]:
bin1,bin2 = reconst_raw.time_as_index([0,50])

In [None]:
%matplotlib inline
import matplotlib as mpl
norm = mpl.colors.LogNorm(vmin=np.min(Sxx + np.min(Sxx)),vmax= np.max(Sxx + np.min(Sxx)) ); 
plt.figure(figsize=(15,5))
plt.pcolor(srctimes[bin1:bin2],freqs, Sxx[:,bin1:bin2], norm=norm)

In [None]:
%matplotlib inline
plt.figure(figsize=(15,5))
Sxx_bmod = mne.baseline.rescale(Sxx,srctimes,(a,b), 'ratio')
plt.pcolor(srctimes,freqs, Sxx_bmod)

In [None]:
Sxx_bmod = mne.baseline.rescale(Sxx,srctimes,(a,b), 'mean')
plt.figure(figsize=(15,5))
plt.pcolor(srctimes,freqs, Sxx_bmod)

In [None]:
Sxx_bmod = mne.baseline.rescale(Sxx,srctimes,(a,b), 'logratio')
plt.figure(figsize=(15,5))
plt.pcolor(srctimes,freqs, Sxx_bmod)

In [None]:
Sxx_bmod = mne.baseline.rescale(Sxx,srctimes,(a,b), 'zscore')
plt.figure(figsize=(15,5))
plt.pcolor(srctimes,freqs, Sxx_bmod)

In [None]:
Sxx_bmod = mne.baseline.rescale(Sxx,srctimes,(a,b), 'zlogratio')
plt.figure(figsize=(15,5))
plt.pcolor(srctimes,freqs, Sxx_bmod)

In [None]:
abin,bbin = reconst_raw.time_as_index([a,b])
Sxx_notrem = Sxx[:,abin:bbin]
print(Sxx.shape, Sxx_notrem.shape)

In [None]:
import utils
mn_nout, mx_nout, me_nout = utils.calcNoutMMM_specgram(Sxx_notrem, thr=1e-2 )
Sxx_mc = (Sxx - mn_nout[:,None]) / mn_nout[:,None]
plt.figure(figsize=(15,5))
plt.pcolor(srctimes,freqs, Sxx_mc)

In [None]:
import utils
mn_nout, mx_nout, me_nout = utils.calcNoutMMM_specgram(Sxx, thr=1e-2 )
Sxx_mc = (Sxx - mn_nout[:,None]) / mn_nout[:,None]
plt.figure(figsize=(15,5))
plt.pcolor(srctimes,freqs, Sxx_mc)

In [None]:
import matplotlib as mpl

In [None]:
import utils
mn_nout, mx_nout, me_nout = utils.calcNoutMMM_specgram(Sxx_notrem, thr=1e-2 )
Sxx_mc = (Sxx - mn_nout[:,None]) / mn_nout[:,None]
mn_mc,mx_mc = utils.getSpecEffMax(Sxx_mc, thr=1e-2)
norm = mpl.colors.Normalize(vmin=np.min(mn_mc),vmax= np.max(mx_mc) ); 
plt.figure(figsize=(15,5))
plt.pcolor(srctimes,freqs, Sxx_mc, norm=norm)

In [None]:
import utils
mn_nout, mx_nout, me_nout = utils.calcNoutMMM_specgram(Sxx, thr=1e-2 )
Sxx_mc = (Sxx - mn_nout[:,None]) / mn_nout[:,None]
mn_mc,mx_mc = utils.getSpecEffMax(Sxx_mc, thr=1e-2)
norm = mpl.colors.Normalize(vmin=np.min(mn_mc),vmax= np.max(mx_mc) ); 
plt.figure(figsize=(15,5))
plt.pcolor(srctimes,freqs, Sxx_mc, norm=norm)

In [None]:
import scipy
from  scipy.signal import welch


In [None]:
# get many sources
# alpha gradient anterior-posterior
# lateralization of MC activity
# power spectrum lookup
# topography in sensor space for narrow   -- different between 
   #  motor alpha (mu rythm) and visual alpha (lower than mu-rythm)  -- at least in some subjects

In [None]:
for src_ind in range(0,8,2):
#for src_ind in range(3):
#for src_ind in [1,3,5]:
    ref = preref[0, src_ind]
    srcdata = srcf[ref]
    f,Pxx = welch(srcdata[:,0][None,None,:], fs = raw.info['sfreq'])
    ax = plt.gca()
    ax.plot(f,Pxx[0,0,:], label='{}'.format(src_ind))
    ax.set_xlim(0,60)
    ax.set_yscale('log')
ax.legend()

In [None]:
ax = plt.gca()
ax.plot(f,Pxx[0,0,:])
ax.set_xlim(0,60)
ax.set_yscale('log')

In [None]:
ax = plt.gca()
ax.plot(freqs,np.sum(Sxx, axis=1) )
ax.set_xlim(0,60)
ax.set_yscale('log')

In [None]:
help( scipy.signal.welch )

In [None]:
#help( mne.concatenate_raws )

In [None]:
help( mne.parallel )

In [None]:
help( mne.stats )

In [None]:
help( mne.viz.plot_csd )

In [None]:
#help(mne.what)

In [None]:
Sxx_bmod = mne.baseline.rescale(Sxx,srctimes,(a,b), 'mean')
plt.figure(figsize=(15,5))
plt.pcolor(srctimes,freqs, Sxx_bmod)

In [None]:
# create source estimate by hand
vertices = [[146374], [33830]]  # need to find right vertices somehow

# Construct SourceEstimates that describe the signals at the cortical level.
data = np.vstack((signal1, signal2))
stc_signal = mne.SourceEstimate(
    data, vertices, tmin=0, tstep=1. / sfreq, subject=sind_str)

In [None]:


# fmax =48
# a,b,it = intervals[ ivalis[intType][0] ]
# r = reconst_raw.plot_psd(tmin=a, tmax=b, show=False, n_jobs=1, fmax=fmax)
# a,b,it = intervals[ ivalis[intType][1] ]
# r = reconst_raw.plot_psd(tmin=a, tmax=b, show=False, n_jobs=1, fmax=fmax)

In [None]:
type(r)

In [None]:
help( raw.plot_psd )

In [None]:
from mne.viz import plot_alignment, set_3d_title
plot_alignment(raw.info, trans=None, dig=False, eeg=False,
                         surfaces=[], meg=['helmet', 'sensors'], coord_frame='meg',
                         verbose=True)