In [1]:
import numpy as np
from scipy import signal
from sklearn.model_selection import LeavePOut
import random
from eeg_lib.cca import GCCA_SSVEP
from eeg_lib.cca import MsetCCA_SSVEP
from eeg_lib.utils import standardise
import matplotlib.pyplot as plt
import pandas as pd

def synth_x(f, Ns, noise_power=0.5, fs=256):
    t = np.arange(0, Ns/fs, 1/fs)
    return np.sin(t*2*np.pi*f)*(1+random.random()*noise_power)

SOS_SSVEP_BANDPASS_256HZ = np.array(
    [   [5.18442631e-04, 5.91022291e-04, 5.18442631e-04, 1.00000000e00, -1.58700686e00, 6.47826110e-01,],
        [1.00000000e00, -6.71721317e-01, 1.00000000e00, 1.00000000e00, -1.56164716e00, 7.42956116e-01,],
        [1.00000000e00, -1.19862825e00, 1.00000000e00, 1.00000000e00, -1.53434369e00, 8.53024717e-01,],
        [1.00000000e00, -1.36462221e00, 1.00000000e00, 1.00000000e00, -1.52074686e00, 9.31086238e-01,],
        [1.00000000e00, -1.41821305e00, 1.00000000e00, 1.00000000e00, -1.52570664e00, 9.80264626e-01,],
    ])

fs_openbci = 200
filt_ord = 10
pb_rip = 0.2
sb_atten = 80

fc_lo = 4 # pass band lower freq
fc_hi = 28 # pass band upp freq 
wc_lo = fc_lo/(fs_openbci*0.5)
wc_hi = fc_hi/(fs_openbci*0.5)

sos_openbci = signal.ellip(filt_ord, pb_rip, sb_atten, (wc_lo, wc_hi), btype='bandpass', output='sos')

def load_array_data(file_path):
    data_file = open(file_path, "r")
    data_file = data_file.read().split(', ')
    for i, v in enumerate(data_file):
        if '[' in v:
            data_file[i] = v.replace('[','')
        if ']\n' in v:
            data_file[i] = v.replace(']\n','')
            
    values = [int(i) for i in data_file]
    return values

def load_array_data_float(file_path):
    data_file = open(file_path, "r")
    data_file = data_file.read().split(', ')
    for i, v in enumerate(data_file):
        if '[' in v:
            data_file[i] = v.replace('[','')
        if ']\n' in v:
            data_file[i] = v.replace(']\n','')
            
    values = [float(i) for i in data_file]
    return values

def load_array_data_online(file_path):
    data_file = open(file_path, "r")
    data_file = data_file.read().split(',')            
    values = [float(i) for i in data_file]
    return values

def average_every_n(values, size):
#     return np.array([sum(group) / size for group in zip(*[iter(values)]*size)])
    return values[::size]

def sos_filter_OpenBCI(values):
    return signal.sosfilt(sos_openbci, values)

def sos_filter_256(values):
    return signal.sosfilt(SOS_SSVEP_BANDPASS_256HZ, values)

def process_data(data, no_samples, no_train, remove_DC=True, apply_filter=False, ds_rate=1, downsample=False):

    data = data[:int(len(data)/no_samples)*no_samples]
    
    if remove_DC:
        data = data - sum(data)/len(data)
        if apply_filter:
            data = sos_filter_256(data)[no_samples:]
            plt.plot(data[:256])
            
    if downsample:
        print("downsampling",len(data))
        data = average_every_n(data, ds_rate)
        print("to",len(data))
        data = data[:int(len(data)/no_samples)*no_samples]
        
    data_reshape = data.reshape(int(len(data)/no_samples),no_samples)
    data = data_reshape.T.reshape(1,no_samples,int(len(data)/no_samples))
    return np.array(data)

def process_data_OpenBCI(data, no_samples, no_train, remove_DC=True, apply_filter=False, ds_rate=1, downsample=False):

    data = data[:int(len(data)/no_samples)*no_samples]
    
    if remove_DC:
        data = data - sum(data)/len(data)
        if apply_filter:
            data = sos_filter_OpenBCI(data)[no_samples:]
            plt.plot(data[:256])
            
    if downsample:
        print("downsampling",len(data))
        data = average_every_n(data, ds_rate)
        print("to",len(data))
        data = data[:int(len(data)/no_samples)*no_samples]
        
    data_reshape = data.reshape(int(len(data)/no_samples),no_samples)
    data = data_reshape.T.reshape(1,no_samples,int(len(data)/no_samples))
    return np.array(data)

def prepare_data(data, frequency, fs, fs_synth, no_samples, no_train, remove_DC=True, apply_filter=True, downsample=False, ds_avg=1, synth_power=0, noise_power=0):
    data = data + synth_power*synth_x(frequency, len(data), noise_power=0, fs=fs_synth)
    values = process_data(data,no_samples,no_train,remove_DC=remove_DC,apply_filter=apply_filter,ds_rate=ds_avg,downsample=downsample)
    return values

def prepare_data_OpenBCI(data, frequency, fs, fs_synth, no_samples, no_train, remove_DC=True, apply_filter=True, downsample=False, ds_avg=1, synth_power=0, noise_power=0):
    data = data + synth_power*synth_x(frequency, len(data), noise_power=0, fs=fs_synth)
    values = process_data_OpenBCI(data,no_samples,no_train,remove_DC=remove_DC,apply_filter=apply_filter,ds_rate=ds_avg,downsample=downsample)
    return values

def generate_train_test_idxs(data,no_train):
    lpo = LeavePOut(p=no_train)
    no_trials = data.shape[-1]
    return list(lpo.split(range(no_trials)))

def flatten(t):
    return [item for sublist in t for item in sublist]

def test_gcca_mset(data, data_idxs, freqs, fs, no_samples, number_runs=10):
    # Nf x Nc x Ns x Nt
    gcca = GCCA_SSVEP(freqs, fs, Nh=1)
    mset_cca = MsetCCA_SSVEP(freqs)
    gcca_total_acc = []
    mset_total_acc = []
    gcca_freq_acc = dict((key,[]) for key in freqs)
    mset_freq_acc = dict((key,[]) for key in freqs)

    for i in range(number_runs):
        # Nf x Nc x Ns x Nt
        train = data[:,:,:,data_idxs[i][1]]
        test = data[:,:,:,data_idxs[i][0]]

        gcca.fit(train)
        mset_cca.fit(train)
        
        for freq, value in enumerate(freqs):
            gcca_res = []
            mset_res = []
            #print("############################# Frequency:", value, " #############################")
            for test_idx in range(test.shape[-1]):
                test_now = test[freq, :, :, test_idx]
                
                gcca_decode = gcca.classify(test_now)
#                 print(gcca_decode)
                for key, prob in gcca_decode.items():
                    gcca_decode[key] = abs(prob)
                
                gcca_res.append(max(gcca_decode, key=gcca_decode.get))
                
                mset_decode = mset_cca.classify(test_now)
#                 print(mset_decode)
                mset_res.append(max(mset_decode, key=mset_decode.get))
            #print("GCCA accuracy {gcca_acc}\nMsetCCA {mset_acc}".format(gcca_acc=gcca_res.count(value)/len(gcca_res),mset_acc=mset_res.count(value)/len(mset_res)))
            
            gcca_total_acc.append(gcca_res.count(value)/len(gcca_res))
            mset_total_acc.append(mset_res.count(value)/len(mset_res))
            gcca_freq_acc[value].append(gcca_res)
            mset_freq_acc[value].append(mset_res)
    
    total_gcca = sum(gcca_total_acc)/len(gcca_total_acc)
    print("GCCA Total Average Accuracy:", sum(gcca_total_acc)/len(gcca_total_acc))
    total_mset = sum(mset_total_acc)/len(mset_total_acc)
    print("MsetCCA Total Average Accuracy:", sum(mset_total_acc)/len(mset_total_acc))
    
    gcca_freq_scores = []
    mset_freq_scores = []
    
    for key, value in gcca_freq_acc.items():
        flattened = flatten(value)
        print("GCCA {frequency}hz accuracy:{result}".format(frequency=key, result=flattened.count(key)/len(flattened)))
        gcca_freq_scores.append(flattened.count(key)/len(flattened))
    for key, value in mset_freq_acc.items():
        flattened = flatten(value)
        print("MsetCCA {frequency}hz accuracy:{result}".format(frequency=key, result=flattened.count(key)/len(flattened)))
        mset_freq_scores.append(flattened.count(key)/len(flattened))
        
    return total_gcca, total_mset, gcca_freq_scores, mset_freq_scores

In [None]:
import random

ds = 1 #downsample averaging size
if ds > 1:
    downsample=True
else:
    downsample=False
freqs = [8,10,12]
fs_synth = 250
fs = int(fs_synth/ds)
over_n_seconds = 1
number_of_samples = fs*over_n_seconds
number_of_train = 8
removeDC = False
applyFilter = False
synth_power = 0
synth_noise = 0

gcca_arr = {}
mset_arr = {}
gcca_f = {'Pz':[], 'PO5':[], 'PO3':[], 'POz':[], 'PO4':[], 'PO6':[], 'O1':[], 'Oz':[], 'O2':[]}
mset_f = {'Pz':[], 'PO5':[], 'PO3':[], 'POz':[], 'PO4':[], 'PO6':[], 'O1':[], 'Oz':[], 'O2':[]}

blocks = [i for i in range(1,7)]
channel_locations = ['Pz', 'PO5', 'PO3', 'POz', 'PO4', 'PO6', 'O1', 'Oz', 'O2']

print(blocks)
for channel in range(1,10):
    data_file_7 = []
#     for block in range(1,7):
    for block in blocks:
        data_file_7 += load_array_data_online("/Users/rishil/Desktop/FYP/EEG-decoding/eeg_lib/log/mnakanishi/8hz/8hz_channel_0{c}_0{b}".format(c=channel,b=block))                                       
#         data_file_7 += load_array_data_online(r"C:\Users\RISHI\Desktop\FYP\EEG-decoding\eeg_lib\log\mnakanishi\8hz\8hz_channel_0{c}_0{b}".format(c=channel,b=block))                                       

    values_7 = prepare_data(data_file_7, 8, fs, fs_synth, number_of_samples, number_of_train, remove_DC=removeDC, apply_filter=applyFilter, downsample=downsample, ds_avg=ds, synth_power=synth_power, noise_power=synth_noise)

    data_file_10 = []
#     for block in range(1,7):
    for block in blocks:
        data_file_10 += load_array_data_online("/Users/rishil/Desktop/FYP/EEG-decoding/eeg_lib/log/mnakanishi/10hz/10hz_channel_0{c}_0{b}".format(c=channel,b=block))
#         data_file_10 += load_array_data_online(r"C:\Users\RISHI\Desktop\FYP\EEG-decoding\eeg_lib\log\mnakanishi\10hz\10hz_channel_0{c}_0{b}".format(c=channel,b=block))                                       

    values_10 = prepare_data(data_file_10, 10, fs, fs_synth, number_of_samples, number_of_train, remove_DC=removeDC, apply_filter=applyFilter, downsample=downsample,ds_avg=ds, synth_power=synth_power, noise_power=synth_noise)

    data_file_12 = []
#     for block in range(1,7):
    for block in blocks:
        data_file_12 += load_array_data_online("/Users/rishil/Desktop/FYP/EEG-decoding/eeg_lib/log/mnakanishi/12hz/12hz_channel_0{c}_0{b}".format(c=channel,b=block))
#         data_file_12 += load_array_data_online(r"C:\Users\RISHI\Desktop\FYP\EEG-decoding\eeg_lib\log\mnakanishi\12hz\12hz_channel_0{c}_0{b}".format(c=channel,b=block))                                       
    values_12 = prepare_data(data_file_12, 12, fs, fs_synth, number_of_samples, number_of_train, remove_DC=removeDC, apply_filter=applyFilter, downsample=downsample,ds_avg=ds, synth_power=synth_power, noise_power=synth_noise)

    print(values_12.shape, values_10.shape, values_7.shape)

    data_packed = np.array([values_7, values_10, values_12])

    print(data_packed.shape)

    train_test_idxs = generate_train_test_idxs(data_packed, number_of_train)
    print(len(train_test_idxs))
    print((train_test_idxs[0]))
    gcca, mset, gccaf, msetf = test_gcca_mset(data_packed, train_test_idxs, freqs, fs, number_of_samples, number_runs=1)
    gcca_arr[channel_locations[channel-1]] = gcca
    mset_arr[channel_locations[channel-1]] = mset
    gcca_f[channel_locations[channel-1]].append(gccaf)
    mset_f[channel_locations[channel-1]].append(msetf)

[1, 2, 3, 4, 5, 6]
(1, 250, 30) (1, 250, 30) (1, 250, 30)
(3, 1, 250, 30)


In [19]:
import json

print("GCCA:", json.dumps(gcca_arr, indent=4, sort_keys=True))
print("MsetCCA:", json.dumps(mset_arr, indent=4, sort_keys=True))

print("GCCA Frequencies:", gcca_f)
print("MsetCCA Frequencies:", mset_f)

GCCA: {
    "O1": 0.6388888888888888,
    "O2": 0.5972222222222222,
    "Oz": 0.6805555555555557,
    "PO3": 0.6249999999999999,
    "PO4": 0.5833333333333334,
    "PO5": 0.625,
    "PO6": 0.48611111111111116,
    "POz": 0.6944444444444445,
    "Pz": 0.48611111111111116
}
MsetCCA: {
    "O1": 0.6805555555555555,
    "O2": 0.5694444444444444,
    "Oz": 0.8333333333333334,
    "PO3": 0.625,
    "PO4": 0.5694444444444445,
    "PO5": 0.6111111111111112,
    "PO6": 0.4583333333333333,
    "POz": 0.7361111111111112,
    "Pz": 0.5
}
GCCA Frequencies: {'Pz': [[0.4166666666666667, 0.625, 0.4166666666666667]], 'PO5': [[0.75, 0.625, 0.5]], 'PO3': [[0.75, 0.6666666666666666, 0.4583333333333333]], 'POz': [[0.9583333333333334, 0.7083333333333334, 0.4166666666666667]], 'PO4': [[0.7083333333333334, 0.5833333333333334, 0.4583333333333333]], 'PO6': [[0.2916666666666667, 0.5, 0.6666666666666666]], 'O1': [[0.7916666666666666, 0.5833333333333334, 0.5416666666666666]], 'Oz': [[0.9583333333333334, 0.75, 0.33