In [1]:
from pylsl import StreamInlet, resolve_stream
import numpy as np
from scipy.fft import fft2, fftshift
import matplotlib.pyplot as plt
from scipy.signal import butter, lfilter
from scipy.signal import welch
import mne
import pywt
import pickle
import keras
from keras.models import load_model
from sklearn.metrics import confusion_matrix
import seaborn as sns

pretrained_model = load_model("CNN_MIexe_model_weights.h5")

epoch_2s = np.zeros(shape=(7,501))

# first resolve an EEG stream on the lab network
print("looking for an EEG stream...")
streams = resolve_stream()
print(streams)

data = []
time_count = 0.0
time_all = 0.0

# create a new inlet to read from the stream
inlet = StreamInlet(streams[0])

fs = 250.0
 
channels = ["Fz","C3", "Cz", "C4","Pz","PO7","PO8"]

info = mne.create_info(
    ch_names= channels,
    ch_types= ['eeg']*len(channels),
    sfreq= fs
)

def call_pywt(new_data):
    component_num = 7
    train_cwt = np.ndarray(shape=(new_data.shape[0], new_data.shape[2], component_num))
    for jj in range(0, new_data.shape[0]):
        train_cwt[jj] = new_data[jj].T
    print(np.shape(new_data))

    scales = range(1,31)

    waveletname = 'morl'
    train_size = new_data.shape[0]
    train_data_cwt = np.ndarray(shape=(train_size, len(scales), new_data.shape[2], component_num))

    for ii in range(0,train_size):
        if ii % 40 == 0:
            print(ii)
        for jj in range(0,component_num):
            signal = train_cwt[ii, :, jj]
            coeff, _ = pywt.cwt(signal, scales, waveletname, 1)
            coeff_ = coeff[:,:new_data.shape[2]]  #crop 227 sample for each channel
            train_data_cwt[ii, :, :, jj] = coeff_
    print(np.shape(train_data_cwt))

    # Stack array and convert to image
    from PIL import Image as im 
    train_cwt_stack = np.ndarray(shape=(train_size, len(scales)*component_num, new_data.shape[2]))

    for jj in range(0,train_data_cwt.shape[0]):
        train_cwt_stack[jj] = np.vstack((train_data_cwt[jj,:,:,0], train_data_cwt[jj,:,:,1], train_data_cwt[jj,:,:,2], train_data_cwt[jj,:,:,3], train_data_cwt[jj,:,:,4], train_data_cwt[jj,:,:,5], train_data_cwt[jj,:,:,6]))
    return train_cwt_stack


while True:
    # get a new sample (you can also omit the timestamp part if you're not
    # interested in it)
    sample, timestamp = inlet.pull_sample()
    # print(sample)

    if timestamp:
        data1 = sample[0:6]
        data1.append(sample[7])
        data.append(data1)

        time_count += 1 * 0.004
        time_count = np.round(time_count ,3)
        time_all += 1 * 0.004

        if time_count == 2.004:

            # data = np.asarray(data)
            # epoch_2s = np.delete(epoch_2s, slice(0,25), axis =1)
            # epoch_2s = np.append(epoch_2s, data.T, axis= 1)

            # print(epoch_2s)

            data = np.asarray(data)
            data = np.transpose(data)
            epoch_2s = data
            # print(epoch_2s)
            
            raw = mne.io.RawArray(epoch_2s, info)
            eeg1 = raw.copy().filter(l_freq=1.0, h_freq=30.0, method = 'iir', iir_params= {"order": 5, "ftype":'butter'})
            eeg1 = eeg1.copy().set_eeg_reference(ref_channels="average")

            # print(np.shape(eeg1.get_data()))
            # print("*****************raw 2 seconds*******************")
            # print(eeg1.get_data())

            with open('trained_csp_model.pkl', 'rb') as file:
                trained_csp = pickle.load(file)

            new_data = trained_csp.transform(eeg1.get_data().reshape(1,7,501))

            # print("*****************CSP*******************")
            # print(new_data)

            train_cwt_stack = call_pywt(new_data)

            # print("*****************CWT*******************")
            # print(train_cwt_stack)

            y_pred_raw = pretrained_model.predict(train_cwt_stack)

            print("************************************")
            print(y_pred_raw)
            print("predict class", np.argmax(y_pred_raw))
        
            data = np.ndarray.tolist(data)
            data = []
            time_count = 0.0
        
        # if time_all >= 10:
        #     break

looking for an EEG stream...
[<pylsl.pylsl.StreamInfo object at 0x000001F33CDF1C50>, <pylsl.pylsl.StreamInfo object at 0x000001F3099DA710>, <pylsl.pylsl.StreamInfo object at 0x000001F3099D8E50>]
Creating RawArray with float64 data, n_channels=7, n_times=501
    Range : 0 ... 500 =      0.000 ...     2.000 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 30 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 20 (effective, after forward-backward)
- Cutoffs at 1.00, 30.00 Hz: -6.02, -6.02 dB

EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
(1, 7, 501)
0
(1, 30, 501, 7)
************************************
[[0.07631462 0.09174915 0.32023987 0.51169634]]
predict class 3
Creating RawArray with float64 data, n_channels=7, n_times=501
    Range : 0 ... 500 =      0.000 ...     2.000 secs
Ready.
Fi

KeyboardInterrupt: 