# CNN model

In [18]:
from pylsl import StreamInlet, resolve_stream
import numpy as np
from scipy.fft import fft2, fftshift
import matplotlib.pyplot as plt
from scipy.signal import butter, lfilter
from scipy.signal import welch
import mne
import pywt
import pickle
import keras
from keras.models import load_model
from sklearn.metrics import confusion_matrix
import seaborn as sns
from matplotlib.animation import FuncAnimation
from datetime import datetime


pretrained_model = load_model("CNN_model1s_3c_weights.h5")

epoch_2s = np.zeros(shape=(7,251))

# first resolve an EEG stream on the lab network
print("looking for an EEG stream...")
streams = resolve_stream()
print(streams)

data = []
time_count = 0.0
time_all = 0.0
time_interval = 0.1

# create a new inlet to read from the stream
inlet = StreamInlet(streams[0])

fs = 250.0
 
channels = ["Fz","C3", "Cz", "C4","Pz","PO7","PO8"]

info = mne.create_info(
    ch_names= channels,
    ch_types= ['eeg']*len(channels),
    sfreq= fs
)

def call_pywt(new_data):
    component_num = 7
    train_cwt = np.ndarray(shape=(new_data.shape[0], new_data.shape[2], component_num))
    for jj in range(0, new_data.shape[0]):
        train_cwt[jj] = new_data[jj].T

    scales = range(1,31)
    waveletname = 'morl'
    train_size = new_data.shape[0]
    train_data_cwt = np.ndarray(shape=(train_size, len(scales), new_data.shape[2], component_num))

    for ii in range(0,train_size):
        for jj in range(0,component_num):
            signal = train_cwt[ii, :, jj]
            coeff, _ = pywt.cwt(signal, scales, waveletname, 1)
            coeff_ = coeff[:,:new_data.shape[2]]  
            train_data_cwt[ii, :, :, jj] = coeff_

    train_cwt_stack = np.ndarray(shape=(train_size, len(scales)*component_num, new_data.shape[2]))

    for jj in range(0,train_data_cwt.shape[0]):
        train_cwt_stack[jj] = np.vstack((train_data_cwt[jj,:,:,0], train_data_cwt[jj,:,:,1], train_data_cwt[jj,:,:,2], train_data_cwt[jj,:,:,3], train_data_cwt[jj,:,:,4], train_data_cwt[jj,:,:,5], train_data_cwt[jj,:,:,6]))
    return train_cwt_stack


while True:
    sample, timestamp = inlet.pull_sample()

    if timestamp:
        data1 = sample[0:6]
        data1.append(sample[7])
        data.append(data1)  

        time_count += 1 * 0.004
        time_count = np.round(time_count ,3)
        time_all += 1 * 0.004

        if time_count == time_interval:

            data = np.asarray(data)
            epoch_2s = np.delete(epoch_2s, slice(0,int(time_interval/0.004)), axis =1)
            epoch_2s = np.append(epoch_2s, data.T, axis= 1)

            # print(epoch_2s)

            # data = np.asarray(data)
            # data = np.transpose(data)
            # epoch_2s = data
            # print(epoch_2s)
            
            # raw = mne.io.RawArray(epoch_2s, info, verbose=False)
            # eeg1 = raw.copy().filter(l_freq=8.0, h_freq=30.0, method = 'iir', iir_params= {"order": 5, "ftype":'butter'}, verbose = False)
            # eeg1 = eeg1.copy().set_eeg_reference(ref_channels="average", verbose=False)

            raw = mne.EpochsArray(epoch_2s.reshape(1,7,251), info, verbose=False)
            eeg1 = raw.copy().filter(l_freq=6.0, h_freq=32.0, method = 'iir', iir_params= {"order": 5, "ftype":'butter'}, verbose = False)
            eeg1 = eeg1.copy().set_eeg_reference(ref_channels="average", verbose=False)

            # print(np.shape(eeg1.get_data()))
            # print("*****************Fitered*******************")
            # print(eeg1.get_data())
            # print(eeg1.get_data().max())

            with open('trained_csp_model.pkl', 'rb') as file:
                trained_csp = pickle.load(file)

            new_data = trained_csp.transform(eeg1.get_data().reshape(1,7,251))

            # print("*****************CSP*******************")
            # print(new_data)

            train_cwt_stack = call_pywt(new_data)

            # print("*****************CWT*******************")
            # print(train_cwt_stack)


            y_pred_raw = pretrained_model.predict(train_cwt_stack ,verbose=False)

            # Append the predicted class to the real-time plot data
            # x_vals.append(time_all)
            # y_vals.append(np.argmax(y_pred_raw))

            print("************************************")
            print(y_pred_raw)
            print("predict class", np.argmax(y_pred_raw))
        
            data = np.ndarray.tolist(data)
            data = []
            time_count = 0.0
        
        # if time_all >= 10.0:
        #     break

looking for an EEG stream...
[<pylsl.pylsl.StreamInfo object at 0x0000016168A442D0>, <pylsl.pylsl.StreamInfo object at 0x00000161D7096B90>]
************************************
[[0. 0. 1.]]
predict class 2
************************************
[[0. 0. 1.]]
predict class 2
************************************
[[0. 0. 1.]]
predict class 2
************************************
[[0. 0. 1.]]
predict class 2
************************************
[[1.000000e+00 0.000000e+00 9.910584e-35]]
predict class 0
************************************
[[0. 1. 0.]]
predict class 1
************************************
[[1. 0. 0.]]
predict class 0
************************************
[[1. 0. 0.]]
predict class 0
************************************
[[0. 1. 0.]]
predict class 1
************************************
[[0. 1. 0.]]
predict class 1
************************************
[[0.129513   0.30088893 0.5695981 ]]
predict class 2
************************************
[[0.03392842 0.7148223  0.25124928]]
predic

KeyboardInterrupt: 

# CSP+LDA

In [16]:
# from pylsl import StreamInlet, resolve_stream
# import numpy as np
# import matplotlib.pyplot as plt
# import mne
# import pickle
# from sklearn.metrics import confusion_matrix
# import seaborn as sns
# from matplotlib.animation import FuncAnimation


# epoch_2s = np.zeros(shape=(7,501))

# print("looking for an EEG stream...")
# streams = resolve_stream()
# print(streams)

# data = []
# time_count = 0.0
# time_all = 0.0
# time_interval = 0.1

# inlet = StreamInlet(streams[0])

# fs = 250.0
 
# channels = ["Fz","C3", "Cz", "C4","Pz","PO7","PO8"]

# info = mne.create_info(
#     ch_names= channels,
#     ch_types= ['eeg']*len(channels),
#     sfreq= fs
# )

# while True:
#     sample, timestamp = inlet.pull_sample()
#     # print(sample)

#     if timestamp:
#         data1 = sample[0:6]
#         data1.append(sample[7])
#         data.append(data1)  

#         time_count += 1 * 0.004
#         time_count = np.round(time_count ,3)
#         time_all += 1 * 0.004

#         if time_count == time_interval:

#             data = np.asarray(data)
#             epoch_2s = np.delete(epoch_2s, slice(0,int(time_interval/0.004)), axis =1)
#             epoch_2s = np.append(epoch_2s, data.T, axis= 1)
            

#             raw = mne.EpochsArray(epoch_2s.reshape(1,7,501), info, verbose=False)
#             eeg1 = raw.copy().filter(l_freq=6.0, h_freq=32.0, method = 'iir', iir_params= {"order": 5, "ftype":'butter'}, verbose = False)
#             eeg1 = eeg1.copy().set_eeg_reference(ref_channels="average", verbose=False)
            

#             with open('trained_csp2_model.pkl', 'rb') as file:
#                 trained_csp2 = pickle.load(file)
#             new_data2 = trained_csp2.transform(eeg1.get_data().reshape(1,7,501))


#             with open('trained_lda_model.pkl', 'rb') as file:
#                 trained_lda = pickle.load(file)

#             print(trained_lda.predict(new_data2))

#             if trained_lda.predict(new_data2)[0] == 6:
#                 print('class predict LEFT')
#             elif trained_lda.predict(new_data2)[0] == 7:
#                 print('class predict RIGHT')
#             elif trained_lda.predict(new_data2)[0] == 9:
#                 print('class predict NON-MI')
#             else:
#                 print('class predict UP')
        
#             data = np.ndarray.tolist(data)
#             data = []
#             time_count = 0.0


looking for an EEG stream...
[<pylsl.pylsl.StreamInfo object at 0x00000161D5EEB350>, <pylsl.pylsl.StreamInfo object at 0x00000161D3998E10>]
[7]
class predict RIGHT
[7]
class predict RIGHT
[7]
class predict RIGHT
[7]
class predict RIGHT
[7]
class predict RIGHT
[7]
class predict RIGHT
[7]
class predict RIGHT
[7]
class predict RIGHT
[7]
class predict RIGHT
[7]
class predict RIGHT
[7]
class predict RIGHT
[7]
class predict RIGHT
[7]
class predict RIGHT
[7]
class predict RIGHT
[7]
class predict RIGHT
[7]
class predict RIGHT
[7]
class predict RIGHT
[7]
class predict RIGHT
[7]
class predict RIGHT
[7]
class predict RIGHT
[7]
class predict RIGHT
[7]
class predict RIGHT
[9]
class predict NON-MI
[9]
class predict NON-MI
[9]
class predict NON-MI
[9]
class predict NON-MI
[9]
class predict NON-MI
[9]
class predict NON-MI
[9]
class predict NON-MI
[9]
class predict NON-MI
[9]
class predict NON-MI
[9]
class predict NON-MI
[9]
class predict NON-MI
[9]
class predict NON-MI
[9]
class predict NON-MI
[9]
cla

KeyboardInterrupt: 