In [None]:
# importing dependencies
from lib.eeg_transformer import *
from lib.train import *

import numpy as np
import os
import matplotlib.pyplot as plt
from scipy.signal import detrend, filtfilt, butter, iirnotch, welch
import json
from tqdm import trange

In [None]:
# torch.nn is a module that implements varios useful functions and functors to implement flexible and highly
# customized neural networks. We will use nn to define neural network modules, different kinds of layers and
# diffrent loss functions
import torch.nn as nn
# torch.nn.functional implements a large variety of activation functions and functional forms of different
# neural network layers. Here we will use it for activation functions.
import torch.nn.functional as F
# torch is the Linear Algebra / Neural Networks library
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [None]:
# eeg signals profile
seq_len = 20
eeg_channels = 32
eeg_size = seq_len*eeg_channels*50
embedding_channels = 256
embedding_size = seq_len*embedding_channels*3

In [None]:
class EEGDataset(Dataset):
    def __init__(self, npz_file, experiment='rpoint', params=None, split='train'):        
        # Determine the type of task
        if experiment in ['rwalk', 'rpoint']:
            self.supervised = True
        elif experiment in ['imagine', 'music', 'speech', 'video']:
            self.supervised = False

        data = np.load(npz_file) # Load data
        self.eeg = data['EEG']
        if self.supervised == True:
            self.embedding = data['E']
        self.channels = data['channels']  
        self.split = split

        if params == None:
            params = {}
            params['num_atoms'] = 1
            params['standardise'] = True
            params['pipeline'] = None
            params['detrend_window'] = 10
            params['sampling_freq'] = 500
            params['line_freq'] = 60
            params['Q_notch'] = 30
            params['low_cutoff_freq'] = 2.
            params['high_cutoff_freq'] = 45.
            params['flatten'] = False
            params['split_ratio'] = [0.7,0,0.3] # train, validation, test
        self.params = params   

        # Split dataset and assign to variables
        self.split_dataset()       
        # Change to specified length and reshape
        if self.eeg.shape[0]%self.params['num_atoms'] != 0:
            end = (self.eeg.shape[0]//self.params['num_atoms'])*self.params['num_atoms']
            self.eeg = self.eeg[:end]
            if self.supervised == True:                
                self.embedding = self.embedding[:end]
            
        new_atom_size = (self.eeg.shape[0]//self.params['num_atoms'])
        a, b, c = self.eeg.shape
        self.eeg = self.eeg.reshape(new_atom_size, b*self.params['num_atoms'], c)
        if self.supervised == True:          
            a, b, c = self.embedding.shape
            self.embedding = self.embedding.reshape(new_atom_size, b*self.params['num_atoms'], c)         

        self.size = len(self.eeg)
        self.num_channels = self.eeg.shape[2]
        self.process() # Process parameters and assign to variables

    def split_dataset(self):
        # Determine split indices
        lims = np.dot(self.params['split_ratio'], self.eeg.shape[0])
        lim_ints = [int(lim) for lim in lims]
        lim_ints = np.cumsum(lim_ints)
                
        eeg_sets = {'train': self.eeg[0:lim_ints[0]],
                    'val': self.eeg[lim_ints[0]:lim_ints[1]],
                    'test': self.eeg[lim_ints[1]:]
                   }

        if self.supervised == True:          
            emb_sets = {'train': self.embedding[0:lim_ints[0]],
                        'val': self.embedding[lim_ints[0]:lim_ints[1]],
                        'test': self.embedding[lim_ints[1]:]
                       } 

        # Assign particular split
        self.eeg = eeg_sets[self.split]
        if self.supervised == True:          
            self.embedding = emb_sets[self.split]
        return

    def process(self):
        if self.params['standardise'] == True:
            a,b,c = self.eeg.shape
            eeg_n = self.eeg.reshape(a*b,c)
            mean, std = np.mean(eeg_n, axis = 0), np.std(eeg_n, axis = 0)
            eeg_n = (eeg_n-mean)/std
            eeg_n = eeg_n.reshape(a,b,c)
            self.eeg = eeg_n          
        return                        

    def __len__(self):
        return self.size

    def __getitem__(self, i):
        eeg_i = self.eeg[i]
        if self.supervised == True:          
            emb_i = self.embedding[i]        
        
        if self.params['pipeline'] != None:            
            for step in self.params['pipeline']:

                if step == 'rereference':
                    ref = np.mean(eeg_i, axis=1)
                    for ch in range(eeg_i.shape[1]):
                        eeg_i[:, ch] = eeg_i[:, ch] - ref                                       
                    
                elif step == 'detrend':
                    eeg_i = detrend(eeg_i, axis=0, 
                                    bp=np.arange(0, eeg_i.shape[0], 
                                                 self.params['detrend_window'] * self.params['sampling_freq'], 
                                                 dtype="int32"))                 

                elif step == 'remove_line_freq':                   
                    b, a = iirnotch(self.params['line_freq'], 
                                    self.params['Q_notch'], 
                                    self.params['sampling_freq'])  # scipy 1.2.0
                    for ch in range(self.num_channels):
                        filtered = filtfilt(b, a, eeg_i[:, ch])
                        eeg_i[:, ch] = filtered

                elif step == 'bandpassfilter': 
                    nyq = 0.5 * self.params['sampling_freq']
                    normal_cutoff1 = self.params['low_cutoff_freq'] / nyq
                    normal_cutoff2 = self.params['high_cutoff_freq'] / nyq
                    [b, a] = butter(5, [normal_cutoff1, normal_cutoff2], btype='bandpass', analog=False)

                    for ch in range(self.num_channels):
                        filtered = filtfilt(b, a, eeg_i[:, ch])
                        eeg_i[:, ch] = filtered 

        if self.params['flatten'] == True:
            eeg_i = eeg_i.flatten()
            if self.supervised == True:              
                emb_i = emb_i.flatten()

        if self.supervised == True:  
            to_return = (torch.from_numpy(eeg_i).float(), torch.from_numpy(emb_i).float())
        else:
            to_return = torch.from_numpy(eeg_i).float()
        return to_return

In [None]:
params = {}
params['num_atoms'] = seq_len
params['standardise'] = True
params['pipeline'] = ['rereference', 'detrend', 'bandpassfilter']
params['detrend_window'] = 50
params['sampling_freq'] = 500
params['Q_notch'] = 30
params['low_cutoff_freq'] = 0.1
params['high_cutoff_freq'] = 249.
params['flatten'] = False
params['split_ratio'] = [0.9,0,0.1]

# split training and testing set
batch_size = 10
eeg_dataset_train = EEGDataset('rwalk.npz','rwalk', params, split='train')
dataloader_train = DataLoader(eeg_dataset_train, batch_size=batch_size, shuffle=True)
eeg_dataset_test = EEGDataset('rwalk.npz','rwalk', params, split='test')
dataloader_test = DataLoader(eeg_dataset_test, batch_size=batch_size, shuffle=True)

In [None]:
opt = {}
opt['Transformer-layers'] = 2
opt['Model-dimensions'] = 256
opt['feedford-size'] = 512
opt['headers'] = 8
opt['dropout'] = 0.1
opt['src_d'] = eeg_channels # input dimension
opt['tgt_d'] = embedding_channels # output dimension
opt['timesteps'] = 60

In [None]:
criterion = nn.MSELoss() # mean squared error
# setup model using hyperparameters defined above
model = make_model(opt['src_d'],opt['tgt_d'],opt['Transformer-layers'],opt['Model-dimensions'],opt['feedford-size'],opt['headers'],opt['dropout'])
# setup optimization function
model_opt = NoamOpt(model_size=opt['Model-dimensions'], factor=1, warmup=400,
        optimizer = torch.optim.Adam(model.parameters(), lr=0.015, betas=(0.9, 0.98), eps=1e-9))
total_epoch = 2000
train_losses = np.zeros(total_epoch)
test_losses = np.zeros(total_epoch)

for epoch in range(total_epoch):
    model.train()
    train_loss = run_epoch(data_gen(dataloader_train), model, 
              SimpleLossCompute(model.generator, criterion, model_opt))
    train_losses[epoch]=train_loss

    if (epoch+1)%10 == 0:
        torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': model_opt.optimizer.state_dict(),
                    'loss': train_loss,
                    }, 'model_checkpoint/'+str(epoch)+'.pth')            
        torch.save(model, 'model_save/model%d.pth'%(epoch)) # save the model

    model.eval() # test the model
    test_loss = run_epoch(data_gen(dataloader_test), model, 
            SimpleLossCompute(model.generator, criterion, None))
    test_losses[epoch] = test_loss
    print('Epoch[{}/{}], train_loss: {:.6f},test_loss: {:.6f}'
              .format(epoch+1, total_epoch, train_loss, test_loss))

In [None]:
# choose a pair of data from the test set
test_x, test_y = eeg_dataset_test.eeg[1],eeg_dataset_test.embedding[1]
# make a prediction then compare it with its true output
test_out, true_out = output_prediction(model, test_x, test_y, max_len=opt['timesteps'], start_symbol=1,output_d=opt['tgt_d'])