# EEG Explorer

## Initialization

In [1]:
import math
import numpy as np
import pandas as pd
from pandas import DataFrame as df
import scipy.signal as signal
import matplotlib.pyplot as plt

import plotly
from scipy.fftpack import ifft, fft
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go

# from pycaret.clustering import *

## Session Object

In [2]:
class session():
    # An object for storing data and metadata for one session of EEG recording
    def __init__(self, filename):
        self.f = filename
        
    def get_data(self, type='OpenBCI', fix=['OpenBCI-col_names']):
        # Create a pandas dataframe with channel data (data_chns)
        # fix:
        #   - "OpenBCI-col_names" - removes spaces from column names and makes channel number's 1-indexed
        if type == 'OpenBCI':
            f = open(self.f)
            self.meta = [f.readline() for i in range(4)]
            self.n_chns = int(self.meta[1][22:])
            self.fs = int(self.meta[2][15:18]) # sample rate

            self.data_chns = pd.read_csv(self.f, skiprows=[0,1,2,3])
            self.data_chns = self.data_chns.drop(columns=['Sample Index'])
            
            if 'OpenBCI-col_names' in fix:
                for i in range(self.n_chns):
                    self.data_chns = self.data_chns.rename(columns={self.data_chns.keys()[i]:'eeg_channel_'+str(i+1)})
            
            if 'Time' not in self.data_chns.keys():
                self.data_chns.reset_index(inplace=True)
                self.data_chns = self.data_chns.rename(columns={'index':'Time'})
                self.data_chns['Time'] = self.data_chns['Time'].divide(self.fs)
                    
    def crop_data(self, upto=-1, after=-1):
        # Crop the data upto or after a certain time in seconds
        if upto > 0:
            self.data_chns = self.data_chns.query('Time>=@upto')
            print('Removed data upto ' + str(upto) + ' seconds.\n')
        if after > 0:
            self.data_chns = self.data_chns.query('Time<=@after')
            print('Removed data after ' + str(after) + ' seconds.\n')
    
    def make_fft(self, test_fq=-1):
        timestep = 1/self.fs
        self.fft = df()
        
        NFFT = 0
        exp = 1
        while 2**exp < len(self.data_chns):
            NFFT = 2**exp
            exp += 1
        
        self.fft['Frequency'] = np.fft.fftfreq(NFFT)
        self.fft['Frequency'] = np.fft.fftshift(self.fft['Frequency'])
        self.fft = self.fft.query('Frequency>=0').mul(self.fs)
#         self.fft['Frequency'] = self.fft['Frequency'].mul(self.fs)[NFFT//2:]
#         self.fft = self.fft.query('Frequency>=0')
        for i in range(self.n_chns):
#             print(len(self.fft['Frequency']))
#             print(len(np.fft.fft(self.data_chns['eeg_channel_'+str(i+1)], n=NFFT)[:(NFFT//2)]))
            self.fft['eeg_channel_'+str(i+1)] = (np.real(np.fft.fft(self.data_chns['eeg_channel_'+str(i+1)], n=NFFT))**2)[:NFFT//2]
        
        if test_fq > 0:
            if test_fq < self.fs/2:
                self.test_fq = test_fq
                self.data_chns[str(test_fq)+'_Hz_test_fq'] = [np.sin(test_fq*2*np.pi*i/self.fs) for i in range(len(self.data_chns))]
                self.fft[str(test_fq)+'_Hz_test_fq'] = (np.real(np.fft.fft(self.data_chns[str(test_fq)+'_Hz_test_fq'], n=NFFT))**2)[:NFFT//2]
            else:
                print('WARNING: The test frequency is too high to be detected at a sample rate of '+str(self.fs)+' Hz.')
                
    
    def plot(self, ver='', chns=[]):
        # Plot the data in different helpful ways
        # Versions:
        #    pick-chns
        #    all-chns-in-one
        #    chn-grid
        #    fq
        #    fq-old
        #    test-fft
        
        if ver == 'pick-chns':
            fig = go.Figure(layout=go.Layout(title=go.layout.Title(text=str(len(chns)) + ' EEG Channels')
                        ))
            
            if len(chns) == 0:
                chns = range(1, self.n_chns + 1)
                print(chns)
            for i in range(len(chns)):
                fig.add_trace(go.Scatter(x=self.data_chns['Time'],
                                            y=self.data_chns['eeg_channel_'+str(chns[i])],
                                            mode='lines',
                                            name='EEG Ch.'+str(chns[i])))
            fig.show()

        if ver == 'all-chns-in-one':
            fig = go.Figure()
            for i in range(self.n_chns):
                fig.add_trace(go.Scatter(x=self.data_chns['Time'], y=self.data_chns['eeg_channel_'+str(i+1)],
                        mode='lines',
                        name='EEG Ch.'+str(i+1)))
            fig.show()
            
        if ver == 'chn-grid':
            if len(chns) == 0:
                fig = make_subplots(rows=self.n_chns, cols=1,
                                    vertical_spacing=0.01,
                                    subplot_titles=['EEG Channel '+str(i+1) for i in range(self.n_chns)])
                for i in range(self.n_chns):
                    fig.append_trace(go.Scatter(x=self.data_chns['Time'],
                                                y=self.data_chns['eeg_channel_'+str(i+1)],
                                                mode='lines',
                                                name='EEG Ch.'+str(i+1)),
                                                row=i+1,
                                                col=1)
                fig.update_layout(height=300*self.n_chns, width=800, title_text="Data by Channel")
                fig.show()
            else:
                fig = make_subplots(rows=len(chns), cols=1,
                                    vertical_spacing=0.2,
                                    subplot_titles=['EEG Channel '+str(i+1) for i in range(len(chns))])
                for i in range(len(chns)):
                    fig.append_trace(go.Scatter(x=self.data_chns['Time'],
                                                y=self.data_chns['eeg_channel_'+str(chns[i])],
                                                mode='lines',
                                                name='EEG Ch.'+str(chns[i])),
                                                row=i+1,
                                                col=1)
                fig.update_layout(height=300*len(chns), width=800, title_text="Data by Channel")
                fig.show()
                
        if ver == 'fq-old':
            for i in range(self.n_chns):
#                 plt.plot(self.fft['Frequency'], self.fft['eeg_channel_'+str(i+1)])
                plt.psd(self.data_chns['eeg_channel_'+str(i+1)], Fs=self.fs)
                plt.show()
        
        if ver == 'fq':
            if len(chns) == 0:
                fig = go.Figure()
                fig = make_subplots(rows=self.n_chns, cols=1,
                                    vertical_spacing=0.01,
                                    subplot_titles=['EEG Channel '+str(i+1) for i in range(self.n_chns)])
                for i in range(self.n_chns):
                    fig.append_trace(go.Scatter(x=self.fft['Frequency'], y=self.fft['eeg_channel_'+str(i+1)],
                            mode='lines',
                            name='EEG Ch.'+str(i+1)),
                            row=i+1,
                            col=1,)
                fig.update_layout(height=300*self.n_chns, width=800, title_text="Data by Channel")
                fig.show()
            else:
                fig = go.Figure()
                fig = make_subplots(rows=len(chns), cols=1,
                                    vertical_spacing=0.15,
                                    subplot_titles=['EEG Channel '+str(chns[i]) for i in range(len(chns))])
                for i in range(len(chns)):
                    fig.append_trace(go.Scatter(x=self.fft['Frequency'], y=self.fft['eeg_channel_'+str(chns[i])],
                            mode='lines',
                            name='EEG Ch.'+str(chns[i])),
                            row=i+1,
                            col=1,)
                fig.update_layout(height=300*len(chns), width=800, title_text="Data by Channel")
                fig.show()
                
        if ver == 'test-fft':
            n =len(S1.data_chns)

            X = np.linspace(0, n/self.fs, n)
#             Y = [np.sin(21*(2*np.pi)*i/S1.fs) for i in X]
            Y = [np.sin(self.test_fq*2*np.pi*i) for i in X]
            tst = df({'X':X, 'Y':Y})
            fig = go.Figure()
            fig.add_trace(go.Scatter(x=tst['X'], y=tst['Y'],
                                    mode='lines',
                                    name='Test Frequency (' + str(self.test_fq) + ' Hz)'))
            fig.show()
            
            fig = go.Figure()
            fig.add_trace(go.Scatter(x=self.fft['Frequency'], y=self.fft[str(self.test_fq)+'_Hz_test_fq'],
                                     mode='lines',
                                     name='Test Frequency (' + str(self.test_fq) + ' Hz)'))
            fig.show()
                

                

    def preprocess(self, hpf=0, lpf=0, rmv_avg=False, inplace=True):
        # Handle preprocessing to create a dataframe of processed data
        # Scale the data from -1 to 1
        # High-pass filter = hpf
        # Low-pass filter = lpf
        # Remove Average = rmv_avg
        
        # first_mod = True
        
        self.proc_data_chns = pd.DataFrame()
        self.proc_data_chns['Time'] = self.data_chns['Time']
        
        self.s_factor = S1.data_chns.drop(columns=S1.data_chns.keys()[17:]).drop(columns='Time').max().max() # max val for all chns
        for i in range(self.n_chns):
            self.proc_data_chns['eeg_channel_'+str(i+1)] = self.data_chns['eeg_channel_'+str(i+1)].div(self.s_factor)

        if rmv_avg:
            for i in range(self.n_chns):
                self.proc_data_chns['eeg_channel_'+str(i+1)] = self.proc_data_chns['eeg_channel_'+str(i+1)] - np.mean(self.proc_data_chns['eeg_channel_'+str(i+1)])
#                 print(np.mean(self.data_chns['eeg_channel_'+str(i+1)]))

        
        if lpf > 0 and hpf > 0:
            for i in range(self.n_chns):
                b, a = signal.butter(2, lpf, 'low', fs=self.fs)
                self.proc_data_chns['eeg_channel_'+str(i+1)] = signal.filtfilt(b, a, self.proc_data_chns['eeg_channel_'+str(i+1)])

                b, a = signal.butter(2, hpf, 'high', fs=self.fs)
                self.proc_data_chns['eeg_channel_'+str(i+1)] = signal.filtfilt(b, a, self.proc_data_chns['eeg_channel_'+str(i+1)])

        elif lpf > 0:
            for i in range(self.n_chns):
                b, a = signal.butter(2, lpf, 'low', fs=self.fs)
                self.proc_data_chns['eeg_channel_'+str(i+1)] = signal.filtfilt(b, a, self.proc_data_chns['eeg_channel_'+str(i+1)])
        elif hpf > 0:
            for i in range(self.n_chns):
                b, a = signal.butter(2, hpf, 'high', fs=self.fs)
                self.proc_data_chns['eeg_channel_'+str(i+1)] = signal.filtfilt(b, a, self.proc_data_chns['eeg_channel_'+str(i+1)])


        if inplace:
            self.data_chns = self.proc_data_chns
        return self.proc_data_chns

### Create a Session (S1)

In [3]:
# S1 = session('OpenBCI-RAW-2021-08-13_18-03-39.txt')
S1 = session('OpenBCI-RAW-eyes_open_closed_open.txt')
S1.get_data()
S1.crop_data(upto=25)
S1.preprocess(hpf=5, lpf=50, rmv_avg=True)
S1.make_fft(test_fq=15)

Removed data upto 25 seconds.



In [15]:
# plot
versions = ['pick-chns', 'all-chns-in-one', 'chn-grid', 'fq', 'fq-old', 'test-fft']

S1.plot(ver=versions[0], chns=[1, 3, 6, 8])

## Component Analysis

### Pycaret Tutorials

In [11]:
S2

<__main__.session at 0x2347772ba30>

### Examples

In [21]:
# adapted example: show the ???data???

S2 = session('OpenBCI-RAW-2021-08-13_18-03-39.txt')
S2.get_data()

features = []
for i in S2.data_chns.keys():
    if 'eeg_channel' in i:
        features.append(i)

  
fig = px.scatter_matrix(S2.data_chns,
                       dimensions=['eeg_channel_'+str(i) for i in [1,3,6,8]])
fig.update_traces(diagonal_visible=False)
fig.show()

In [20]:
df = px.data.iris()
features = ["sepal_width", "sepal_length", "petal_width", "petal_length"]

fig = px.scatter_matrix(
    df,
    dimensions=features,
    color="species"
)
fig.update_traces(diagonal_visible=False)
fig.show()

In [15]:
import plotly.express as px
from sklearn.decomposition import PCA

df = px.data.iris()
features = ["sepal_width", "sepal_length", "petal_width", "petal_length"]

pca = PCA()
components = pca.fit_transform(df[features])
labels = {
    str(i): f"PC {i+1} ({var:.1f}%)"
    for i, var in enumerate(pca.explained_variance_ratio_ * 100)
}

fig = px.scatter_matrix(
    components,
    labels=labels,
    dimensions=range(4),
    color=df["species"]
)
fig.update_traces(diagonal_visible=False)
fig.show()

In [None]:
# adapted example: show the Principal Components


In [16]:
import pandas as pd
import plotly.express as px
from sklearn.decomposition import PCA
from sklearn.datasets import load_boston

boston = load_boston()
df = pd.DataFrame(boston.data, columns=boston.feature_names)
n_components = 4

pca = PCA(n_components=n_components)
components = pca.fit_transform(df)

total_var = pca.explained_variance_ratio_.sum() * 100

labels = {str(i): f"PC {i+1}" for i in range(n_components)}
labels['color'] = 'Median Price'

fig = px.scatter_matrix(
    components,
    color=boston.target,
    dimensions=range(n_components),
    labels=labels,
    title=f'Total Explained Variance: {total_var:.2f}%',
)
fig.update_traces(diagonal_visible=False)
fig.show()

In [17]:
import plotly.express as px
from sklearn.decomposition import PCA

df = px.data.iris()
X = df[['sepal_length', 'sepal_width', 'petal_length', 'petal_width']]

pca = PCA(n_components=3)
components = pca.fit_transform(X)

total_var = pca.explained_variance_ratio_.sum() * 100

fig = px.scatter_3d(
    components, x=0, y=1, z=2, color=df['species'],
    title=f'Total Explained Variance: {total_var:.2f}%',
    labels={'0': 'PC 1', '1': 'PC 2', '2': 'PC 3'}
)
fig.show()

### Real Data

In [38]:
### trying with pycaret
S2 = session('OpenBCI-RAW-2021-08-13_18-03-39.txt')
S2.get_data()
S2.crop_data(upto=1)
S2.preprocess(hpf=5, lpf=50, rmv_avg=True)
S2.make_fft()

Removed data upto 1 seconds.



In [49]:
S2.data_chns

Unnamed: 0,Time,eeg_channel_1,eeg_channel_2,eeg_channel_3,eeg_channel_4,eeg_channel_5,eeg_channel_6,eeg_channel_7,eeg_channel_8,eeg_channel_9,eeg_channel_10,eeg_channel_11,eeg_channel_12,eeg_channel_13,eeg_channel_14,eeg_channel_15,eeg_channel_16
125,1.000,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,1.128135e+04,8.349023e+03,-7.523164e-37,-7.523164e-37,-7.523164e-37,-1.504633e-36,-1.504633e-36,-7.523164e-37,-7.523164e-37,-7.523164e-37
126,1.008,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,3.283362e+06,4.500000e+06,-7.523164e-37,-7.523164e-37,-7.523164e-37,-1.504633e-36,-1.504633e-36,-7.523164e-37,-7.523164e-37,-7.523164e-37
127,1.016,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,8.434513e+05,1.428863e+06,-7.523164e-37,-7.523164e-37,-7.523164e-37,-1.504633e-36,-1.504633e-36,-7.523164e-37,-7.523164e-37,-7.523164e-37
128,1.024,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,1.654963e+06,2.123692e+06,-7.523164e-37,-7.523164e-37,-7.523164e-37,-1.504633e-36,-1.504633e-36,-7.523164e-37,-7.523164e-37,-7.523164e-37
129,1.032,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,1.357101e+06,1.526511e+06,-7.523164e-37,-7.523164e-37,-7.523164e-37,-1.504633e-36,-1.504633e-36,-7.523164e-37,-7.523164e-37,-7.523164e-37
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1260,10.080,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-8.250093e+05,-1.218255e+06,-7.523164e-37,-7.523164e-37,-7.523164e-37,-1.504633e-36,-1.504633e-36,-7.523164e-37,-7.523164e-37,-7.523164e-37
1261,10.088,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-8.832839e+05,-1.264134e+06,-7.523164e-37,-7.523164e-37,-7.523164e-37,-1.504633e-36,-1.504633e-36,-7.523164e-37,-7.523164e-37,-7.523164e-37
1262,10.096,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-1.086565e+06,-1.384472e+06,-7.523164e-37,-7.523164e-37,-7.523164e-37,-1.504633e-36,-1.504633e-36,-7.523164e-37,-7.523164e-37,-7.523164e-37
1263,10.104,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-2.075317e+06,-2.921972e+06,-7.523164e-37,-7.523164e-37,-7.523164e-37,-1.504633e-36,-1.504633e-36,-7.523164e-37,-7.523164e-37,-7.523164e-37


In [68]:
raw_df = S2.data_chns.drop(columns='Time')
raw_df

Unnamed: 0,eeg_channel_1,eeg_channel_2,eeg_channel_3,eeg_channel_4,eeg_channel_5,eeg_channel_6,eeg_channel_7,eeg_channel_8,eeg_channel_9,eeg_channel_10,eeg_channel_11,eeg_channel_12,eeg_channel_13,eeg_channel_14,eeg_channel_15,eeg_channel_16
125,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,1.128135e+04,8.349023e+03,-7.523164e-37,-7.523164e-37,-7.523164e-37,-1.504633e-36,-1.504633e-36,-7.523164e-37,-7.523164e-37,-7.523164e-37
126,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,3.283362e+06,4.500000e+06,-7.523164e-37,-7.523164e-37,-7.523164e-37,-1.504633e-36,-1.504633e-36,-7.523164e-37,-7.523164e-37,-7.523164e-37
127,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,8.434513e+05,1.428863e+06,-7.523164e-37,-7.523164e-37,-7.523164e-37,-1.504633e-36,-1.504633e-36,-7.523164e-37,-7.523164e-37,-7.523164e-37
128,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,1.654963e+06,2.123692e+06,-7.523164e-37,-7.523164e-37,-7.523164e-37,-1.504633e-36,-1.504633e-36,-7.523164e-37,-7.523164e-37,-7.523164e-37
129,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,1.357101e+06,1.526511e+06,-7.523164e-37,-7.523164e-37,-7.523164e-37,-1.504633e-36,-1.504633e-36,-7.523164e-37,-7.523164e-37,-7.523164e-37
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1260,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-8.250093e+05,-1.218255e+06,-7.523164e-37,-7.523164e-37,-7.523164e-37,-1.504633e-36,-1.504633e-36,-7.523164e-37,-7.523164e-37,-7.523164e-37
1261,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-8.832839e+05,-1.264134e+06,-7.523164e-37,-7.523164e-37,-7.523164e-37,-1.504633e-36,-1.504633e-36,-7.523164e-37,-7.523164e-37,-7.523164e-37
1262,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-1.086565e+06,-1.384472e+06,-7.523164e-37,-7.523164e-37,-7.523164e-37,-1.504633e-36,-1.504633e-36,-7.523164e-37,-7.523164e-37,-7.523164e-37
1263,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-7.523164e-37,-2.075317e+06,-2.921972e+06,-7.523164e-37,-7.523164e-37,-7.523164e-37,-1.504633e-36,-1.504633e-36,-7.523164e-37,-7.523164e-37,-7.523164e-37


In [77]:
from pycaret.clustering import *
pca_data = setup(data=raw_df, session_id=None) # simplified setup function
# pca_data = setup(data=raw_df, pca=True, pca_method='linear', pca_components=2) # real setup function

IntProgress(value=0, description='Processing: ', max=3)

ValueError: Setting a random_state has no effect since shuffle is False. You should leave random_state to its default (None), or set shuffle=True.

In [48]:
from pycaret.datasets import get_data
jewellery = get_data('jewellery')
from pycaret.clustering import *
exp_name = setup(data = jewellery)
kmeans = create_model('kmeans')
kmeans_predictions = predict_model(model = kmeans, data = unseen_data)

Unnamed: 0,Age,Income,SpendingScore,Savings
0,58,77769,0.791329,6559.829923
1,59,81799,0.791082,5417.661426
2,62,74751,0.702657,9258.992965
3,59,74373,0.76568,7346.334504
4,87,17760,0.348778,16869.50713


IntProgress(value=0, description='Processing: ', max=3)

ValueError: Setting a random_state has no effect since shuffle is False. You should leave random_state to its default (None), or set shuffle=True.