In [None]:
import numpy as np
import os.path as op
import os
import matplotlib.pyplot as plt
from langouEEG import *

import mne
import pickle
from mne.datasets import sample
from mne.minimum_norm import apply_inverse_epochs, read_inverse_operator
from mne.connectivity import spectral_connectivity
from mne.viz import circular_layout, plot_connectivity_circle
import mne
from mne.datasets import eegbci
from mne.datasets import fetch_fsaverage
from mne.datasets import sample
from mne.minimum_norm import make_inverse_operator, apply_inverse
from mne.minimum_norm import write_inverse_operator

import os
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa
import mne
from mne.epochs import equalize_epoch_counts

import os.path as op

import numpy as np
from numpy.random import randn
from scipy import stats as stats

import mne
from mne.epochs import equalize_epoch_counts
from mne.stats import (spatio_temporal_cluster_1samp_test, spatio_temporal_cluster_test,
                       summarize_clusters_stc)
from mne.minimum_norm import apply_inverse, read_inverse_operator
from mne.datasets import sample
from tqdm import trange

sample_data_folder = mne.datasets.sample.data_path()
dataRoot = "/data/home/viscent/Light"
# Download fsaverage files
fs_dir = fetch_fsaverage(verbose=True)
isMale = True
result_dir = op.join(dataRoot,'result','male' if isMale else 'female')
subjects_dir = op.dirname(fs_dir)

# The files live in:
subject = 'fsaverage'
trans = 'fsaverage'  # MNE has a built-in fsaverage transformation
src = op.join(fs_dir, 'bem', 'fsaverage-ico-5-src.fif')
bem = op.join(fs_dir, 'bem', 'fsaverage-5120-5120-5120-bem-sol.fif')
print(__doc__)

# load data

In [None]:
epochs_4F = []
epochs_RF = []
epochs_RR = []
for subject_name in range(1,21):
    if not (isMale ^ (subject_name in [7,8,11,17])):
        continue
    if subject_name<10:
        subject_name='S0'+str(subject_name)
    else:
        subject_name='S'+str(subject_name)
    with open(dataRoot+'/clean_data_av/'+subject_name+'_clean.lgeeg','rb') as f:
        raw=pickle.load(f)
    raw.set_channel_types({'Trigger':'stim','VEO':'eog'})
    raw.set_eeg_reference(projection=True)
    events, event_dict=extractEvents(raw)
    picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=True,
                        exclude='bads')
    epoch_RR,epoch_RF,epoch_4R,epoch_4F = extractEpochs(raw,events,picks)
    epochs_4F.append(epoch_4F)
    epochs_RF.append(epoch_RF)
    epochs_RR.append(epoch_RR)
epochs_4F = mne.concatenate_epochs(epochs_4F)
epochs_RF = mne.concatenate_epochs(epochs_RF)
epochs_RR = mne.concatenate_epochs(epochs_RR)
equalize_epoch_counts([epochs_4F, epochs_RF,epochs_RR])

## Read forward

In [None]:
subject_name='S15'
with open(dataRoot+'/clean_data_av/'+subject_name+'_clean.lgeeg','rb') as f:
    raw=pickle.load(f)
raw.set_channel_types({'Trigger':'stim','VEO':'eog'})
raw.set_eeg_reference(projection=True)
events, event_dict=extractEvents(raw)
if not op.exists(os.path.join(dataRoot,'fwd_solutions',subject_name+'_fwd.lgeeg')):
    fwd = mne.make_forward_solution(raw.info, trans=trans, src=src,
                                    bem=bem, eeg=True, mindist=5.0, n_jobs=1)
    print(fwd)
    mne.write_forward_solution(os.path.join(dataRoot,'fwd_solutions',subject_name+'_fwd.lgeeg'),fwd,overwrite=True)
else:
    fwd = mne.read_forward_solution(os.path.join(dataRoot,'fwd_solutions',subject_name+'_fwd.lgeeg'))

## compute noise covariance and inverse operators

In [None]:
fname_inv = os.path.join(dataRoot,'inv_operators.lgeeg')
fname_cov = os.path.join(dataRoot,'noise_covariance.lgeeg')
if not os.path.exists(fname_cov):  
    noise_cov = mne.compute_covariance(
        epochs_RR, tmax=80., method=['shrunk', 'empirical'], rank=None, verbose=True)
    mne.write_cov(fname_cov,noise_cov)
else:   
# Load data
    noise_cov = mne.read_cov(fname_cov)
if not os.path.exists(fname_inv):  
    inverse_operator = make_inverse_operator(
        raw.info, fwd, noise_cov, loose=0.2, depth=0.8)
    write_inverse_operator(fname_inv,inverse_operator)
else:   
# Load data
    inverse_operator = read_inverse_operator(fname_inv)
src = inverse_operator['src']

## Source estimate

In [None]:
stcs_40 = []
stcs_rand = []
snr = 1.0  # use lower SNR for single epochs
lambda2 = 1.0 / snr ** 2
if not op.exists(op.join(result_dir, 'stc')):
    op.makedirs(op.join(result_dir, 'stc'))
method = "dSPM"  # use dSPM method (could also be MNE or sLORETA) 
if not op.exists(op.join(result_dir,'stc','40_stc.lgeeg')):
    epochs = epochs_4F.crop(10,15)
    stcs = apply_inverse_epochs(epochs, inverse_operator, lambda2, method,
                            pick_ori="normal", return_generator=False)
    with open(op.join(result_dir,'stc','40_stc.lgeeg'),'wb') as f:
        pickle.dump(stcs,f)
    stcs_40 = stcs
else:
    with open(op.join(result_dir,'stc','40_stc.lgeeg'),'rb') as f:
        stcs_40 = pickle.load(f)



if not op.exists(op.join(result_dir,'stc','rand_stc.lgeeg')):
    epochs = epochs_RF.crop(10,15)
    stcs = apply_inverse_epochs(epochs, inverse_operator, lambda2, method,
                            pick_ori="normal", return_generator=False)
    with open(op.join(result_dir,'stc','rand_stc.lgeeg'),'wb') as f:
        pickle.dump(stcs,f)
    stcs_rand = stcs
else:
    with open(op.join(result_dir,'stc','rand_stc.lgeeg'),'rb') as f:
        stcs_rand = pickle.load(f)

In [None]:
stcs_40_np = []
for stc in stcs_40:
    stcs_40_np.append(stc.data)
stcs_40_np = np.array(stcs_40_np)
stcs_rand_np = []
for stc in stcs_rand:
    stcs_rand_np.append(stc.data)
stcs_rand_np = np.array(stcs_rand_np)
del stcs_40, stcs_rand

In [None]:
X = np.array([stcs_40_np,stcs_rand_np])
X = X.transpose(2,3,1,0)
del stcs_40_np, stcs_rand_np


n_vertices_sample, n_times = X.shape[0], X.shape[1]
n_subjects = X.shape[2]
p_threshold = 0.001
t_threshold = -stats.distributions.t.ppf(p_threshold / 2., n_subjects - 1)

fsave_vertices = [s['vertno'] for s in src]
morph_mat = mne.compute_source_morph(
    src=inverse_operator['src'], subject_to='fsaverage',
    spacing=fsave_vertices, subjects_dir=subjects_dir).morph_mat
print('Reshaping')
n_vertices_fsave = morph_mat.shape[0]
#    We have to change the shape for the dot() to work properly
X = X.reshape(n_vertices_sample, n_times * n_subjects * 2)
print('Morphing data.')
X = morph_mat.dot(X)  # morph_mat is a sparse matrix
X = X.reshape(n_vertices_fsave, n_times, n_subjects, 2)
X = np.abs(X)  # only magnitude


In [None]:
X = [X[:, :, :, 0].transpose([2, 1, 0]), X[:, :, :, 1].transpose([2, 1, 0])]

In [None]:

# X = X[:, :, :, 0] - X[:, :, :, 1]  # make paired contrast
# X = np.transpose(X, [2, 1, 0])

In [None]:
import gc
gc.collect()

# Compute statistic

In [None]:
print('Computing adjacency.')
adjacency = mne.spatial_src_adjacency(src)

print('Clustering.')
    # spatio_temporal_cluster_1samp_test(X, adjacency=adjacency, n_jobs=2,
mne.set_cache_dir(op.join(dataRoot, 'cache'))
T_obs, clusters, cluster_p_values, H0 = clu_40 = \
    spatio_temporal_cluster_test(X, adjacency=adjacency, n_jobs=20,
                                       threshold=t_threshold, buffer_size=1,
                                       verbose=True)
good_cluster_inds = np.where(cluster_p_values < 0.05)[0]
print('Visualizing clusters.')

#    Now let's build a convenient representation of each cluster, where each
#    cluster becomes a "time point" in the SourceEstimate
stc_all_cluster_vis = summarize_clusters_stc(clu_40, tstep=tstep, backend='matplotlib',
                                             vertices=fsave_vertices,
                                             subject='fsaverage')

#    Let's actually plot the first "time point" in the SourceEstimate, which
#    shows all the clusters, weighted by duration.
subjects_dir = op.join(data_path, 'subjects')
# blue blobs are for condition A < condition B, red for A > B
brain = stc_all_cluster_vis.plot(
    hemi='both', views='lateral', subjects_dir=subjects_dir,
    time_label='temporal extent (ms)', size=(800, 800),
    smoothing_steps=5, clim=dict(kind='value', pos_lims=[0, 1, 40]))
brain.save_image('clusters0.001.png')