In [71]:
import sys

# sys.path.append('/Users/m/PTSA_NEW_GIT/')
import numpy as np
from numpy.testing import *
from ptsa.data.readers import BaseEventReader
from ptsa.data.filters.MorletWaveletFilter import MorletWaveletFilter
from ptsa.data.readers.TalReader import TalReader
from ptsa.data.readers import EEGReader
from ptsa.data.filters import MonopolarToBipolarMapper
from ptsa.data.filters import DataChopper
from ptsa.data.common import xr
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, roc_curve
from scipy.stats.mstats import zscore
import numpy.testing as npt
import cPickle
import os
from collections import namedtuple
from time import time

In [72]:
%matplotlib inline

import matplotlib
import numpy as np
import matplotlib.pyplot as plt

When runing first time set compute_classifier_flag to True. Subsequent runc can be executed with flag set to False because classifier will be computed and stored on the hard drive at that time

In [73]:
subject = 'R1111M'
e_path = os.path.join('/Users/m/data/events/RAM_FR1/',subject+'_events.mat')
                      
tal_path = os.path.join('/Users/m/data/eeg',subject,'tal',subject+'_talLocs_database_bipol.mat')

ClassifierData = namedtuple('ClassifierData', ['lr_classifier', 'mean','std'])

compute_classifier_flag = False

reading electrodes information

In [74]:
def get_monopolar_and_bipolar_electrodes():
    tal_reader = TalReader(filename=tal_path)
    monopolar_channels = tal_reader.get_monopolar_channels()
    bipolar_pairs = tal_reader.get_bipolar_pairs()
    return monopolar_channels, bipolar_pairs

reading events

In [75]:
def get_events():
    # ---------------- NEW STYLE PTSA -------------------
    base_e_reader = BaseEventReader(filename=e_path, eliminate_events_with_no_eeg=True)

    base_events = base_e_reader.read()

    base_events = base_events[base_events.type == 'WORD']
    return base_events

In [76]:
def get_bp_data(base_events):
    # retaining first session
    sessions = np.unique(base_events.session)
    # dataroot = base_events[0].eegfile
    # base_events = base_events[base_events.eegfile == dataroot]

    eeg_reader = EEGReader(events=base_events, channels=monopolar_channels,
                           start_time=0.0, end_time=1.6, buffer_time=1.0)

    base_eegs = eeg_reader.read()

    m2b = MonopolarToBipolarMapper(time_series=base_eegs, bipolar_pairs=bipolar_pairs)
    bp_eegs = m2b.filter()
    return bp_eegs

In [77]:
def compute_wavelets(base_events):
    from time import time
    s = time()

    bp_eegs = get_bp_data(base_events)
    wf = MorletWaveletFilter(time_series=bp_eegs,
                             freqs=np.logspace(np.log10(3), np.log10(180), 8),
                             output='power',
                             frequency_dim_pos=0,
                             verbose=True
                             )
    pow_wavelet, phase_wavelet = wf.filter()
    print 'TOTAL WAVELET TIME=', time() - s
    return pow_wavelet

In [78]:
def prepare_classifier_data(zscore_mean_powers,evs,sess_min,sess_max):
    session_mask = (evs.session >= sess_min) & (evs.session <= sess_max)
    evs_sel = evs[session_mask]
    recalls = evs_sel.recalled.astype(np.int)
    features = zscore_mean_powers[session_mask, ...]
    
    return features, recalls

In [79]:
def train_classifier(features, recalls):
    
    lr_classifier = LogisticRegression(C=7.2e-4, penalty='l2', class_weight='auto', solver='liblinear')
    lr_classifier.fit(features, recalls)
    return lr_classifier

In [80]:
def compute_continuous_wavelets(session=0):
    dataroot = base_events[base_events.session==session][0].eegfile
    session_reader = EEGReader(session_dataroot=dataroot, channels=monopolar_channels)
    session_eegs = session_reader.read()

    m2b = MonopolarToBipolarMapper(time_series=session_eegs, bipolar_pairs=bipolar_pairs)
    session_bp_eegs = m2b.filter()

    wf = MorletWaveletFilter(time_series=session_bp_eegs,
                         freqs=np.logspace(np.log10(3), np.log10(180), 8),
                         output='power',
                         frequency_dim_pos=0,
                         verbose=True
                         )
    
    pow_wavelet_session, phase_wavelet_session = wf.filter()
    return pow_wavelet_session

This function chops continuous time series into chunks corresponding to events where events can in fact be specified using simple array of offsets. Offsets then correspond to positions at which time series will be cut - note that you may cut eeg series but also time series of wavelets cooeficients

In [81]:
def chop_time_series(time_series,start_offsets):
    dc = DataChopper(start_offsets=start_offsets, session_data=time_series, start_time=0.0, end_time=1.6)
    chopped_time_series = dc.filter()
    return chopped_time_series

In [82]:
def compute_classifier_features(pow_wavelet):
    np.log10(pow_wavelet.data, out=pow_wavelet.data);

    pow_wavelet = pow_wavelet.transpose('events', "bipolar_pairs", "frequency", "time")

    mean_powers_nd = np.nanmean(pow_wavelet.data, axis=-1)

    mean_powers_rs = mean_powers_nd.reshape(mean_powers_nd.shape[0], -1)
    mean_powers_rs.shape

    zscore_mean_powers = zscore(mean_powers_rs, axis=0, ddof=1)
    
    return zscore_mean_powers

This function computes zscoring params - instead of using zscore function we will use more pedestrian way of zscoring

In [83]:
def compute_zscoring_params(log_pow_wavelet):
    transposed_log_pow_wavelet = log_pow_wavelet.transpose('events', "bipolar_pairs", "frequency", "time")
    mean_powers_nd = np.nanmean(transposed_log_pow_wavelet.data, axis=-1)
    mean_powers_rs = mean_powers_nd.reshape(mean_powers_nd.shape[0], -1)
    m = np.mean(mean_powers_rs,axis=0)
#     print m[:3]
    s = np.std(mean_powers_rs,axis=0,ddof=1)
    
    zscore_mean_powers = zscore(mean_powers_rs, axis=0, ddof=1)
    
    z_score = (mean_powers_rs-m)/s
    npt.assert_array_equal(zscore_mean_powers,z_score)
    return m,s

this function zscores features using mean and std dev returned by compute_zscoring_params

In [84]:
def compute_features_using_zscoring_params(pow_wavelet,mean,std):
    log_pow_wavelet = np.log10(pow_wavelet.data);

    transposed_log_pow_wavelet = pow_wavelet.transpose('events', "bipolar_pairs", "frequency", "time")

    mean_powers_nd = np.nanmean(transposed_log_pow_wavelet.data, axis=-1)

    mean_powers_rs = mean_powers_nd.reshape(mean_powers_nd.shape[0], -1)
    mean_powers_rs.shape
    
    z_score_mean_powers = (mean_powers_rs-mean)/std
    
    return z_score_mean_powers

this is just a test function not used in "production code"

In [85]:
def test_zscoring_computations():
    s=time()
    pow_wavelet = compute_wavelets(base_events)
    print 'TIME: EEG READING + WAVELETS=',time()-s
    zscore_mean_powers = compute_classifier_features(pow_wavelet)
    mean,std = compute_zscoring_params(log_pow_wavelet=pow_wavelet)
    transposed_log_pow_wavelet =pow_wavelet.transpose('events', "bipolar_pairs", "frequency", "time")
    mean_powers_nd = np.nanmean(transposed_log_pow_wavelet.data, axis=-1)
    mean_powers_rs = mean_powers_nd.reshape(mean_powers_nd.shape[0], -1)
    print zscore_mean_powers - (mean_powers_rs-mean)/std
    zscore_mean_powers_alt = compute_features_using_zscoring_params(pow_wavelet,mean,std)
    zscore_mean_powers_alt-zscore_mean_powers

This function computes classifier

In [86]:
def compute_classifier():
    s=time()
    pow_wavelet = compute_wavelets(base_events)
    print 'TIME: EEG READING + WAVELETS=',time()-s
    pow_wavelet = pow_wavelet.remove_buffer(duration=1.0)
    
    np.log10(pow_wavelet.data, out=pow_wavelet.data);
    log_pow_wavelet = pow_wavelet
    mean,std = compute_zscoring_params(log_pow_wavelet=log_pow_wavelet)
    zscore_mean_powers = compute_features_using_zscoring_params(log_pow_wavelet,mean,std)
    
#     zscore_mean_powers = compute_classifier_features(pow_wavelet)
    training_features, training_recalls  = prepare_classifier_data(zscore_mean_powers, base_events,sess_min=0,sess_max=1)
    lr_classifier = train_classifier(training_features,training_recalls)

    validation_features, validation_recalls  = prepare_classifier_data(zscore_mean_powers, base_events,sess_min=2,sess_max=2)
    
    recall_prob_array = lr_classifier.predict_proba(training_features)[:, 1]
    auc = roc_auc_score(training_recalls, recall_prob_array)
    print 'auc=', auc
    
    validation_recall_prob_array = lr_classifier.predict_proba(validation_features)[:, 1]
    auc = roc_auc_score(validation_recalls, validation_recall_prob_array)
    print 'auc=', auc            
    classifier_data = ClassifierData(lr_classifier=lr_classifier,mean=mean,std=std)

    # save the classifier
    with open('classifier_data_'+subject+'.pkl', 'wb') as fid:
        cPickle.dump(classifier_data, fid)

This function will tak as an input the following:
1. wavelets computed for the entire session
2. ClassifierData tuple - containing trained classifier, mean and std dev for z scoring
3. start_time (in seconds) - determines the time location of the epoch at which we being computting probs tgime series
4. end_time (in seconds) - determines the last epoch of thr probs time series
5. resolution - separation of the time points (in seconds) at which we calculate recall probabilities
5. slice_size - determines  the number of choping operations DataChopper performs  - since Data Chopper returns eeg time series using smaller slice_size has less strain on memory. in principle call Data Chopper only once but if the number of chops is large we might run out of memory...

In [87]:
def compute_probs_ts(pow_wavelet_session,classifier_data,start_time = 10.0, end_time=20.0, slice_size=10, resolution=0.1):
    
    lr_classifier = classifier_data.lr_classifier
    mean = classifier_data.mean
    std = classifier_data.std
    
    #resolution is in seconds
    samplerate = float(pow_wavelet_session['samplerate'])
    
    number_of_samples_in_resolution = int(round(resolution*samplerate))
    
    total_number_of_items  = int(round((end_time-start_time)/resolution))
                          
    number_of_compute_iterations = total_number_of_items / slice_size

    probs_list=[]
    

    for n in xrange(number_of_compute_iterations):
        st = start_time + n*slice_size*resolution
        initial_offset = int(round(st*samplerate))
        start_offsets = initial_offset +  np.arange(slice_size)*number_of_samples_in_resolution
        
        pow_wavelet_chopped = chop_time_series(time_series=pow_wavelet_session,start_offsets=start_offsets)
        pow_wavelet_chopped = pow_wavelet_chopped.rename({'start_offsets':'events'})
        np.log10(pow_wavelet_chopped.data, out=pow_wavelet_chopped.data) 
        
        features = compute_features_using_zscoring_params(pow_wavelet_chopped,mean,std)    
        probs = lr_classifier.predict_proba(features)[:, 1]
                
        probs_list.append(probs)
    
    
    probs_array = np.hstack(probs_list)
    time_axis = start_time + np.arange(probs_array.shape[0])*resolution
    return time_axis, probs_array

Begining of the computational pipeline that computes classifiers

In [88]:
base_events = get_events()

In [89]:
monopolar_channels, bipolar_pairs = get_monopolar_and_bipolar_electrodes()

In [None]:
if compute_classifier_flag:
    compute_classifier()

#we read classifier from the disk    
    
with open('classifier_data_'+subject+'.pkl', 'rb') as fid:
    classifier_data = cPickle.load(fid)

lr_classifier = classifier_data.lr_classifier
mean = classifier_data.mean
std = classifier_data.std

Computing wavelets for entire session (continuous mode)

In [None]:
pow_wavelet_session = compute_continuous_wavelets(session=0)

Here we are computing time series of recall probabilities

In [None]:
time_axis, probs_array = compute_probs_ts(pow_wavelet_session,
                                          classifier_data,
                                          start_time=50.0, 
                                          end_time=70.0, 
                                          slice_size=10, 
                                          resolution=0.1
                                          )

In [None]:
plt.plot(time_axis,probs_array)