In [7]:
import numpy as np
import pandas as pd
from mne.io import RawArray
from mne.channels import read_montage
from mne.epochs import concatenate_epochs
from mne import create_info, find_events, Epochs
from mne.viz.topomap import _prepare_topo_plot, plot_topomap
from mne.decoding import CSP

from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score

from glob import glob

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
    
from scipy.signal import welch
from mne import pick_types

In [8]:
from sklearn.cross_validation import cross_val_score, LeaveOneLabelOut

ImportError: No module named cross_validation

In [9]:
def creat_mne_raw_object(fname):
    """Create a mne raw instance from csv file"""
    # Read EEG file
    data = pd.read_csv(fname)
    
    # get chanel names
    ch_names = list(data.columns[1:])
    
    # read EEG standard montage from mne
    montage = read_montage('standard_1005',ch_names)

    # events file
    ev_fname = fname.replace('_data','_events')
    # read event file
    events = pd.read_csv(ev_fname)
    events_names = events.columns[1:]
    events_data = np.array(events[events_names]).T
    
    # concatenate event file and data
    data = np.concatenate((1e-6*np.array(data[ch_names]).T,events_data))        
    
    # define channel type, the first is EEG, the last 6 are stimulations
    ch_type = ['eeg']*len(ch_names) + ['stim']*6
    
    # create and populate MNE info structure
    ch_names.extend(events_names)
    info = create_info(ch_names,sfreq=500.0, ch_types=ch_type, montage=montage)
    info['filename'] = fname
    
    # create raw object 
    raw = RawArray(data,info,verbose=False)
    return raw


In [10]:
# /Users/shanmvas/Downloads/grasp-and-lift-eeg-detection/train/subj9_series3_events.csv

In [13]:
subjects = range(1,13)
auc = []
for subject in subjects:
    epochs_tot = []
    
    #eid = 'HandStart'
    fnames =  glob('../input/train/subj%d_series*_data.csv' % (subject))
    
    session = []
    y = []
    for i,fname in enumerate(fnames):
      
        # read data 
        raw = creat_mne_raw_object(fname)
        
        # pick eeg signal
        picks = pick_types(raw.info,eeg=True)
        
        # Filter data for alpha frequency and beta band
        # Note that MNE implement a zero phase (filtfilt) filtering not compatible
        # with the rule of future data.
        raw.filter(7,35, picks=picks, method='iir', n_jobs=-1, verbose=False)
        
        # get event posision corresponding to Replace
        events = find_events(raw,stim_channel='Replace', verbose=False)
        # epochs signal for 1.5 second before the movement
        epochs = Epochs(raw, events, {'during' : 1}, -2, -0.5, proj=False,
                        picks=picks, baseline=None, preload=True,
                        add_eeg_ref=False, verbose=False)
        
        epochs_tot.append(epochs)
        session.extend([i]*len(epochs))
        y.extend([1]*len(epochs))
        
        # epochs signal for 1.5 second after the movement, this correspond to the 
        # rest period.
        epochs_rest = Epochs(raw, events, {'after' : 1}, 0.5, 2, proj=False,
                        picks=picks, baseline=None, preload=True,
                        add_eeg_ref=False, verbose=False)
        
        # Workaround to be able to concatenate epochs
        epochs_rest.times = epochs.times
        
        epochs_tot.append(epochs_rest)
        session.extend([i]*len(epochs_rest))
        y.extend([-1]*len(epochs_rest))
        #concatenate all epochs
    epochs = concatenate_epochs(epochs_tot)
    
    # get data 
    X = epochs.get_data()
    y = np.array(y)
    
    # run CSP
    csp = CSP(reg='lws')
    csp.fit(X,y)
    
    # compute spatial filtered spectrum
    po = []
    for x in X:
        f,p = welch(np.dot(csp.filters_[0,:].T,x), 500, nperseg=512)
        po.append(p)
    po = np.array(po)
    
    # prepare topoplot
    _,epos,_,_,_ = _prepare_topo_plot(epochs,'eeg',None)
    
    # plot first pattern
    pattern = csp.patterns_[0,:]
    pattern -= pattern.mean()
    ix = np.argmax(abs(pattern))
    # the parttern is sign invariant.
    # invert it for display purpose
    if pattern[ix]>0:
        sign = 1.0
    else:
        sign = -1.0
    
    fig, ax_topo = plt.subplots(1, 1, figsize=(12, 4))
    title = 'Spatial Pattern'
    fig.suptitle(title, fontsize=14)
    img, _ = plot_topomap(sign*pattern,epos,axis=ax_topo,show=False)
    divider = make_axes_locatable(ax_topo)
    # add axes for colorbar
    ax_colorbar = divider.append_axes('right', size='5%', pad=0.05)
    plt.colorbar(img, cax=ax_colorbar)
    
        

IndexError: list index out of range