# EEG - GAN

## Initialization

#### Import Libraries

In [None]:
import os, requests
from matplotlib import rcParams
from matplotlib import pyplot as plt
import numpy as np
from scipy import signal

#### Dataset

In [None]:
fname = 'motor_imagery.npz'
url = "https://osf.io/ksqv8/download"

if not os.path.isfile(fname):
  try:
    r = requests.get(url)
  except requests.ConnectionError:
    print("!!! Failed to download data !!!")
  else:
    if r.status_code != requests.codes.ok:
      print("!!! Failed to download data !!!")
    else:
      with open(fname, "wb") as fid:
        fid.write(r.content)

#### Figure Setup

In [None]:
rcParams['figure.figsize'] = [20, 4]
rcParams['font.size'] = 15
rcParams['axes.spines.top'] = False
rcParams['axes.spines.right'] = False
rcParams['figure.autolayout'] = True

In [None]:
class process():
    def __init__():
        return

    def preprocess(data):
        V = data['V'].astype('float32')
        b, a = signal.butter(3, [50], btype = 'high', fs=1000)
        V = signal.filtfilt(b,a,V,0)
        V = np.abs(V)**2
        b, a = signal.butter(3, [10], btype = 'low', fs=1000)
        V = signal.filtfilt(b,a,V,0)
        V = V/V.mean(0)
        return V

class plots():
    def __init__():
        return

    def singlechannel1(data, channel, trange):
        plt.figure(figsize=(20,10))
        plt.plot(trange, data[:,channel])
        plt.title('ch%d'%channel)
        plt.xticks([0, 1000, 2000])
        plt.ylim([0, 4])

    def singlechannel2(data, data2, channel, trange):
        plt.figure(figsize=(20,10))
        plt.plot(trange, data[:,channel])
        plt.plot(trange, data2[:,channel])
        plt.title('ch%d'%channel)
        plt.xticks([0, 1000, 2000])
        plt.ylim([0, 4])

    def all_channels1(data, trange):
        plt.figure(figsize=(20,10))
        for j in range(46):
            ax = plt.subplot(5,10,j+1)
            plt.plot(trange, data[:,j])
            plt.title('ch%d'%j)
            plt.xticks([0, 1000, 2000])
            plt.ylim([0, 4])

    def all_channels2(data, data2, trange):
        plt.figure(figsize=(20,10))
        for j in range(46):
            ax = plt.subplot(5,10,j+1)
            plt.plot(trange, data[:,j])
            plt.plot(trange, data2[:,j])
            plt.title('ch%d'%j)
            plt.xticks([0, 1000, 2000])
            plt.ylim([0, 4])

## Manipulating the Dataset

In [None]:
AllData = np.load(fname, allow_pickle= True)['dat']

AllData[0][0].keys()

In [None]:
# Checking srate for every subject:
for i in range(7):
    for j in range(2):
        print(AllData[i][j]['srate'])

This is how we access subject 0 and experiment 0 (real movement)

In [None]:
RealPatientZero = AllData[0][0]

We store all the real trials in an array called RealPatients which holds 7 different participants real trials.

In [None]:
RealPatients = AllData[:][0]

This is how we access subject 0 and experiment 1 (imagined movement)

In [None]:
ImaginaryPatientZero = AllData[0][1]

We also create an array for the imaginary participants called ImaginaryPatients

In [None]:
ImaginaryPatients = AllData[:][1]

In [None]:
# Pipeline for preprocessing:

## We start off by picking subject 0 and experiment 0 (real movement)
data1 = RealPatientZero
V = data1['V'].astype('float32')
b, a = signal.butter(3, [50], btype = 'high', fs=1000)
V = signal.filtfilt(b,a,V,0)
V = np.abs(V)**2
b, a = signal.butter(3, [10], btype = 'low', fs=1000)
V = signal.filtfilt(b,a,V,0)
V = V/V.mean(0)

In [None]:
# average the broadband power across all tongue and hand trials
nt, nchan = V.shape
nstim = len(data1['t_on'])

trange = np.arange(0, 2000)
ts = data1['t_on'][:, np.newaxis] + trange
V_epochs = np.reshape(V[ts, :], (nstim, 2000, nchan))

V_tongue = (V_epochs[data1['stim_id'] == 11]).mean(0)
V_hand = (V_epochs[data1['stim_id'] == 12]).mean(0)

In [None]:
# let's find the electrodes that distinguish tongue from hand movements
# note the behaviors happen some time after the visual cue

plt.figure(figsize=(20, 10))
for j in range(46):
  ax = plt.subplot(5, 10, j+1)
  plt.plot(trange, V_tongue[:, j])
  plt.plot(trange, V_hand[:, j])
  plt.title('ch%d'%j)
  plt.xticks([0, 1000, 2000])
  plt.ylim([0, 4])
plt.show()

In [None]:
isort = np.argsort(data1['stim_id'])

V_epochs[isort, :, 20].shape

In [None]:
# let's look at all the trials for electrode 20 that has a good response to hand movements
# we will sort trials by stimulus id
plt.subplot(1, 3, 1)
isort = np.argsort(data1['stim_id'])
plt.imshow(V_epochs[isort, :, 20].astype('float32'),
           aspect='auto',
           vmax=7, vmin=0,
           cmap='magma')
plt.colorbar()
plt.show()

In [None]:
# Electrode 42 seems to respond to tongue movements
isort = np.argsort(data1['stim_id'])
plt.subplot(1, 3, 1)
plt.imshow(V_epochs[isort, :, 42].astype('float32'),
           aspect='auto',
           vmax=7, vmin=0,
           cmap='magma')
plt.colorbar()
plt.show()

## OLD METHOD

In [None]:
alldat = np.load(fname, allow_pickle=True)['dat']

real = alldat[0][0]
imagine = alldat[0][1]

In [None]:
processed_real = process.preprocess(real)
nt_pr, nchan_pr = processed_real.shape
nstim_r = len(real['t_on'])

trange = np.arange(0, 2000)
ts = real['t_on'][:,np.newaxis] + trange

V_epochs_pr = np.reshape(processed_real[ts, :], (nstim_r, 2000, nchan_pr))

V_hand_real = (V_epochs_pr[real['stim_id']==12])  #Change this to 11 to get tongue data and 12 to get hand data


##### This is the voltage data and you could change the trial from here
hand_real_data = V_hand_real.mean(0)





processed_imagine = process.preprocess(imagine)
nt_pi, nchan_pi = processed_imagine.shape
nstim_i = len(imagine['t_on'])

trange = np.arange(0, 2000)
ts = imagine['t_on'][:,np.newaxis] + trange

V_epochs_pi = np.reshape(processed_imagine[ts, :], (nstim_i, 2000, nchan_pi))

V_hand_imagine = (V_epochs_pi[imagine['stim_id']==12])



##### This is the voltage data and you could change the trial from here
hand_imagine_data = V_hand_imagine.mean(0)

In [None]:
hand_imagine_data = V_hand_imagine.mean(0)
hand_real_data = V_hand_real.mean(0)

In [None]:
V_hand_imagine.shape

In [None]:
plots.all_channels2(hand_real_data, hand_imagine_data, trange)
plots.singlechannel1(V_hand_imagine[0], 13, trange)
plots.singlechannel2(hand_real_data, hand_imagine_data, 1, trange)

## GAN

In [None]:
import torch
import torch.nn as nn
from torch.utils import data

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True


np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)

In [None]:
#The dataset
specify_channel = 20

cnn_inputs_real = []
for i in range(30):
    abc = V_hand_real[i]
    cnn_inputs_real.append(abc[:, specify_channel])
#cnn_inputs_real is for real data. 1 is the label for real
cnn_inputs_real = np.array(cnn_inputs_real)


labels_real = np.ones((30))
labels_real = labels_real.reshape(30,1)


cnn_inputs_imagine = []
for i in range(30):
    abc = V_hand_imagine[i]
    cnn_inputs_imagine.append(abc[:, specify_channel])
#cnn_inputs_imagine is for imaginary data. 0 is the label for imaginary
cnn_inputs_imagine = np.array(cnn_inputs_imagine)

labels_imagine = np.zeros((30))
labels_imagine = labels_imagine.reshape(30,1) 

In [None]:
trainX = np.append(cnn_inputs_real[0:20], cnn_inputs_imagine[0:20], axis=0)
trainy = np.append(labels_real[0:20], labels_imagine[0:20], axis=0)

trainX = trainX.reshape([40,2000,1])

testX = np.append(cnn_inputs_real[20:30], cnn_inputs_imagine[20:30], axis=0)
testy = np.append(labels_real[20:30], labels_imagine[20:30], axis=0)
testX = testX.reshape([20,2000,1])

In [None]:
class EEGDataset(data.Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index], self.labels[index]

In [None]:
dataset = EEGDataset()
