# EEG-GAN

## Initialization

### Imports

In [None]:
import os, requests
from matplotlib import rcParams
from matplotlib import pyplot as plt
import numpy as np
from scipy import signal
import pandas as pd

### Dataset

In [None]:
fname = 'motor_imagery.npz'
url = "https://osf.io/ksqv8/download"

if not os.path.isfile(fname):
  try:
    r = requests.get(url)
  except requests.ConnectionError:
    print("!!! Failed to download data !!!")
  else:
    if r.status_code != requests.codes.ok:
      print("!!! Failed to download data !!!")
    else:
      with open(fname, "wb") as fid:
        fid.write(r.content)

### Figure Settings

rcParams['figure.figsize'] = [20, 4]
rcParams['font.size'] = 15
rcParams['axes.spines.top'] = False
rcParams['axes.spines.right'] = False
rcParams['figure.autolayout'] = True

### Helper Functions

In [None]:
class process():
    def __init__():
        return

    def preprocess(data):
        V = data['V'].astype('float32')
        b, a = signal.butter(3, [50], btype = 'high', fs=1000)
        V = signal.filtfilt(b,a,V,0)
        V = np.abs(V)**2
        b, a = signal.butter(3, [10], btype = 'low', fs=1000)
        V = signal.filtfilt(b,a,V,0)
        V = V/V.mean(0)
        return V

class plots():
    def __init__():
        return

    def singlechannel1(data, channel, trange):
        plt.figure(figsize=(20,10))
        plt.plot(trange, data[:,channel])
        plt.title('ch%d'%channel)
        plt.xticks([0, 1000, 2000])
        plt.ylim([0, 4])

    def singlechannel2(data, data2, channel, trange):
        plt.figure(figsize=(20,10))
        plt.plot(trange, data[:,channel])
        plt.plot(trange, data2[:,channel])
        plt.title('ch%d'%channel)
        plt.xticks([0, 1000, 2000])
        plt.ylim([0, 4])

    def all_channels1(data, trange):
        plt.figure(figsize=(20,10))
        for j in range(46):
            ax = plt.subplot(5,10,j+1)
            plt.plot(trange, data[:,j])
            plt.title('ch%d'%j)
            plt.xticks([0, 1000, 2000])
            plt.ylim([0, 4])

    def all_channels2(data, data2, trange):
        plt.figure(figsize=(20,10))
        for j in range(46):
            ax = plt.subplot(5,10,j+1)
            plt.plot(trange, data[:,j])
            plt.plot(trange, data2[:,j])
            plt.title('ch%d'%j)
            plt.xticks([0, 1000, 2000])
            plt.ylim([0, 4])

## GAN

In [None]:
import torch
import torch.nn as nn
from torch.utils import data

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True


np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)

In [None]:
#The dataset
specify_channel = 20

cnn_inputs_real = []
for i in range(30):
    abc = V_hand_real[i]
    cnn_inputs_real.append(abc[:, specify_channel])
#cnn_inputs_real is for real data. 1 is the label for real
cnn_inputs_real = np.array(cnn_inputs_real)


labels_real = np.ones((30))
labels_real = labels_real.reshape(30,1)


cnn_inputs_imagine = []
for i in range(30):
    abc = V_hand_imagine[i]
    cnn_inputs_imagine.append(abc[:, specify_channel])
#cnn_inputs_imagine is for imaginary data. 0 is the label for imaginary
cnn_inputs_imagine = np.array(cnn_inputs_imagine)

labels_imagine = np.zeros((30))
labels_imagine = labels_imagine.reshape(30,1) 

In [None]:
trainX = np.append(cnn_inputs_real[0:20], cnn_inputs_imagine[0:20], axis=0)
trainy = np.append(labels_real[0:20], labels_imagine[0:20], axis=0)

trainX = trainX.reshape([40,2000,1])

testX = np.append(cnn_inputs_real[20:30], cnn_inputs_imagine[20:30], axis=0)
testy = np.append(labels_real[20:30], labels_imagine[20:30], axis=0)
testX = testX.reshape([20,2000,1])

In [None]:
class EEGDataset(data.Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index], self.labels[index]

In [None]:
dataset = EEGDataset()