# EEG - GAN

In this project, we attempt to create synthetic EEG signals using a GAN.

## Initialization

#### Import Libraries

In [None]:
import os, requests
from matplotlib import rcParams
from matplotlib import pyplot as plt
from nilearn import plotting
from nimare import utils
import numpy as np
from scipy import signal

#### Dataset

In [None]:
fname = 'motor_imagery.npz'
url = "https://osf.io/ksqv8/download"

if not os.path.isfile(fname):
  try:
    r = requests.get(url)
  except requests.ConnectionError:
    print("!!! Failed to download data !!!")
  else:
    if r.status_code != requests.codes.ok:
      print("!!! Failed to download data !!!")
    else:
      with open(fname, "wb") as fid:
        fid.write(r.content)

#### Figure Setup

In [None]:
rcParams['figure.figsize'] = [20, 4]
rcParams['font.size'] = 15
rcParams['axes.spines.top'] = False
rcParams['axes.spines.right'] = False
rcParams['figure.autolayout'] = True

In [None]:
class process():
    def __init__():
        return

    def preprocess(data):
        V = data['V'].astype('float32')
        b, a = signal.butter(3, [50], btype = 'high', fs=1000)
        V = signal.filtfilt(b,a,V,0)
        V = np.abs(V)**2
        b, a = signal.butter(3, [10], btype = 'low', fs=1000)
        V = signal.filtfilt(b,a,V,0)
        V = V/V.mean(0)
        return V

class plots():
    def __init__():
        return

    def singlechannel1(data, channel):
        plt.figure(figsize=(20,10))
        plt.plot(trange, data[:,channel])
        plt.title('ch%d'%channel)
        plt.xticks([0, 1000, 2000])
        plt.ylim([0, 4])

    def singlechannel2(data, data2, channel):
        plt.figure(figsize=(20,10))
        plt.plot(trange, data[:,channel])
        plt.plot(trange, data2[:,channel])
        plt.title('ch%d'%channel)
        plt.xticks([0, 1000, 2000])
        plt.ylim([0, 4])

    def all_channels1(data):
        plt.figure(figsize=(20,10))
        for j in range(46):
            ax = plt.subplot(5,10,j+1)
            plt.plot(trange, data[:,j])
            plt.title('ch%d'%j)
            plt.xticks([0, 1000, 2000])
            plt.ylim([0, 4])

    def all_channels2(data, data2):
        plt.figure(figsize=(20,10))
        for j in range(46):
            ax = plt.subplot(5,10,j+1)
            plt.plot(trange, data[:,j])
            plt.plot(trange, data2[:,j])
            plt.title('ch%d'%j)
            plt.xticks([0, 1000, 2000])
            plt.ylim([0, 4])

## Manipulating the Dataset

In [None]:
alldat = np.load(fname, allow_pickle=True)['dat']

real = alldat[0][0]
imagine = alldat[0][1]

In [None]:
processed_real = process.preprocess(real)
nt_pr, nchan_pr = processed_real.shape
nstim_r = len(real['t_on'])

trange = np.arange(0, 2000)
ts = real['t_on'][:,np.newaxis] + trange

V_epochs_pr = np.reshape(processed_real[ts, :], (nstim_r, 2000, nchan_pr))

V_hand_real = (V_epochs_pr[real['stim_id']==12])  #Change this to 11 to get tongue data and 12 to get hand data


##### This is the voltage data and you could change the trial from here
hand_real_data = V_hand_real.mean(0)





processed_imagine = process.preprocess(imagine)
nt_pi, nchan_pi = processed_imagine.shape
nstim_i = len(imagine['t_on'])

trange = np.arange(0, 2000)
ts = imagine['t_on'][:,np.newaxis] + trange

V_epochs_pi = np.reshape(processed_imagine[ts, :], (nstim_i, 2000, nchan_pi))

V_hand_imagine = (V_epochs_pi[imagine['stim_id']==12])



##### This is the voltage data and you could change the trial from here
hand_imagine_data = V_hand_imagine.mean(0)

In [None]:
plots.all_channels2(hand_real_data, hand_imagine_data)
plots.singlechannel1(hand_real_data, 13)
plots.singlechannel2(hand_real_data, hand_imagine_data, 1)

In [None]:
#The dataset
specify_channel = 20

cnn_inputs_real = []
for i in range(30):
    abc = V_hand_real[i]
    cnn_inputs_real.append(abc[:, specify_channel])
#cnn_inputs_real is for real data. 1 is the label for real
cnn_inputs_real = np.array(cnn_inputs_real)
labels_real = np.ones((30))
labels_real = labels_real.reshape(30,1)


cnn_inputs_imagine = []
for i in range(30):
    abc = V_hand_imagine[i]
    cnn_inputs_imagine.append(abc[:, specify_channel])
#cnn_inputs_imagine is for imaginary data. 0 is the label for imaginary
cnn_inputs_imagine = np.array(cnn_inputs_imagine)
labels_imagine = np.zeros((30))
labels_imagine = labels_imagine.reshape(30,1) 


trainX = np.append(cnn_inputs_real[0:20], cnn_inputs_imagine[0:20], axis=0)
trainy = np.append(labels_real[0:20], labels_imagine[0:20], axis=0)

trainX = trainX.reshape([40,2000,1])

testX = np.append(cnn_inputs_real[20:30], cnn_inputs_imagine[20:30], axis=0)
testy = np.append(labels_real[20:30], labels_imagine[20:30], axis=0)
testX = testX.reshape([20,2000,1])

In [None]:
def CNN(trainX, trainy, testX, testy):
    n_timesteps, n_features, n_outputs = trainX.shape[1], trainX.shape[2], trainy.shape[1]
    verbose = 1
    epochs = 200
    batch_size = 40
    model = Sequential()
    model.add(Conv1D(filters=100, kernel_size=50, activation='relu', input_shape=(n_timesteps,n_features)))
    #model.add(Conv1D(filters=100, kernel_size=100, activation='relu'))
    #model.add(MaxPooling1D())
    #model.add(Conv1D(filters=100, kernel_size=40, activation='sigmoid'))
    model.add(GlobalAveragePooling1D())
    model.add(Dropout(0.05))
    model.add(Dense(n_outputs, activation='relu'))
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    accuracy = model.fit(trainX, trainy, epochs=epochs, batch_size=batch_size, verbose=verbose)
    return accuracy

In [None]:
#Final CNN Model
n_timesteps, n_features, n_outputs = trainX.shape[1], trainX.shape[2], trainy.shape[1]
verbose = 1
epochs = 200
batch_size = 40
model = Sequential()
model.add(Conv1D(filters=75, kernel_size=45, activation='relu', input_shape=(n_timesteps,n_features)))
model.add(Conv1D(filters=75, kernel_size=30, activation='relu'))
model.add(GlobalAveragePooling1D())
model.add(Dropout(0.003))
model.add(Dense(n_outputs, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
historyyy = model.fit(trainX, trainy, epochs=epochs, batch_size=batch_size, verbose=verbose,validation_data=(testX, testy))

In [None]:
plt.plot(historyyy.history['accuracy'])
plt.plot(historyyy.history['loss'])
plt.show()
plt.plot(historyyy.history['val_accuracy'])
plt.plot(historyyy.history['val_loss'])
plt.show()

In [None]:
CNN(trainX, trainy, testX, testy)