In [24]:
from load import load_runs, conv_runs_to_epochs, extract_summary_stats
from ml import eval_all, get_pipe,  predict_new
import numpy as np
import mne
from joblib import dump, load

In [35]:
def train_and_save(params, runs, save_name,
                   use_summary=False, model='rf'):
    
    # Load as epochs w/ or w/o ica
    epochs, ica_obj = conv_runs_to_epochs(runs, tmin=-2, tmax=1.9, **params)
    
    # Get epochs as data
    data = epochs.get_data()
    
    # Grab labels
    labels = epochs.events[:, -1]
    
    # Add to params for save
    params['ica_obj'] = ica_obj
    params['use_summary'] = use_summary
    params['epoch_len'] = data.shape[-1]
    
    # If use summary, use summary stats instead of raw
    if use_summary:
        data = extract_summary_stats(epochs)
    
    # Fit pipe on all avalible data
    pipe = get_pipe(model)
    pipe.fit(data, labels)
    
    # Save pipeline and params
    dump((pipe, params), save_name)
    
def test_eval(runs, params):
    
    epochs, ica_obj = conv_runs_to_epochs(runs, tmin=-2, tmax=1.9, **params)

    # Get scores both ways
    labels = epochs.events[:, -1]
    eval_all(extract_summary_stats(epochs), labels, verbose=True)
    eval_all(data=epochs.get_data(), labels=labels)
    
    return epochs

In [26]:
#runs = load_runs('C:\\Users\\Sage\\Desktop\\predict_color\\data\\2_2_first')
#runs = load_runs('C:\\Users\\Sage\\Desktop\\predict_color\\data\\2_2_new')
#runs = load_runs('C:\\Users\\Sage\\Desktop\\predict_color\\data\\2_2_latest')
#runs += load_runs('C:\\Users\\Sage\\Desktop\\predict_color\\data\\2_2_8_ch')


In [36]:
runs =  load_runs('C:\\Users\\Sage\\Desktop\\predict_color\\data\\7_colors')

params = {'l_freq': None, 'h_freq': 10, 'auto_reject': 'ar', 'ica': [2], 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': True, 'crop': True}

In [37]:
test_eval(runs, params)

logistic: 0.1402116402116402
rf: 0.1402116402116402
logistic: 0.1349206349206349
rf: 0.1349206349206349


0,1
Number of events,29
Events,#2CFCFC: 2 #844CD4: 2 #FC5C04: 3 #FCFB20: 4 Blue: 5 Green: 5 Red: 8
Time range,0.000 – 1.904 sec
Baseline,off


In [30]:
#train_and_save(params.copy(), runs, save_name='test_7_rf_raw.pkl', use_summary=False, model='rf')
#train_and_save(params.copy(), runs, save_name='test_7_log.pkl', use_summary=True, model='logistic')

In [31]:
p_and_params1 = load('test_7_rf_raw.pkl')

In [34]:
predict_new(runs[7][:, 1000:3000], [p_and_params1])

('Red',
 {'Red': 0.22,
  'Green': 0.12,
  'Blue': 0.12,
  '#FC5C04': 0.12,
  '#844CD4': 0.14,
  '#2CFCFC': 0.17,
  '#FCFB20': 0.11})

In [None]:
stop

## Use rest  of notebook for plotting / random exploration of data

In [None]:
mne.viz.plot_compare_evokeds(dict(red=epochs['Red'].average(),
                                  blue=epochs['Blue'].average(),
                                  green=epochs['Green'].average()))

In [None]:
import matplotlib.pyplot as plt

def get_start(run_data):
    
    event_labels = run_data[-1]
    return np.where(event_labels == 1)[0][0]

def plot_raw_run(run_data, event_labels=None):
    
    if event_labels is not None:
        data = np.vstack([run_data._data, event_labels])
    else:
        data = run_data
    
    s = get_start(data)
    
    n_ch = len(data)-1
    fig, axes = plt.subplots(n_ch, figsize=(10, 60))

    for ch in range(n_ch):
        axes[ch].plot(data[ch, s:])
        
runs = load_runs('C:\\Users\\Sage\\Desktop\\predict_color\\data\\2_2_latest')
plot_raw_run(runs[4])

In [None]:
frequencies = np.arange(1, 30, 3)
power = mne.time_frequency.tfr_morlet(epochs, n_cycles=2, return_itc=False,
                                      freqs=frequencies, decim=3)
power.plot()

In [None]:
# above 40 for latest

array([{'l_freq': 12, 'h_freq': 30, 'auto_reject': 'ar', 'ica': [0], 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 1, 'h_freq': 50, 'auto_reject': None, 'ica': [1], 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': 1, 'h_freq': 50, 'auto_reject': None, 'ica': [1], 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 1, 'h_freq': 50, 'auto_reject': None, 'ica': [0, 1], 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 1, 'h_freq': 10, 'auto_reject': None, 'ica': [1], 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 1, 'h_freq': 10, 'auto_reject': 'ar', 'ica': None, 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': 5, 'h_freq': 10, 'auto_reject': None, 'ica': [2], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 1, 'h_freq': 4, 'auto_reject': None, 'ica': [1], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 1, 'h_freq': 4, 'auto_reject': 'ar', 'ica': None, 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': 4, 'h_freq': 8, 'auto_reject': None, 'ica': None, 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 4, 'h_freq': 8, 'auto_reject': None, 'ica': [2], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 8, 'h_freq': 13, 'auto_reject': None, 'ica': [0], 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 8, 'h_freq': 13, 'auto_reject': None, 'ica': [0, 1], 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 13, 'h_freq': 20, 'auto_reject': None, 'ica': [1], 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 13, 'h_freq': 20, 'auto_reject': None, 'ica': [1], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': 13, 'h_freq': 20, 'auto_reject': None, 'ica': [0, 1], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': 1, 'h_freq': None, 'auto_reject': None, 'ica': [0, 1], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': 1, 'h_freq': None, 'auto_reject': None, 'ica': [0, 1], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 1, 'h_freq': None, 'auto_reject': 'ar', 'ica': [0, 1], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 5, 'h_freq': None, 'auto_reject': None, 'ica': [0, 1], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': 5, 'h_freq': None, 'auto_reject': 'ar', 'ica': [0, 1], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': 5, 'h_freq': None, 'auto_reject': 'ar', 'ica': [0, 1], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 20, 'h_freq': None, 'auto_reject': None, 'ica': None, 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 20, 'h_freq': None, 'auto_reject': None, 'ica': [0, 1], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': 20, 'h_freq': None, 'auto_reject': 'ar', 'ica': [0, 1], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': None, 'h_freq': 50, 'auto_reject': None, 'ica': [0, 1], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': None, 'h_freq': 50, 'auto_reject': 'ar', 'ica': [0, 1], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': None, 'h_freq': 20, 'auto_reject': None, 'ica': None, 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': None, 'h_freq': 20, 'auto_reject': None, 'ica': None, 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': None, 'h_freq': 10, 'auto_reject': None, 'ica': None, 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': None, 'h_freq': 10, 'auto_reject': None, 'ica': [0, 1], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': None, 'h_freq': 10, 'auto_reject': 'ar', 'ica': [0, 1], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 10, 'h_freq': 20, 'auto_reject': None, 'ica': None, 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 10, 'h_freq': 20, 'auto_reject': None, 'ica': [0, 1], 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 10, 'h_freq': 30, 'auto_reject': None, 'ica': [0], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': 10, 'h_freq': 30, 'auto_reject': 'ar', 'ica': None, 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': 10, 'h_freq': 30, 'auto_reject': 'ar', 'ica': [0], 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': 10, 'h_freq': 30, 'auto_reject': 'ar', 'ica': [0], 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 10, 'h_freq': 30, 'auto_reject': 'ar', 'ica': [0], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': 10, 'h_freq': 30, 'auto_reject': 'ar', 'ica': [0], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 10, 'h_freq': 30, 'auto_reject': 'ar', 'ica': [1], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 10, 'h_freq': 30, 'auto_reject': 'ar', 'ica': [0, 1], 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': 10, 'h_freq': 30, 'auto_reject': 'ar', 'ica': [0, 1], 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 30, 'h_freq': 50, 'auto_reject': None, 'ica': [0], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 30, 'h_freq': 50, 'auto_reject': 'ar', 'ica': [1], 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': 30, 'h_freq': 50, 'auto_reject': 'ar', 'ica': [2], 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': 30, 'h_freq': 50, 'auto_reject': 'ar', 'ica': [2], 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 30, 'h_freq': 50, 'auto_reject': 'ar', 'ica': [2], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': 30, 'h_freq': 50, 'auto_reject': 'ar', 'ica': [2], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 30, 'h_freq': 50, 'auto_reject': 'ar', 'ica': [0, 1], 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': 40, 'h_freq': 50, 'auto_reject': 'ar', 'ica': [2], 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': 40, 'h_freq': 50, 'auto_reject': 'ar', 'ica': [2], 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': True, 'crop': True},
       {'l_freq': 40, 'h_freq': 50, 'auto_reject': 'ar', 'ica': [2], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': None, 'h_freq': None, 'auto_reject': None, 'ica': [0, 1], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': False, 'crop': True},
       {'l_freq': None, 'h_freq': None, 'auto_reject': 'ar', 'ica': [0, 1], 'set_average_ref': True, 'apply_baseline': True, 'drop_ref_ch': True, 'crop': True}],
      dtype=object)


# all new

{'l_freq': 8, 'h_freq': 13, 'auto_reject': None, 'ica': None, 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': True, 'crop': True}
logistic: 0.4215024741340531
rf: 0.3425551057130005
    

{'l_freq': 4, 'h_freq': 8, 'auto_reject': 'ar', 'ica': None, 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': True, 'crop': True}
logistic: 0.3843882582272675
rf: 0.4288002140427321

{'l_freq': 4, 'h_freq': 8, 'auto_reject': 'ar', 'ica': [1], 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': False, 'crop': True}
logistic: 0.3931429881894278
rf: 0.37596797003401744
    
logistic: 0.3651435232962581
rf: 0.39030692198906847
{'l_freq': 4, 'h_freq': 8, 'auto_reject': 'ar', 'ica': [2], 'set_average_ref': True, 'apply_baseline': False, 'drop_ref_ch': False, 'crop': True}

In [12]:
import numpy as np

In [13]:
buffer = np.random.random((17, 400))

In [14]:
data = np.random.random((17, 342))

In [16]:
predict_data = np.concatenate([buffer, data], axis=1)

In [18]:
predict_data.shape

(17, 742)