In [3]:
import numpy as np
import csv
import math
import antropy as ant
import random

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline, Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score, cross_val_predict
from sklearn.metrics import classification_report
from sklearn.ensemble import RandomForestClassifier
from sklearn.decomposition import PCA
from sklearn.svm import SVC
from sklearn.ensemble import StackingClassifier, VotingClassifier, GradientBoostingClassifier, AdaBoostClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import KFold, StratifiedKFold, StratifiedGroupKFold, GroupKFold
from sklearn.tree import DecisionTreeClassifier
from sklearn.neural_network import MLPClassifier

## Data Setup

In [4]:
begin, end = 1, 61 # (begin is inclusive, end is exclusive)
num_people = 14
count_samples = {
    "active": 8,
    "meditate": 8,
    "neutral": 8
}

class Sample:
    def __init__(self):
        self.data = {
            'RawEEG': [],
            'Alpha': [],
            'Low Beta': [],
            'High Beta': [],
            'Gamma': [],
            'Theta': [],
            'Delta': [],
            'Meditation': [],
            'Attention': []
        }

    def recordDataPoint(self, RawEEG, Attention, Meditation, Alpha, Delta, Theta, LowBeta, HighBeta, Gamma):
        self.data['RawEEG'].append(float(RawEEG))
        self.data['Attention'].append(float(Attention))
        self.data['Meditation'].append(float(Meditation))
        self.data['Alpha'].append(float(Alpha))
        self.data['Delta'].append(float(Delta))
        self.data['Theta'].append(float(Theta))
        self.data['Low Beta'].append(float(LowBeta))
        self.data['High Beta'].append(float(HighBeta))
        self.data['Gamma'].append(float(Gamma))

    '''
    Record a line of data from the CSV output, which takes form RawEEG, Alpha, Delta, Gamma, Low Beta, High Beta, Theta, Attention, Meditation

    '''
    def recordDataLine(self, line):
        self.recordDataPoint(line[0], line[7], line[8], line[1], line[2], line[6], line[4], line[5], line[3])
    
    def getEEG(self):
        return self.data['RawEEG']
    
    def getAttention(self):
        return self.data["Attention"]
    
    def getMeditation(self):
        return self.data["Meditation"]
    
    def getAlpha(self):
        return self.data["Alpha"]
    
    def getDelta(self):
        return self.data["Delta"]
    
    def getTheta(self):
        return self.data["Theta"]
    
    def getLowBeta(self):
        return self.data["Low Beta"]
    
    def getHighBeta(self):
        return self.data["High Beta"]
    
    def getGamma(self):
        return self.data["Gamma"]

    def get(self, key):
        return self.data[key]

    '''
    Filter out all outliers, as defined by being outside 3*std from the mean, and replace with mean of the samples around them
    '''
    def filter_outliers(self):
        sampleBad = False
        for key in ['RawEEG', 'Alpha', 'Theta', 'Low Beta', 'High Beta', "Gamma", 'Delta']:
            data = self.data[key]
            
            filtered = []

            iqr = np.subtract(*np.percentile(data, [75, 25]))
            med = np.median(data)

            for x in data:
                
                if (med - 1.5*iqr > x) or (med + 1.5*iqr < x) or abs(x - np.mean(data)) > 2 * np.std(data):
                    filtered.append(med)
                    # filtered.append(np.median(data[max(0, i-5):i] + data[i+1:min(len(data), i+5)]))
                else:
                    filtered.append(x)
                    
            self.data[key] = filtered
        return sampleBad

In [5]:
# {personNum : {state: [sampleNums]}}
# 0 = key for throwing away all samples of that state

badSamples = {
    1: {"active": [5], "neutral": [2], "meditate": []},
    2: {"active": [0], "neutral": [0], "meditate": [0]},
    3: {"active": [1, 4], "neutral": [1], "meditate": [5, 6, 7, 8]},
    4: {"active": [2], "neutral": [1, 7], "meditate": [1, 8]},
    5: {"active": [], "neutral": [], "meditate": []}, 
    6: {"active": [], "neutral": [2, 6], "meditate": []},
    7: {"active": [5], "neutral": [4, 6, 7], "meditate": [1, 3, 4, 8]}, 
    8: {"active": [5], "neutral": [1], "meditate": [5, 8]}, 
    9: {"active": [], "neutral": [], "meditate": []}, 
    10: {"active": [6, 8], "neutral": [4, 5, 6], "meditate": []},
    11: {"active": [4], "neutral": [4, 8], "meditate": [1, 2, 3, 5, 7]},
    12: {"active": [2, 3, 8], "neutral": [0], "meditate": [6]}, 
    13: {"active": [], "neutral": [8], "meditate": []},
    14: {"active": [4, 5, 8], "neutral": [0], "meditate": [1, 2, 8]}
}

In [11]:
data = []
dataLabels = []
# groups = [x+1 for sublist in ([i]*24 for i in range(num_people)) for x in sublist]
# ^^ this is the list comprehension that works to get all of the indexes of people-groups but 
# I just realized that you cant do it that way bc of outliers not existing so F
groups = []

def transcribeFileToSample(personN: int, sampleN: int, state: str, X, y, outlierFiltering = True):
    sample_data = Sample()

    with open("data/all_data/" + state + "_" + str(personN) + "_" + str(sampleN) + ".csv") as f:
        reader = csv.reader(f)

        header = next(reader)
        
        for row in reader:
            sample_data.recordDataLine(row)

        if (outlierFiltering):   
            if (0 not in badSamples[personN][state] and sampleN not in badSamples[personN][state]):

                for key in sample_data.data:
                    sample_data.data[key] = sample_data.data[key][begin:end]

                sample_data.filter_outliers()
                X.append(sample_data)
                y.append(state)
                groups.append(group)

        else:
            X.append(sample_data)
            y.append(state)

for person in range(num_people):
    for state in count_samples:
        for i in range(8):
            transcribeFileToSample(person + 1, i + 1, state, data, dataLabels)

UnboundLocalError: local variable 'group' referenced before assignment

In [7]:
dataExtracted = []

def safety_check(x):
    if math.isnan(x): return 0
    if math.isinf(x): return 99999999999
    return x

for point in data:
    extractedPoint = []

    extractedPoint.append(np.mean(point.getAlpha()))
    extractedPoint.append(np.mean(point.getLowBeta()))
    extractedPoint.append(np.mean(point.getHighBeta())) 
    extractedPoint.append(np.mean(point.getGamma())) 
    extractedPoint.append(np.mean(point.getTheta()))
    extractedPoint.append(np.std(point.getHighBeta())) 
    extractedPoint.append(np.std(point.getGamma()))
    extractedPoint.append(np.std(point.getDelta()))
    extractedPoint.append(safety_check(ant.sample_entropy(point.getDelta())))
    extractedPoint.append(np.std(point.getLowBeta())) 
    extractedPoint.append(np.std(point.getTheta()))
    
    # extractedPoint.append(safety_check(ant.spectral_entropy(point.getEEG(), sf=1)))
    # extractedPoint.append(np.mean(point.getDelta()))
    # extractedPoint.append(np.std(point.getAlpha())) 

    dataExtracted.append(extractedPoint)

In [8]:
cvclf = RandomForestClassifier(max_depth=20, n_estimators=2000)

cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=13)
sgkf = GroupKFold(n_splits=4)
scores = cross_val_score(cvclf, dataExtracted, dataLabels, groups=groups, cv=sgkf, n_jobs=-1)
print("%0.2f accuracy with a standard deviation of %0.2f" % (scores.mean(), scores.std()))

0.53 accuracy with a standard deviation of 0.06


In [10]:
# sgkf.split(groups=groups, X=dataExtracted, y=dataLabels)
sgkf = GroupKFold(n_splits=4)
dataLabels = np.array(dataLabels)
groups = np.array(groups)

for train_idxs, test_idxs in sgkf.split(np.array(dataExtracted), dataLabels, groups):
    # print(test_idxs)
    print("TRAIN:", groups[train_idxs])
    print("      ", dataLabels[train_idxs])
    print(" TEST:", groups[test_idxs])
    print("      ", dataLabels[test_idxs])

TRAIN: [ 1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  3  3  3  3  3  3  3
  3  3  3  3  3  3  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  5  5  5
  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  6  6  6  6  6  6  6
  6  6  6  6  6  6  6  7  7  7  7  7  7  7  7  7  7  8  8  8  8  8  8  8
  8  8  8  8  8  8  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9  9
  9  9 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 11 11 11 11 11 11 11
 11 11 11 11 11 12 12 12 12 12 12 12 12 13 13 13 13 13 13 13 13 13 13 13
 13 13 13 13 13 13 13 13 14 14 14 14 14 14 14]
       ['active' 'active' 'active' 'active' 'active' 'active' 'meditate'
 'meditate' 'meditate' 'meditate' 'meditate' 'neutral' 'neutral' 'neutral'
 'neutral' 'neutral' 'neutral' 'active' 'active' 'active' 'active'
 'active' 'meditate' 'meditate' 'meditate' 'neutral' 'neutral' 'neutral'
 'neutral' 'neutral' 'active' 'active' 'active' 'active' 'active' 'active'
 'active' 'meditate' 'meditate' 'meditate' 'neutral' 'neutral' 'neutral'