In [None]:
import os,sys
import mne
import utils  #my code
import json
import matplotlib.pyplot as plt
import numpy as np
import h5py
%load_ext autoreload
%autoreload 2

if os.environ.get('DATA_DUSS') is not None:
    data_dir = os.path.expandvars('$DATA_DUSS')
else:
    data_dir = '/home/demitau/data'
    


rawname_ = 'S01_off_hold'
#rawname_ = 'S01_on_hold'
rawname = rawname_ + '_resample_raw.fif'
fname_full = os.path.join(data_dir,rawname)
    
# read file -- resampled to 256 Hz,  Electa MEG, EMG, LFP, EOG channels
raw = mne.io.read_raw_fif(fname_full, None)

In [None]:
anns_fn = rawname_ + '_anns.txt'
anns_fn_full = os.path.join(data_dir, anns_fn)
anns = mne.read_annotations(anns_fn_full)
raw.set_annotations(anns)

In [None]:
# get info about bad MEG channels (from separate file)
with open('subj_info.json') as info_json:
        #raise TypeError

    #json.dumps({'value': numpy.int64(42)}, default=convert)
    gen_subj_info = json.load(info_json)
    
subj,medcond,task  = utils.getParamsFromRawname(rawname_)
badchlist = gen_subj_info[subj]['bad_channels'][medcond][task]
raw.info['bads'] = badchlist
print('bad channels are ',badchlist)

In [None]:
import utils
trem_times_fn = 'trem_times_tau.json'
with open(trem_times_fn ) as jf:
    trem_times_byhand = json.load(jf)   
trem_times_nms_fn = 'trem_times_tau_nms.json'
with open(trem_times_nms_fn ) as jf:
    trem_times_nms_byhand = json.load(jf)   

#%debug
tremIntervalJan, artif         = utils.unpackTimeIntervals(trem_times_byhand, mainSide = True, 
                                                           gen_subj_info=gen_subj_info, skipNotLoadedRaws=0)
tremIntervalJan_nms, artif_nms = utils.unpackTimeIntervals(trem_times_nms_byhand, mainSide = False, 
                                                           gen_subj_info=gen_subj_info, skipNotLoadedRaws=0)
for rawn in [rawname_]:
    if rawn in artif_nms and rawn not in artif:
        artif[rawn] = artif_nms[rawn]
    else:
        if rawn in artif_nms:
            artif[rawn].update(artif_nms[rawn] )
        
for rawn in tremIntervalJan:
    sind_str,medcond,task = utils.getParamsFromRawname(rawn)
    maintremside = gen_subj_info[sind_str]['tremor_side']
    opside= utils.getOppositeSideStr(maintremside)
    if rawn in tremIntervalJan_nms:
        tremIntervalJan[rawn][opside] = tremIntervalJan_nms[rawn][opside] 


mvtTypes = ['tremor', 'no_tremor', 'unk_activity']

plotTremNegOffset = 2.
plotTremPosOffset = 2.
maxPlotLen = 6   # for those interval that are made for plotting, not touching intervals for stats
addIntLenStat = 5
plot_time_end = 150

timeIntervalPerRaw_processed = utils.processJanIntervals(tremIntervalJan, maxPlotLen, addIntLenStat, 
                          plotTremNegOffset, plotTremPosOffset, plot_time_end, mvtTypes=mvtTypes)

In [None]:
# for current raw
maintremside = gen_subj_info[subj]['tremor_side']
print(maintremside)

In [None]:
raw_lfponly = raw.copy()
#raw_lfponly.crop(0,300)
raw_lfponly.load_data()

In [None]:
#hand_side = 'L'
hand_side = 'R'

if hand_side == 'L':
    brain_side  = 'R'
else:
    brain_side  = 'L'
chis = mne.pick_channels_regexp(raw.ch_names, 'LFP{}.*'.format(brain_side))
chnames_lfp = [raw.ch_names[chi] for chi in chis]
chnames_lfp

In [None]:
raw_lfponly.pick_channels(   chnames_lfp  )

raw_lfponly.ch_names

In [None]:
rawdata = raw_lfponly.get_data()

rawdata.shape

In [None]:
raw_lfponly.get_channel_types()

In [None]:
y = {}
for chname in raw_lfponly.ch_names:
    y[chname] = 'eeg'
raw_lfponly.set_channel_types(y)

In [None]:
freqsToKill = np.arange(50, 128, 50) # harmonics of 50
raw_lfponly.notch_filter(freqsToKill, picks=['eeg'])

In [None]:
raw_lfponly.plot_psd();

In [None]:
for ind in range(rawdata.shape[0] ):
    plt.hist(rawdata[ind], bins=100, alpha=0.7, label='{}'.format(raw_lfponly.ch_names[ind]))

plt.legend()

In [None]:
# I want 256 window sz
cf =  256/ ( 5/(2*np.pi) * 256  ) 
cf

In [None]:
# strec = 0
# endrec = raw_lfponly.times[-1]
# epdur = endrec
# events_one = mne.make_fixed_length_events(raw_lfponly, start=strec, stop=endrec, duration=epdur)
# epochs_one = mne.Epochs(raw_lfponly,events_one, tmin=0,tmax = epdur, baseline=None)

#tfr_array_morlet
min_freq = 3
freq_step = 2
freqs = np.arange(min_freq,100,freq_step)
#freq2cycles_mult = 0.75
freq2cycles_mult = cf  # 1.2566370614359172
tfrres = mne.time_frequency.tfr_array_morlet(raw_lfponly.get_data()[None,:], 
                                             raw.info['sfreq'], freqs, freqs * freq2cycles_mult, n_jobs=10)
tfrres = tfrres[0]

In [None]:
%matplotlib qt

In [None]:
dat_ = np.abs( tfrres[0] )
plt.figure()
import matplotlib as mpl
norm = mpl.colors.LogNorm(vmin = np.min(dat_), vmax = np.max(dat_))
plt.pcolormesh(raw.times, freqs, dat_, norm=norm )

In [None]:
tfrres.shape

In [None]:
tfres_ = tfrres.reshape(3*len(freqs), tfrres.shape[-1]).T

In [None]:
plt.figure()
N =  tfres_.shape[0] 
nshow = 20
for ind in range( 0, N, N//nshow ):
    plt.hist( np.abs( tfres_[ind] ) , bins=100, alpha=0.7, label='{}'.format(ind))

#plt.legend()

In [None]:
Xfull = np.abs( tfres_ )[256:-256]  # to avaoid edge artifact due to wavelet computation

In [None]:
skip = 30
X = np.abs( Xfull[::skip] )
Xtimes = raw_lfponly.times[256:-256:skip]
X.shape

In [None]:
anns.description

In [None]:
ivalis = {}
anns = raw_lfponly.annotations
for i,an in enumerate(anns ):
    descr = an['description']
    if descr not in ivalis:
        ivalis[descr] = []
    tpl = an['onset'], an['onset']+ an['duration'], descr
    ivalis[descr] += [ tpl  ]

In [None]:
hand_side

In [None]:
tremcolor = 'r'
nontremcolor = 'g'
mvtcolor = 'm'  #c,y

#hsfc = hand_side
hsfc = 'L'; print('Using not hand side (perhabs) for coloring')
annot_colors = { 'trem_{}'.format(hsfc): tremcolor  }
annot_colors[ 'no_tremor_{}'.format(hsfc) ] = nontremcolor
annot_colors[ 'mvt_{}'.format(hsfc) ] = mvtcolor
#annot_colors[ 'no_tremor_{}'.format(hand_side) ] = nontremcolor
#annot_colors[ 'no_tremor_{}'.format(hand_side) ] = nontremcolor

colors =  np.array(  [nontremcolor] * len(Xtimes) )

for an in anns:
    for descr in annot_colors:
        if an['description'] == descr:
            col = annot_colors[descr]
    
            start = an['onset']
            end = start + an['duration']
            inds = np.where((Xtimes >= start)* (Xtimes <= end)  )[0]
            colors[inds] = [col]

# postcolor = 'blue'
# precolor = 'yellow'

# predur = 3
# postdur = 3

# inds = np.where((Xtimes >= tremend - predur)* (Xtimes <= tremend )  )[0]
# colors[inds] = [precolor]

# inds = np.where((Xtimes >= tremend)* (Xtimes <= tremend + postdur )  )[0]
# colors[inds] = [postcolor]

In [None]:
colors

In [None]:
colors.shape

In [None]:
# X = Xfull
# X.shape

## Look if we have something evidently weird in ICA

In [None]:
help(FastICA)

In [None]:
from sklearn.decomposition import FastICA
ica = FastICA(n_components=30)
S_ = ica.fit_transform(X)  

In [None]:
sh = 0.09
%matplotlib qt
#descrs = ['trem']
for i in range(S_.shape[-1]):
    plt.plot(Xtimes, S_[:,i] + i*sh)
    #for i in ivalis

In [None]:
import gc; gc.collect()

In [None]:
X.shape

In [None]:
from sklearn.decomposition import PCA
pca = PCA(n_components=60)
pca.fit(Xfull)
#pca.fit(X)

In [None]:
np.sum(pca.explained_variance_)

In [None]:
print('total explained variance proportion', np.sum(pca.explained_variance_ratio_) )
print(pca.explained_variance_ratio_[:10])

In [None]:
pcapts = pca.transform(X)

In [None]:
pcapts.shape

In [None]:
import matplotlib as mpl

In [None]:
colors

In [None]:
plt.scatter(pcapts[:,0], pcapts[:,1], c=colors.tolist())

#legel_trem = mpl.patches.Patch(facecolor=tremcolor, edgecolor=None, label='trem')

legel_trem = mpl.lines.Line2D([0], [0], marker='o', color='w', label='trem', 
                              markerfacecolor=tremcolor, markersize=8)
legel_notrem = mpl.lines.Line2D([0], [0], marker='o', color='w', label='notrem', 
                              markerfacecolor=nontremcolor, markersize=8)
legel_mvt = mpl.lines.Line2D([0], [0], marker='o', color='w', label='mvt', 
                              markerfacecolor=mvtcolor, markersize=8)
# legel_preend = mpl.lines.Line2D([0], [0], marker='o', color='w', label='preend', 
#                               markerfacecolor=precolor, markersize=8)
# legel_postend = mpl.lines.Line2D([0], [0], marker='o', color='w', label='postend', 
#                               markerfacecolor=postcolor, markersize=8)



#legend_elements = [legel_trem, legel_notrem, legel_preend, legel_postend]
legend_elements = [legel_trem, legel_notrem, legel_mvt]

# Create the figure
#fig, ax = plt.subplots()
plt.legend(handles=legend_elements)

#plt.show()


In [None]:
XX = pcapts


import numpy as np
from sklearn.manifold import TSNE
#X = np.array([[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 1]])

In [None]:
import gc; gc.collect()

In [None]:
import time

time.time()


In [None]:
def run_tsne(p):
    t0 = time.time()
    pi,si, XX, seed, perplex_cur, lrate = p
    tsne = TSNE(n_components=2, random_state=seed, perplexity=perplex_cur, learning_rate=lrate)
    X_embedded = tsne.fit_transform(XX)
    
    dur = time.time() - t0
    print('comnputed in {:.3f}s: perplexity = {};  lrate = {}; seed = {}'.
          format(dur,perplex_cur, lrate, seed))
    
    return pi,si,X_embedded, seed, perplex_cur, lrate

In [None]:
#perplex_values = [5, 10, 30, 40, 50]
#seeds = range(5)
lrate = 200.
#seeds = range(5)
perplex_values = [5, 30, 50]
seeds = range(2)

res = []
args = []
for pi,perplex_cur in enumerate(perplex_values):
    subres = []
    for si,seed in enumerate(seeds):

        args += [ (pi,si, XX.copy(), seed, perplex_cur, lrate)]
        #tsne = TSNE(n_components=2, random_state=seed, perplexity=perplex_cur, learning_rate=lrate)
        

        #X_embedded = tsne.fit_transform(XX)
        #X_embedded.shape
        #subres += [X_embedded]

    #res += [subres]
    

In [None]:
import multiprocessing as mpr
ncores = mpr.cpu_count()-2
pool = mpr.Pool(ncores)
print('Starting {} workers on {} cores'.format(len(args), ncores))
r = pool.map(run_tsne, args)

pool.close()
pool.join()

In [None]:
#cols = [colors, colors2, colors3]
cols = [colors]

colind = 0
nr = len(seeds)
nc = len(perplex_values)
ww = 5; hh=5
fig,axs = plt.subplots(ncols =nc, nrows=nr, figsize = (nc*ww, nr*hh))
# for pi,pv in enumerate(perplex_values):
#     for si,sv in enumerate(seeds):
for tpl in r:
    pi,si,X_embedded, seed, perplex_cur, lrate = tpl
    ax = axs[si,pi]
    #X_embedded = res[si][pi]
    ax.scatter(X_embedded[:,0], X_embedded[:,1], c = cols[colind], s=1)
    ax.set_title('perplexity = {};  lrate = {}; seed = {}'.format(perplex_cur, lrate, seed))

axs[0,0].legend(handles=legend_elements)
plt.savefig('tSNE_LFP{}_trem_minFreq={}.pdf'.format(brain_side,min_freq))

In [None]:
X.shape