In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [4]:
#import transformers
#from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup
import warnings
warnings.filterwarnings('ignore')

from collections import defaultdict
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

import h5py
import json

import math
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import pandas as pd

from random import shuffle,randint

from scipy.fftpack import fft
import scipy.stats as stats
import scipy.io as scio
from scipy.io import wavfile
from scipy import signal
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report


from tqdm import tqdm
from torchmetrics.functional import word_error_rate as wer
from textwrap import wrap
import textgrid as tg
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch
from torch import nn, optim, einsum
import torch.utils.data as Data
from torch.utils.data import DataLoader,TensorDataset
import wave


RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = 'cpu'
date='1118'
print(date)
subject = 'PA4'

1118


In [7]:
#FUNCTIONS USED FOR LOAD DATA
def get_timelocked_activity(times, hg, back, forward, hz=False):
    '''
    Get time-locked activity.

    Parameters:
    - times (array-like): List of timepoints in seconds.
    - hg (numpy.ndarray): High gamma array, shaped (elecs_num, whole_duration*hz).
    - back (float): Start time before the timepoints.
    - forward (float): End time after the timepoints.
    - hz (float, optional): Sampling rate of High gamma. Default is False.

    Returns:
    - Y_mat (numpy.ndarray): Time-locked activity array, shaped (trial_num, elecs_num, selected_duration*hz).
    - back (int): Start time in samples (back*hz).
    - forward (int): End time in samples (forward*hz).
    '''
    if hz:
        times = (times*hz).astype(int)
        back = int(back*hz)
        forward = int(forward*hz)
    times = times[times - back > 0]
    times = times[times + forward < hg.shape[1]]

    Y_mat = np.zeros((len(times),hg.shape[0], int(back + forward)), dtype=float)

    for i, index in enumerate(times):
        Y_mat[i, :, :] = hg[:, int(index-back):int(index+forward)]

    return Y_mat,back,forward

def read_ecog_mat(back,forward,mat,channelNum,key_elecs=[],key_label=[],oppolabel=False,key_sentence=[],opposentence=False, key_paragraph=[],oppoparagraph=False,
                  hz=False, block=4):
    '''
    get ecog_mat based on defined parameters
    
    Parameters:
    - mat : shaped(7,num_of_timepoints)
    - hg (numpy.ndarray): High gamma array, shaped (elecs_num, whole_duration*hz).
    - back (float): Start time before the timepoints.
    - forward (float): End time after the timepoints.
    - hz (float, optional): Sampling rate of High gamma. Default is False.
    channelNum:number of electrodes 128 or 256
    block:number of blocks included, usually from 1 to 4
    key_elecs: selected elecs
    kay_labels:selected labels, opplabel means oppose the key_label
    sentence and paragraphs similarly
    '''

    key_elecs = np.array(key_elecs)
    key_index = np.ones(len(mat[0,:])).astype('bool')
    
    sentence_index = np.ones(len(mat[0,:])).astype('bool')
    label_index = np.ones(len(mat[0,:])).astype('bool')
    paragraph_index = np.ones(len(mat[0,:])).astype('bool')  
    
    if key_label:
        print('select label:'+str(key_label),end=' ')
        label_index=~label_index
        for i in key_label:
            temp_index = mat[0,:] ==i
            label_index = np.logical_or(label_index, temp_index)
        if oppolabel:
            print('oppo')
            label_index=~label_index
    key_index = np.logical_and(label_index, key_index)
    if key_sentence:
        print('select sentence:'+str(key_sentence),end=' ')
        sentence_index=~sentence_index
        for i in key_sentence:
            temp_index = mat[2,:] ==i
            sentence_index = np.logical_or(sentence_index, temp_index)
        if opposentence:
            print('oppo')
            sentence_index=~sentence_index
    key_index = np.logical_and(sentence_index, key_index)        
    if key_paragraph:
        print('select paraqraph:'+str(key_paragraph),end=' ')
        paragraph_index=~paragraph_index
        for i in key_paragraph:
            temp_index = mat[3,:] ==i
            paragraph_index = np.logical_or(paragraph_index, temp_index)
        if oppoparagraph:
            print('oppo')
            paragraph_index=~paragraph_index
    key_index = np.logical_and(paragraph_index, key_index)      
    mat = mat[:,key_index]
       
    
    if key_elecs.any():
        ecog_mat = np.zeros((len(mat[0,:]),len(key_elecs), int((back+forward)*hz)), dtype=float)
        print('select elecs:'+str(key_elecs),end=' ')
    else:
        ecog_mat = np.zeros((len(mat[0,:]),channelNum, int((back+forward)*hz)), dtype=float)
        
    for i in range(block):
        block_index= mat[4,:]==i+1
        temp_time_list=mat[1,block_index]
        ecogData=scio.loadmat(path_raw+str(i+1)+filterType+'.mat')#################################
        ecogData=np.array(ecogData['bands'])*10e4
        #print('###',np.nanmean(np.max(ecogData,axis=1)-np.min(ecogData,axis=1)),np.array(np.max(ecogData,axis=1)).shape)######################################################################################
        temp_ecog,back_duration,forward_duration = get_timelocked_activity(times=temp_time_list, hg=ecogData, back=back, forward=forward, hz=hz)
        if key_elecs.any():
            key_elecs = np.array(key_elecs)
            ecog_mat[block_index] = temp_ecog[:,key_elecs,:]
        else:
            ecog_mat[block_index] = temp_ecog

    print('read_ecog:',ecog_mat.shape)
    return ecog_mat,back_duration,forward_duration,mat

'''
EXAMPLE:
ecog_mat,back_duration,forward_duration,mat = read_ecog_mat(back,forward,mat,channelNum,key_elecs=[],
                key_label=[],oppolabel=False,key_sentence=[],opposentence=False, key_paragraph=[],oppoparagraph=False,
                hz=False, block=4)
'''

'\nEXAMPLE:\necog_mat,back_duration,forward_duration,mat = read_ecog_mat(back,forward,mat,channelNum,key_elecs=[],\n                key_label=[],oppolabel=False,key_sentence=[],opposentence=False, key_paragraph=[],oppoparagraph=False,\n                hz=False, block=4)\n'

In [None]:
#FUNCTIONS USED FOR TRAIN AND VALIDATE THE MODEL

def evaluater(Mat, prob_list, loss_func = nn.CrossEntropyLoss(), unit = 3):
    
    if stage in ['sylb','onset_sylb']:
        List = sylbList
        row = 0
    elif stage in ['tone','onset_tone']:
        List = toneList
        row = 5
        
    real_indices = np.where(~np.isnan(Mat[0, :]))[0]
    Mat =  Mat[:, real_indices]
    prob_list = np.array(prob_list)[np.array(real_indices)]
    
    acc_list=[]
    loss_list=[]
    for trial in set(Mat[unit,:]):
        trial_index = Mat[3,:]==trial
        temp_label = Mat[row,trial_index]
        temp_prob = prob_list[trial_index]

        loss = loss_func(torch.tensor(temp_prob), torch.tensor(temp_label, dtype=torch.long))
        loss = loss / len(temp_label)
        temp_out = np.argmax(temp_prob, axis=1)
        correct = np.sum(temp_out == np.array(temp_label))
        acc = correct / len(temp_label)

        acc_list.append(acc)
        loss_list.append(loss)
    
    if plot:
        plotCM(List, acc_list, Mat[row,:], np.argmax(prob_list, axis=1))

    return acc_list, loss_list

def CV_datasets(back,forward,mat,key_elecs,row,CV_list,unbalance=False,List=False,augmented=False):
    #split data to cross validation combinations, and calculate the weight of each labels
    CV_datasets = []
    CV_augmented_datasets=[]
    count_total=[]

    for CVs in range(6):
        print('mat:',mat.shape)
        test_ecog_mat,_,_,test_mat= read_ecog_mat(back=back,forward=forward,mat=mat,channelNum=channelNum,key_elecs=key_elecs,
                key_label=[],oppolabel=False,key_sentence=[],opposentence=False, key_paragraph=CV_list[CVs],oppoparagraph=False,
                hz=hz, block=4)

        test_label = test_mat[row,:]
        
        if unbalance == True:
            count=np.bincount(test_label.astype('int'))
            if len(count)!=len(List):
                #triggerrd when there are some blocks without all the labels
                print('not a full set!')
            print(count)
            count_total.append(count)
        test_ecog_mat=torch.FloatTensor(test_ecog_mat)
        test_label=torch.tensor(test_label,dtype=torch.long)
        testdataset = Data.TensorDataset(test_ecog_mat, test_label)    
        CV_datasets.append(testdataset)
        

        if augmented is not False:
            aug_mat = stammer (mat=test_mat, multiple_time=augmented,multiple_range=0.1,shift=-0.05)
            aug_ecog_mat,_,_,test_mat= read_ecog_mat(back=back,forward=forward,mat=mat,channelNum=channelNum,key_elecs=key_elecs,
                key_label=[],oppolabel=False,key_sentence=[],opposentence=False, key_paragraph=CV_list[CVs],oppoparagraph=False,
                hz=hz, block=4)
            aug_label = aug_mat[row,:]
            aug_ecog_mat=torch.FloatTensor(aug_ecog_mat)
            aug_label=torch.tensor(aug_label,dtype=torch.long)
            augdataset = Data.TensorDataset(aug_ecog_mat, aug_label) 
            CV_augmented_datasets.append(augdataset)
            
    if unbalance == True:
        count_total = np.array(count_total)
        
    return CV_datasets,count_total,CV_augmented_datasets

'''
EXAMPLES:
CV_onset_datasets,count_total,_ = CV_datasets(back=overt_back,
                                            forward=overt_forward,
                                            mat=overt_mat,
                                            key_elecs=overt_elecs,
                                            row=0,
                                            CV_list,
                                            unbalance=True,List=stateList,
                                            augmented=False)
class_weight = np.sum(count_total,axis=0)
'''
'''
CV_sylb_datasets,_,_ = CV_datasets(back=sylb_back,
                            forward=sylb_forward,
                            mat=sylb_mat,
                            key_elecs=sylb_elecs,
                            row=0,
                            CV_list,
                            unbalance=False,List=False,
                            augmented=False)
'''
'''
CV_clip_datasets,_,_ = CV_datasets(back=sylb_back,
                            forward=sylb_forward,
                            mat=clipTimeMat,
                            key_elecs=sylb_elecs,
                            row=0,
                            CV_list,
                            unbalance=False,List=False,
                            augmented=False)

'''
def CV_train_ENN(CV_datasets, CV_pred_only_datasets, lr, batch_size, patience,
                 class_weight = False, channelNum=256,plot = False):
    #this is used for cross validation on ensumble models 
    test_out_prob_list=[]
    pred_only_prob_list=[]
    torch.cuda.empty_cache()
    for test_CV in range(6):#tqdm(
        CV_train_val_datasets = CV_datasets.copy()
        CV_test_datasets = [CV_train_val_datasets.pop(test_CV),]
        pred_only_datasets = [CV_pred_only_datasets[test_CV],]
        #model_list=[]
        predicted_prob_list = []
        predicted_only_prob_list= []
        for CVs in range(5):
            CV_train_datasets = CV_train_val_datasets.copy()
            CV_val_datasets = [CV_train_datasets.pop(CVs),]
            #label,acc,predicted,predicted_prob, predicted_only,predicted_only_prob,model
            _,_,_,predicted_prob,_,predicted_only_prob,CV_model,loss_func = train( 
                                        lr=lr, batch_size=batch_size, EPOCH=EPOCH, patience=patience, 
                                        CV_train_datasets = CV_train_datasets, #CV_train_val_datasets, #
                                        CV_val_datasets =  CV_val_datasets, #False, #
                                        CV_test_datasets = CV_test_datasets,#False, 
                                        pred_only_datasets = pred_only_datasets, 
                                        class_weight=class_weight)
            
            #model_list.append(CV_model)
            predicted_prob_list.append(predicted_prob)
            predicted_only_prob_list.append(predicted_only_prob)
        
        predicted_prob = np.mean(predicted_prob_list, axis=0)
        test_out_prob_list.extend(predicted_prob)
        predicted_only_prob = np.mean(predicted_only_prob_list, axis=0)
        pred_only_prob_list.extend(predicted_only_prob) 
        
    a, acc_list = evaluater(sylb_label_mat, test_out_prob_list, loss_func = loss_func, unit = 3)    
    if not verbose:
        print(str(a))
    _, loss_list = evaluater(clipTimeMat, pred_only_prob_list, loss_func = loss_func, unit = 3)
    
    return [acc_list, loss_list], test_out_prob_list, pred_only_prob_list

def train(batch_size, lr, EPOCH, patience, CV_train_datasets, 
          CV_val_datasets=False, CV_test_datasets=False, pred_only_datasets=False, 
          class_weight=False):
    #This function directly train and validate the model
    model = torch.load("./"+subject+'.pt')
    loss_func = nn.CrossEntropyLoss() 
    if class_weight is not False:
        weight = [max(class_weight)/x for x in class_weight]
        weight = torch.FloatTensor(weight).to(device)
        loss_func = nn.CrossEntropyLoss(weight=weight)
        
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)#也可以使用SGD优化算法进行训练
    
    early_stopping = EarlyStopping(patience)
    '''
    optimizer = torch.optim.SGD(model.parameters(),lr=lr)#也可以使用SGD优化算法进行训练
    scheduler =  torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = optimizer,T_max =  EPOCH)
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=patience, 
                                           verbose=False, threshold=0.0001, threshold_mode='rel', 
                                           cooldown=0, min_lr=0, eps=1e-08)
    '''
    for epoch in range(EPOCH):
        #return acc，ground truth & predicited label
        sum_loss_train= 0
        total_train=0
        model.train()
        for train_dataset in CV_train_datasets:
            train_loader = Data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
            for data in train_loader:
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                pred = model(inputs)
                loss = loss_func(pred, labels)
                sum_loss_train+= loss.item()
                total_train+= labels.size(0)
                loss.backward()
                optimizer.step()
        train_loss = sum_loss_train/total_train
        if verbose:
            print('Epoch {}:train-loss:{:.2e},'.format(epoch+1,train_loss),end='')
       
        if CV_val_datasets is not False:#if there is val datasets performing early-stopping
            sum_loss_val=0
            total_val = 0  
            for val_dataset in CV_val_datasets:
                val_loader = Data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
                for data in val_loader:
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device) 
                    pred = model(inputs)
                    loss = loss_func(pred, labels)
                    sum_loss_val+=loss.item()
                    total_val+= labels.size(0)
            val_loss = sum_loss_val/total_val
            if verbose:
                print('val-loss:{:.2e},'.format(val_loss,),end='')
            
            '''
            scheduler.step(val_loss)##############################################
            #scheduler.step()
            '''
            early_stopping(val_loss, model)
            if early_stopping.early_stop:
                if verbose:
                    print("Early stopping")
                break
        
        if CV_test_datasets is not False: #if there is test data for testing
            total_test = 0  
            sum_loss_test= 0
            predicted=[]
            label=[]
            
            for test_dataset in CV_test_datasets:
                test_loader = Data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
                for data in test_loader:
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device)
                    pred_prob = model(inputs)
                    loss = loss_func(pred_prob, labels)
                    sum_loss_test+=loss.item()
                    _, pred = torch.max(pred_prob.data, 1)
                    pred_prob = pred_prob.data
                    pred_prob = F.softmax(pred_prob, dim=1) 
                    total_test+= labels.size(0)
                    predicted.append(pred.cpu())
                    label.append(labels.cpu())
            label = torch.cat(label,dim=0)
            predicted = torch.cat(predicted,dim=0)
            correct = (predicted == label).sum()
            test_acc= correct.item()/total_test
            test_loss=sum_loss_test/total_test
            if verbose:
                print('test-loss:{:.2e},Acc:{:.4f}'.format(test_loss,test_acc),end='')
        if verbose:
            print('')
    
    model.load_state_dict(torch.load('checkpoint.pt'))
    
    ##############################EVALUATION model###############################################
    predicted=[]
    predicted_prob=[]
    label=[]
    acc=[]
    if CV_test_datasets is not False:
        total_test = 0  # 总数
        sum_loss_test= 0
        model.eval()
        for test_dataset in CV_test_datasets:
            test_loader = Data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
            for data in test_loader:
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)  # 有GPU则将数据置入GPU加速
                pred_prob = model(inputs)
                loss = loss_func(pred_prob, labels)
                sum_loss_test+=loss.item()
                _, pred = torch.max(pred_prob.data, 1)
                pred_prob = pred_prob.data
                pred_prob = F.softmax(pred_prob, dim=1) 
                total_test+= labels.size(0)

                predicted.append(pred.cpu())
                predicted_prob.append(pred_prob.cpu())
                label.append(labels.cpu())

        label = torch.cat(label,dim=0)
        predicted = torch.cat(predicted,dim=0)
        predicted_prob = torch.cat(predicted_prob,dim=0)
        correct = (predicted == label).sum()
        test_acc= correct.item()/total_test
        test_loss=sum_loss_test/total_test
        if verbose:
            print('Final: test-loss:{:.2e},Acc:{:.4f}'.format(test_loss,test_acc))
        label = label.detach().numpy()
        acc = test_acc
        predicted = predicted.detach().numpy()
        predicted_prob = predicted_prob.detach().numpy()
    
    
    ##PRED-ONLY:do not return acc, only predicted result
    predicted_only=[]
    predicted_only_prob=[]
        
    if pred_only_datasets is not False:
        for pred_only_dataset in pred_only_datasets:
            pred_only_loader = Data.DataLoader(pred_only_dataset, batch_size=batch_size, shuffle=False)
            for inputs in pred_only_loader:
                inputs = inputs[0].to(device)
                pred = model(inputs)
                pred_only_prob = pred.data
                pred_only_prob = F.softmax(pred_only_prob, dim=1) 
                _, pred_only = torch.max(pred.data, 1)
                predicted_only.append(pred_only.cpu())
                predicted_only_prob.append(pred_only_prob.cpu())
            
        predicted_only = torch.cat(predicted_only).detach().numpy()
        predicted_only_prob = torch.cat(predicted_only_prob).detach().numpy()
        if verbose:
            print(predicted_only_prob.shape)
    
    return label,acc, predicted,predicted_prob, predicted_only,predicted_only_prob, model,loss_func.cpu()

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, delta=0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping: {self.counter}/{self.patience}',end='')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        torch.save(model.state_dict(), 'checkpoint.pt')	# save the best model
        self.val_loss_min = val_loss

In [9]:
path_raw='../Raw/B'

sylbList = ['shi', 'de', 'ji', 'li', 'bu', 'ge', 'qi', 'zhe', 'ta', 'zhi']
stateList=['silent','speech']
toneList = ['1','2','3','4']
hz = 400
filterType = '_70_150'
channelNum = 256
sylb_mat = np.load('../'+subject+'_sylb_mat'+'.npy',allow_pickle=True)
print(sylb_mat.shape)
_,_,_,sylb_label_mat =  read_ecog_mat(0.01,0.01,sylb_mat,channelNum=256,key_elecs=[1,],
                key_label=[],oppolabel=False,key_sentence=[],opposentence=False, 
                key_paragraph=[13,14,15,16],oppoparagraph=True,hz=hz, block=4)
clipTimeMat = np.load(subject+'_opt_clipTimeMat'+'.npy',allow_pickle=True)
print(clipTimeMat.shape)

channelNum = 256
sylb_back = 0.4
sylb_forward = 0.8

CV_list=[[1,2],[3,4],[5,6],[7,8],[9,10],[11,12]]
sylb_elecs = np.load('../'+subject+'_anovasylb_elecs.npy')
row=0

(7, 1034)
select paraqraph:[13, 14, 15, 16] oppo
select elecs:[1] read_ecog: (775, 1, 8)
(7, 768)


In [11]:
class timespatCNNRNN(nn.Module):
    def __init__(self, *, duration, typeNum, in_chans, is_timespat,
                 n_filters_time,
                 filter_time_length,
                 n_filters_spat,
                 conv_stride,
                 pool_time_length,
                 pool_stride,
                 n_filters,
                 filter_length, 
                 n_CNN_layer,
                 gruDim,
                 gruLayer,
                 drop_out):
        super().__init__()
        
        self.conv_time = nn.Conv2d(
            1,
            n_filters_time,
            (filter_time_length, 1),
            stride=1,
        )
        
        self.conv_spat = nn.Conv2d(
            n_filters_time,
            n_filters_spat,
            (1, in_chans),
            stride=(conv_stride, 1),
        )
        self.is_timespat = is_timespat
        self.conv_timespat = nn.Conv2d(
            1,
            n_filters_spat,
            (filter_time_length, in_chans),
            stride=(conv_stride, 1),
        )
        
        self.bnorm = nn.BatchNorm2d(
            n_filters_spat,
            #momentum=self.batch_norm_alpha,
            affine=True,
            eps=1e-5)
        self.elu = nn.ELU()
        self.pool = nn.MaxPool2d(kernel_size=(pool_time_length, 1), stride=(pool_stride, 1))
        
        self.conv_pool_block = nn.ModuleList()
        self.conv_pool_block.append(nn.Dropout(p=drop_out))
        self.conv_pool_block.append(nn.Conv2d(
            n_filters_spat,
            n_filters,
            (filter_length, 1),
            stride=(conv_stride, 1),
            padding=(((filter_length - 1) * conv_stride) // 2,0)
        ))
        
        for i in range(n_CNN_layer-1):
            self.conv_pool_block.append(nn.Dropout(p=drop_out))
            self.conv_pool_block.append(nn.Conv2d(
                n_filters,
                n_filters,
                (filter_length, 1),
                stride=(conv_stride, 1),
                padding=(((filter_length - 1) * conv_stride) // 2,0)
            ))
            self.conv_pool_block.append(nn.BatchNorm2d(
                n_filters,
                momentum=0.1,
                affine=True,
                eps=1e-5,
            ))
            self.conv_pool_block.append(nn.ELU())
            self.conv_pool_block.append(nn.MaxPool2d(
                kernel_size=(pool_time_length, 1),
                stride=(pool_stride, 1),
            ))

        self.gru1 = nn.GRU(n_filters, gruDim, gruLayer, batch_first=True, bidirectional=True)
        elec_feature = int(2*gruDim)
        self.fc1 = nn.Linear(elec_feature, typeNum)

    def forward(self, x):
        x = rearrange(x,'(batch 1) electrodes duration -> batch 1 duration electrodes')
        if self.is_timespat:
            x = self.conv_timespat(x)
        else:
            x = self.conv_time(x)
            #print('1',x.shape)
            x = self.conv_spat(x)
        #print('2',x.shape)
        x = self.bnorm(x)
        x = self.elu(x)
        x = self.pool(x)
        for block in self.conv_pool_block:
            x = block(x)
        x = rearrange(x,'batch filter duration 1 -> batch duration filter')
        x = self.gru1(x)[0][:,-1,:]
        #x = self.relu1(x) 
        x = self.fc1(x)
        return x

####################################################################################################
'''
a = torch.zeros([10,32,480])
print(a.shape)
model = timespatCNNRNN(duration=480, typeNum=10, in_chans=32, 
                 n_filters_time=64,
                 filter_time_length=2,
                 n_filters_spat=64,
                 conv_stride=2,
                 pool_time_length=2,
                 pool_stride=2,
                 gruDim=64)
                        
#model = CNNRNN(duration = 480, typeNum = 10, channelNum=256, kernel=2, stride=2, mp_kernel=2, gruDim=64)
b = model(a)
print(b)

for name, param in model.named_parameters():
    print('Parameter name:', name)
    print('Parameter shape:',param.shape)
'''

"\na = torch.zeros([10,32,480])\nprint(a.shape)\nmodel = timespatCNNRNN(duration=480, typeNum=10, in_chans=32, \n                 n_filters_time=64,\n                 filter_time_length=2,\n                 n_filters_spat=64,\n                 conv_stride=2,\n                 pool_time_length=2,\n                 pool_stride=2,\n                 gruDim=64)\n                        \n#model = CNNRNN(duration = 480, typeNum = 10, channelNum=256, kernel=2, stride=2, mp_kernel=2, gruDim=64)\nb = model(a)\nprint(b)\n\nfor name, param in model.named_parameters():\n    print('Parameter name:', name)\n    print('Parameter shape:',param.shape)\n"

In [12]:
CV_list=[[1,2],[3,4],[5,6],[7,8],[9,10],[11,12]]
#load datasets of syllable decoding on manually aligned onsets
CV_sylb_datasets,_,_ = CV_datasets(back=sylb_back,
                            forward=sylb_forward,
                            mat=sylb_mat,
                            key_elecs=sylb_elecs,
                            row=0,
                            CV_list=CV_list,
                            unbalance=False,List=False,
                            augmented=False)

mat: (7, 1034)
select paraqraph:[1, 2] select elecs:[  0   1   2   5   8   9  13  14  15  32  46  48  49  59  64  95  96 100
 110 111 112 126 128 135 136 137 138 139 140 142 143 150 151 152 153 154
 155 156 157 158 159 166 167 168 169 170 171 172 173 174 175 180 181 182
 183 184 185 186 187 188 189 190 191 197 199 201 204 214 215 226 231 242
 243 253 254] read_ecog: (133, 75, 480)
mat: (7, 1034)
select paraqraph:[3, 4] select elecs:[  0   1   2   5   8   9  13  14  15  32  46  48  49  59  64  95  96 100
 110 111 112 126 128 135 136 137 138 139 140 142 143 150 151 152 153 154
 155 156 157 158 159 166 167 168 169 170 171 172 173 174 175 180 181 182
 183 184 185 186 187 188 189 190 191 197 199 201 204 214 215 226 231 242
 243 253 254] read_ecog: (132, 75, 480)
mat: (7, 1034)
select paraqraph:[5, 6] select elecs:[  0   1   2   5   8   9  13  14  15  32  46  48  49  59  64  95  96 100
 110 111 112 126 128 135 136 137 138 139 140 142 143 150 151 152 153 154
 155 156 157 158 159 166 167 168 1

In [14]:
CV_list=[[1,2],[3,4],[5,6],[7,8],[9,10],[11,12]]
#load datasets of syllable decoding on the onsets detected by optimized speech detector
CV_clip_datasets,_,_ = CV_datasets(back=sylb_back,
                            forward=sylb_forward,
                            mat=clipTimeMat,
                            key_elecs=sylb_elecs,
                            row=0,
                            CV_list=CV_list,
                            unbalance=False,List=False,
                            augmented=False)

mat: (7, 768)
select paraqraph:[1, 2] select elecs:[  0   1   2   5   8   9  13  14  15  32  46  48  49  59  64  95  96 100
 110 111 112 126 128 135 136 137 138 139 140 142 143 150 151 152 153 154
 155 156 157 158 159 166 167 168 169 170 171 172 173 174 175 180 181 182
 183 184 185 186 187 188 189 190 191 197 199 201 204 214 215 226 231 242
 243 253 254] read_ecog: (132, 75, 480)
mat: (7, 768)
select paraqraph:[3, 4] select elecs:[  0   1   2   5   8   9  13  14  15  32  46  48  49  59  64  95  96 100
 110 111 112 126 128 135 136 137 138 139 140 142 143 150 151 152 153 154
 155 156 157 158 159 166 167 168 169 170 171 172 173 174 175 180 181 182
 183 184 185 186 187 188 189 190 191 197 199 201 204 214 215 226 231 242
 243 253 254] read_ecog: (130, 75, 480)
mat: (7, 768)
select paraqraph:[5, 6] select elecs:[  0   1   2   5   8   9  13  14  15  32  46  48  49  59  64  95  96 100
 110 111 112 126 128 135 136 137 138 139 140 142 143 150 151 152 153 154
 155 156 157 158 159 166 167 168 169 

In [None]:
EPOCH=1000
patience2=50

stage='sylb'
verbose = False
plot = False

reporterList=[]
def hyperopt_syllable(filter_time_length,n_filters,conv_stride,pool_time_length,pool_stride,filter_length,
                      n_CNN_layer,gruDim,gruLayer,drop_out):    
    batch_size_sylb = 512
    lr_sylb = 0.0005#lr
    model2 = timespatCNNRNN(duration=int((sylb_back+sylb_forward)*hz), typeNum=10, in_chans=len(sylb_elecs),
                                    is_timespat=True,
                                    n_filters_time=n_filters, filter_time_length=filter_time_length,
                                    n_filters_spat=n_filters, conv_stride=conv_stride,
                                    pool_time_length=pool_time_length, pool_stride=pool_stride,
                                    n_filters=n_filters, 
                                    filter_length=filter_length, 
                                    n_CNN_layer=n_CNN_layer,
                                    gruDim=gruDim,
                                    gruLayer=gruLayer,
                                    drop_out=drop_out).to(device)
    
    torch.save(model2,("./"+subject+'_sylb.pt'))
    torch.save(model2,("./"+subject+'.pt'))
    lists, _, _ = CV_train_ENN(CV_datasets = CV_sylb_datasets, 
                                   CV_pred_only_datasets = CV_clip_datasets,
                                   lr=lr_sylb, batch_size=batch_size_sylb, patience=patience2,
                                   class_weight = False, channelNum=256)
    
    #a = np.array([np.nanmean(loss_list),is_timespat,batch_size,lr,conv_stride,pool_time_length,pool_stride,filter_length])
    a = np.array([np.nanmean(lists[1])+np.nanmean(lists[0]),np.nanmean(lists[1]),np.nanmean(lists[0]),filter_time_length,n_filters,conv_stride,
        pool_time_length,pool_stride,filter_length, n_CNN_layer,gruDim,gruLayer,drop_out,])
    if verbose:
        print(a)
    reporterList.append(a)
    #print(reporterList)
    return(np.nanmean(lists[1])+np.nanmean(lists[0]))



'''
print(hyperopt_syllable(is_timespat=False,
                        conv_stride=23,
                        n_filters=64,
                        pool_time_length=3,
                        pool_stride=3,
                        filter_length=3,
                        gruDim=64))
'''
from hyperopt import hp,STATUS_OK,Trials,fmin,tpe
def hyperopt_train(params):
    loss=hyperopt_syllable(**params)
    return loss

hyperparas = {
    #'batch_size': hp.choice('batch_size', [4,8]),
    #'lr': hp.choice('lr', [0.0005,0.0001]),
    'filter_time_length':hp.choice('filter_time_length', [4,5,6]),
    'n_filters':hp.choice('n_filters', [256,512,1024,2048]),
    'conv_stride': hp.choice('conv_stride', [1,2,3]),
    'pool_time_length': hp.choice('pool_time_length', [2,3,4]),
    'pool_stride': hp.choice('pool_stride', [1,2,3]),
    'filter_length': hp.choice('filter_length', [2,3,4]),
    'n_CNN_layer':hp.choice('n_CNN_layer',[1,2,3]),
    'gruDim':hp.choice('gruDim', [64,128,256,512]),
    'gruLayer':hp.choice('gruLayer',[1,2,3,4]),
    'drop_out':hp.choice('drop_out', [0.2,0.3,0.4,0.5,0.6,0.7,0.8]),
    #'drop_out':hp.uniform('drop_out', 0.2, 0.8),
}


def f(params):
    try:
        loss = hyperopt_train(params)
    except Exception as e:
        # Handle the exception here (e.g., print the error message)
        print(f"Error: {e}")
        # Assign a high loss value to discourage selecting this parameter combination
        return {'loss': 9999, 'status': STATUS_OK}
    
    return {'loss': loss, 'status': STATUS_OK}

trials=Trials()
best=fmin(f,hyperparas,algo=tpe.suggest,max_evals=500,trials=trials)
print('best',best)

'''use best to index the best parameters'''

[0.5671641791044776, 0.4696969696969697, 0.3787878787878788, 0.4393939393939394, 0.4090909090909091, 0.3728813559322034, 0.6212121212121212, 0.6507936507936508, 0.5423728813559322, 0.5076923076923077, 0.4090909090909091, 0.4090909090909091]
  0%|          | 1/500 [20:44<172:29:11, 1244.39s/trial, best loss: 0.06284551322460175]