In [1]:
import os, sys
os.chdir('/home/seigyo/Documents/pytorch/brain_decoder')
sys.path.append(os.pardir)
import numpy as np
from numpy.random import RandomState
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch import optim
import mne
from mne.io import concatenate_raws
from mymodule.utils import data_loader, evaluator
from mymodule.layers import LSTM, Residual_block, Res_net, Wavelet_cnn, NlayersSeqConvLSTM
from mymodule.trainer import Trainer
from mymodule.optim import Eve, YFOptimizer
from sklearn.utils import shuffle
from tensorboardX import SummaryWriter
from load_data import get_data, get_data_multi, get_crops, get_crops_multi
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

epochs = 300
batch_size = 10
cv_splits = 5
torch.manual_seed(1214)
torch.cuda.manual_seed_all(1214)
num_of_subjects = 30

# criterion = torch.nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


In [2]:
def get_data(id=1, event_code=[6,10,14], filter=[0.5, 36], t=[1, 4.1]):
    # 5,6,7,10,13,14 are codes for executed and imagined hands/feet
    subject_id = id
    event_codes = event_code

    # This will download the files if you don't have them yet,
    # and then return the paths to the files.
    physionet_paths = mne.datasets.eegbci.load_data(subject_id, event_codes)

    # Load each of the files
    parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto', verbose='WARNING')
             for path in physionet_paths]

    # Concatenate them
    raw = concatenate_raws(parts)

    # bandpass filter
    if filter != None:
        raw.filter(filter[0], filter[1], fir_design='firwin', skip_by_annotation='edge')
    else:
        pass

    # Find the events in this dataset
    events = mne.find_events(raw, shortest_event=0, stim_channel='STI 014')

    # Use only EEG channels
    eeg_channel_inds = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                       exclude='bads')

    # Extract trials, only using EEG channels
    epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=t[0], tmax=t[1], proj=False, picks=eeg_channel_inds,
                    baseline=None, preload=True)
    # change time length
    # epochs_train = epochs.copy().crop(tmin=1., tmax=2.)


    # Convert data from volt to millivolt
    # Pytorch expects float32 for input and int64 for labels.
    X = (epoched.get_data() * 1e6).astype(np.float32)
    y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1
    return X, y

In [3]:
X, y = get_data()
X.shape

Removing orphaned offset at the beginning of the file.
89 events found
Events id: [1 2 3]
45 matching events found
Not setting metadata
Loading data for 45 events and 497 original time points ...
0 bad epochs dropped


(45, 64, 497)

In [4]:
X_spat = np.zeros((45, X.shape[-1], 1, 5, 7))
X_spat.shape

(45, 497, 1, 5, 7)

In [5]:
X_spat[:, :, 0, 1, 0] = X[:, 0, :]
X_spat[:, :, 0, 1, 1] = X[:, 1, :]
X_spat[:, :, 0, 1, 2] = X[:, 2, :]
X_spat[:, :, 0, 1, 3] = X[:, 3, :]
X_spat[:, :, 0, 1, 4] = X[:, 4, :]
X_spat[:, :, 0, 1, 5] = X[:, 5, :]
X_spat[:, :, 0, 1, 6] = X[:, 6, :]

X_spat[:, :, 0, 2, 0] = X[:, 7, :]
X_spat[:, :, 0, 2, 1] = X[:, 8, :]
X_spat[:, :, 0, 2, 2] = X[:, 9, :]
X_spat[:, :, 0, 2, 3] = X[:, 10, :]
X_spat[:, :, 0, 2, 4] = X[:, 11, :]
X_spat[:, :, 0, 2, 5] = X[:, 12, :]
X_spat[:, :, 0, 2, 6] = X[:, 13, :]

X_spat[:, :, 0, 3, 0] = X[:, 14, :]
X_spat[:, :, 0, 3, 1] = X[:, 15, :]
X_spat[:, :, 0, 3, 2] = X[:, 16, :]
X_spat[:, :, 0, 3, 3] = X[:, 17, :]
X_spat[:, :, 0, 3, 4] = X[:, 18, :]
X_spat[:, :, 0, 3, 5] = X[:, 19, :]
X_spat[:, :, 0, 3, 6] = X[:, 20, :]

X_spat[:, :, 0, 0, 0] = X[:, 30, :]
X_spat[:, :, 0, 0, 1] = X[:, 31, :]
X_spat[:, :, 0, 0, 2] = X[:, 32, :]
X_spat[:, :, 0, 0, 3] = X[:, 33, :]
X_spat[:, :, 0, 0, 4] = X[:, 34, :]
X_spat[:, :, 0, 0, 5] = X[:, 35, :]
X_spat[:, :, 0, 0, 6] = X[:, 36, :]

X_spat[:, :, 0, 4, 0] = X[:, 47, :]
X_spat[:, :, 0, 4, 1] = X[:, 48, :]
X_spat[:, :, 0, 4, 2] = X[:, 49, :]
X_spat[:, :, 0, 4, 3] = X[:, 50, :]
X_spat[:, :, 0, 4, 4] = X[:, 51, :]
X_spat[:, :, 0, 4, 5] = X[:, 52, :]
X_spat[:, :, 0, 4, 6] = X[:, 53, :]

In [7]:
class ConvLSTM(nn.Module):
  def __init__(self):
    super(ConvLSTM, self).__init__()
    self.relu = nn.LeakyReLU()
    self.convlstm = NlayersSeqConvLSTM(input_channels=1,
                        hidden_channels=[32, 64],
                        kernel_sizes=[3,3])
    self.conv = nn.Sequential(nn.Conv2d(64,128,(3,3)),
                      nn.LeakyReLU(),
                      nn.BatchNorm2d(128),
                      nn.Conv2d(128,256,(3,5)),
                      nn.LeakyReLU(),
                      nn.BatchNorm2d(256)).cuda()
    self.dropout = nn.Dropout(0.5)
    self.linear = nn.Linear(256, 2)
    
    
  def forward(self, x):
    h, _ = self.convlstm(x)
    h = self.conv(h[:,-1,:,:,:])
    h = h.view(-1, h.size(1))
    h = self.dropout(h)
    h = self.linear(h)
    return h

model = ConvLSTM().cuda()
# model = NlayersSeqConvLSTM(input_channels=1,
#                     hidden_channels=[32, 64],
#                     kernel_sizes=[3,3]).cuda()

In [8]:
# y, outputs = model(Variable(torch.from_numpy(X_spat)).float().cuda())
# y[:,-1,:,:,:].size()

y = model(Variable(torch.from_numpy(X_spat)).float().cuda())
y.size()

torch.Size([45, 2])

In [None]:
model2 = nn.Sequential(nn.Conv2d(64,128,(3,3)),
                      nn.LeakyReLU(),
                      nn.BatchNorm2d(128),
                      nn.Conv2d(128,256,(3,5)),
                      nn.LeakyReLU(),
                      nn.BatchNorm2d(256)).cuda()

o = model2(y[:,-1,:,:,:])
o.size()

In [9]:

def cv_train(model_class, criterion_class, optimizer_class, X, y,
             epoch=100, num_of_cv=10, batch_size=16):
    kf = KFold(n_splits=num_of_cv, shuffle=True)
    accuracy = []
    for train_idx, val_idx in kf.split(X=X, y=y):
        train_x, val_x = X[train_idx], X[val_idx]
        train_y, val_y = y[train_idx], y[val_idx]
        train_loader = data_loader(train_x, train_y, batch_size=batch_size,
                           shuffle=True, gpu=False)
        val_loader = data_loader(val_x, val_y, batch_size=batch_size)
        writer = SummaryWriter()
        model = model_class().cuda()
        criterion = criterion_class()
        optimizer = optimizer_class(model.parameters(), lr=1e-4)
        trainer = Trainer(model, criterion, optimizer,
                  train_loader, val_loader,
                  val_num=1, early_stopping=2,
                  writer=writer, gpu=True)
        trainer.run(epochs=epoch)
        accuracy.append(trainer.val_best_acc)
    return accuracy


all_accs_list = []
all_mean_list = []
all_var_list = []

for idx in range(num_of_subjects):
    X, y = get_data(id=idx+1, event_code=[6,10,14], filter=[0.5, 30], t=[0., 4])
    X = X.reshape(X.shape[0], 1, X.shape[1], X.shape[2]).transpose(0,1,3,2)

#     model = Conv_lstm()
#     model.cuda()

    acc = cv_train(ConvLSTM, torch.nn.CrossEntropyLoss,
                   torch.optim.Adam, X_spat, y, epoch=epochs,
                   num_of_cv=cv_splits, batch_size=batch_size)

    mean = np.mean(acc)
    var = np.var(acc)
    print('subject{}   mean_acc:{}, var_acc:{}'.format(idx+1, mean, var))

    all_accs_list.append(acc)
    all_mean_list.append(mean)
    all_var_list.append(var)

all_mean = np.mean(all_accs_list)
all_var = np.var(all_accs_list)

print('all subjects  mean_acc:{}, var_acc:{}'.format(all_mean, all_var))

Removing orphaned offset at the beginning of the file.
89 events found
Events id: [1 2 3]
45 matching events found
Not setting metadata
Loading data for 45 events and 641 original time points ...
0 bad epochs dropped
----------start training----------
epoch:1, tr_loss:0.6228, tr_acc:0.5278,   val_loss:0.3525, val_acc:0.3333
epoch:2, tr_loss:0.4660, tr_acc:0.6944,   val_loss:0.3527, val_acc:0.3333
epoch:3, tr_loss:0.3752, tr_acc:0.8333,   val_loss:0.3528, val_acc:0.3333
epoch:4, tr_loss:0.3534, tr_acc:0.9167,   val_loss:0.3531, val_acc:0.3333
epoch:5, tr_loss:0.3693, tr_acc:0.8889,   val_loss:0.3525, val_acc:0.3333
epoch:6, tr_loss:0.3499, tr_acc:0.8889,   val_loss:0.3512, val_acc:0.3333
epoch:7, tr_loss:0.2827, tr_acc:0.9722,   val_loss:0.3483, val_acc:0.3333
epoch:8, tr_loss:0.2576, tr_acc:0.9722,   val_loss:0.3467, val_acc:0.4444
epoch:9, tr_loss:0.2470, tr_acc:0.9444,   val_loss:0.3445, val_acc:0.6667
epoch:10, tr_loss:0.2555, tr_acc:0.9444,   val_loss:0.3434, val_acc:0.6667
epoch:1

KeyboardInterrupt: 