In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [6]:
import warnings
warnings.filterwarnings('ignore')

from collections import defaultdict
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

import h5py
import json

import math
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import os

import pandas as pd

from random import shuffle,randint

from scipy.fftpack import fft
import scipy.stats as stats
import scipy.io as scio
from scipy.io import wavfile
from scipy import signal
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report


from tqdm import tqdm
from torchmetrics.functional import word_error_rate as wer
from textwrap import wrap
import textgrid as tg
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch
from torch import nn, optim, einsum
import torch.utils.data as Data
from torch.utils.data import DataLoader,TensorDataset
import wave


RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = 'cpu'
date='1117'
print(date)
subject = 'PA4'

1117


In [7]:
#FUNCTIONS USED FOR LOAD DATA
def get_timelocked_activity(times, hg, back, forward, hz=False):
    '''
    Get time-locked activity.

    Parameters:
    - times (array-like): List of timepoints in seconds.
    - hg (numpy.ndarray): High gamma array, shaped (elecs_num, whole_duration*hz).
    - back (float): Start time before the timepoints.
    - forward (float): End time after the timepoints.
    - hz (float, optional): Sampling rate of High gamma. Default is False.

    Returns:
    - Y_mat (numpy.ndarray): Time-locked activity array, shaped (trial_num, elecs_num, selected_duration*hz).
    - back (int): Start time in samples (back*hz).
    - forward (int): End time in samples (forward*hz).
    '''
    if hz:
        times = (times*hz).astype(int)
        back = int(back*hz)
        forward = int(forward*hz)
    times = times[times - back > 0]
    times = times[times + forward < hg.shape[1]]

    Y_mat = np.zeros((len(times),hg.shape[0], int(back + forward)), dtype=float)

    for i, index in enumerate(times):
        Y_mat[i, :, :] = hg[:, int(index-back):int(index+forward)]

    return Y_mat,back,forward

def read_ecog_mat(back,forward,mat,channelNum,key_elecs=[],key_label=[],oppolabel=False,key_sentence=[],opposentence=False, key_paragraph=[],oppoparagraph=False,
                  hz=False, block=4):
    '''
    get ecog_mat based on defined parameters
    
    Parameters:
    - mat : shaped(7,num_of_timepoints)
    - hg (numpy.ndarray): High gamma array, shaped (elecs_num, whole_duration*hz).
    - back (float): Start time before the timepoints.
    - forward (float): End time after the timepoints.
    - hz (float, optional): Sampling rate of High gamma. Default is False.
    channelNum:number of electrodes 128 or 256
    block:number of blocks included, usually from 1 to 4
    key_elecs: selected elecs
    kay_labels:selected labels, opplabel means oppose the key_label
    sentence and paragraphs similarly
    '''

    key_elecs = np.array(key_elecs)
    key_index = np.ones(len(mat[0,:])).astype('bool')
    
    sentence_index = np.ones(len(mat[0,:])).astype('bool')
    label_index = np.ones(len(mat[0,:])).astype('bool')
    paragraph_index = np.ones(len(mat[0,:])).astype('bool')  
    
    if key_label:
        print('select label:'+str(key_label),end=' ')
        label_index=~label_index
        for i in key_label:
            temp_index = mat[0,:] ==i
            label_index = np.logical_or(label_index, temp_index)
        if oppolabel:
            print('oppo')
            label_index=~label_index
    key_index = np.logical_and(label_index, key_index)
    if key_sentence:
        print('select sentence:'+str(key_sentence),end=' ')
        sentence_index=~sentence_index
        for i in key_sentence:
            temp_index = mat[2,:] ==i
            sentence_index = np.logical_or(sentence_index, temp_index)
        if opposentence:
            print('oppo')
            sentence_index=~sentence_index
    key_index = np.logical_and(sentence_index, key_index)        
    if key_paragraph:
        print('select paraqraph:'+str(key_paragraph),end=' ')
        paragraph_index=~paragraph_index
        for i in key_paragraph:
            temp_index = mat[3,:] ==i
            paragraph_index = np.logical_or(paragraph_index, temp_index)
        if oppoparagraph:
            print('oppo')
            paragraph_index=~paragraph_index
    key_index = np.logical_and(paragraph_index, key_index)      
    mat = mat[:,key_index]
       
    
    if key_elecs.any():
        ecog_mat = np.zeros((len(mat[0,:]),len(key_elecs), int((back+forward)*hz)), dtype=float)
        print('select elecs:'+str(key_elecs),end=' ')
    else:
        ecog_mat = np.zeros((len(mat[0,:]),channelNum, int((back+forward)*hz)), dtype=float)
        
    for i in range(block):
        block_index= mat[4,:]==i+1
        temp_time_list=mat[1,block_index]
        ecogData=scio.loadmat(path_raw+str(i+1)+filterType+'.mat')#################################
        ecogData=np.array(ecogData['bands'])*10e4
        #print('###',np.nanmean(np.max(ecogData,axis=1)-np.min(ecogData,axis=1)),np.array(np.max(ecogData,axis=1)).shape)######################################################################################
        temp_ecog,back_duration,forward_duration = get_timelocked_activity(times=temp_time_list, hg=ecogData, back=back, forward=forward, hz=hz)
        if key_elecs.any():
            key_elecs = np.array(key_elecs)
            ecog_mat[block_index] = temp_ecog[:,key_elecs,:]
        else:
            ecog_mat[block_index] = temp_ecog

    print('read_ecog:',ecog_mat.shape)
    return ecog_mat,back_duration,forward_duration,mat

'''
EXAMPLE:
ecog_mat,back_duration,forward_duration,mat = read_ecog_mat(back,forward,mat,channelNum,key_elecs=[],
                key_label=[],oppolabel=False,key_sentence=[],opposentence=False, key_paragraph=[],oppoparagraph=False,
                hz=False, block=4)
'''

'\nEXAMPLE:\necog_mat,back_duration,forward_duration,mat = read_ecog_mat(back,forward,mat,channelNum,key_elecs=[],\n                key_label=[],oppolabel=False,key_sentence=[],opposentence=False, key_paragraph=[],oppoparagraph=False,\n                hz=False, block=4)\n'

In [8]:
#FUNCTIONS USED FOR TRAIN AND VALIDATE THE MODEL
def CV_datasets(back,forward,mat,key_elecs,row,CV_list,unbalance=False,List=False,augmented=False):
    CV_datasets = []
    CV_augmented_datasets=[]
    count_total=[]
    #split data to cross validation combinations, and calculate the weight of each labels
    for CVs in range(6):
        print('mat:',mat.shape)
        test_ecog_mat,_,_,test_mat= read_ecog_mat(back=back,forward=forward,mat=mat,channelNum=channelNum,key_elecs=key_elecs,
                key_label=[],oppolabel=False,key_sentence=[],opposentence=False, key_paragraph=CV_list[CVs],oppoparagraph=False,
                hz=hz, block=4)

        test_label = test_mat[row,:]
        
        if unbalance == True:
            count=np.bincount(test_label.astype('int'))
            if len(count)!=len(List):
                print('not a full set!')
                dataPlus1 = np.concatenate((data,np.arange(len(List))))
                count=np.bincount(dataPlus1)-1    
                del dataPlus1
            print(count)
            count_total.append(count)
        test_ecog_mat=torch.FloatTensor(test_ecog_mat)
        test_label=torch.tensor(test_label,dtype=torch.long)
        testdataset = Data.TensorDataset(test_ecog_mat, test_label)    
        CV_datasets.append(testdataset)
        

        if augmented is not False:
            aug_mat = stammer (mat=test_mat, multiple_time=augmented,multiple_range=0.1,shift=-0.05)
            aug_ecog_mat,_,_,test_mat= read_ecog_mat(back=back,forward=forward,mat=mat,channelNum=channelNum,key_elecs=key_elecs,
                key_label=[],oppolabel=False,key_sentence=[],opposentence=False, key_paragraph=CV_list[CVs],oppoparagraph=False,
                hz=hz, block=4)
            aug_label = aug_mat[row,:]
            aug_ecog_mat=torch.FloatTensor(aug_ecog_mat)
            aug_label=torch.tensor(aug_label,dtype=torch.long)
            augdataset = Data.TensorDataset(aug_ecog_mat, aug_label) 
            CV_augmented_datasets.append(augdataset)
    if unbalance == True:
        count_total = np.array(count_total)
        
    return CV_datasets,count_total,CV_augmented_datasets

'''
EXAMPLES:
CV_onset_datasets,count_total,_ = CV_datasets(back=overt_back,
                                            forward=overt_forward,
                                            mat=overt_mat,
                                            key_elecs=overt_elecs,
                                            row=0,
                                            CV_list,
                                            unbalance=True,List=stateList,
                                            augmented=False)
class_weight = np.sum(count_total,axis=0)
'''
'''
CV_sylb_datasets,_,_ = CV_datasets(back=sylb_back,
                            forward=sylb_forward,
                            mat=sylb_mat,
                            key_elecs=sylb_elecs,
                            row=0,
                            CV_list,
                            unbalance=False,List=False,
                            augmented=False)
'''
'''
CV_clip_datasets,_,_ = CV_datasets(back=sylb_back,
                            forward=sylb_forward,
                            mat=clipTimeMat,
                            key_elecs=sylb_elecs,
                            row=0,
                            CV_list,
                            unbalance=False,List=False,
                            augmented=False)

'''
def CV_train_onset(CV_datasets, lr, batch_size, patience, class_weight, val_ratio, channelNum=256):
    loss_func = nn.CrossEntropyLoss()
    #This function used to preform cross validation
    acc_list=[]
    loss_list=[]
    label_list=[]
    test_out_list=[]
    test_out_prob_list=[]

    for test_CV in range(6):#tqdm(

        CV_train_val_datasets = CV_datasets.copy()
        CV_test_datasets = [CV_train_val_datasets.pop(test_CV),]
        
        CV_train_datasets = []
        CV_val_datasets = []
        for train_val_dataset in CV_train_val_datasets:
            val_size = int(val_ratio * len(train_val_dataset))
            train_size = len(train_val_dataset) - val_size
            train_dataset, val_dataset = Data.random_split(train_val_dataset, [train_size, val_size])
            CV_train_datasets.append(train_dataset)
            CV_val_datasets.append(val_dataset)
        #label,acc,predicted,predicted_prob, predicted_only,predicted_only_prob,model
        label,_,_,predicted_prob,_,_,_,loss_func = train( 
                                        lr=lr, batch_size=batch_size, EPOCH=EPOCH, patience=patience, 
                                        CV_train_datasets = CV_train_datasets, #CV_train_val_datasets, #
                                        CV_val_datasets =  CV_val_datasets, #False, #
                                        CV_test_datasets = CV_test_datasets,#False, 
                                        pred_only_datasets = False, #pred_only_datasets, 
                                        class_weight=class_weight)
            
        #predicted_prob_list.append(predicted_prob)
        #predicted_prob = np.mean(predicted_prob_list, axis=0)
        
        loss = loss_func(torch.tensor(predicted_prob), torch.tensor(label))
        loss = loss / len(label)
        predicted = np.argmax(predicted_prob, axis=1)
        correct = np.sum(predicted == np.array(label))
        acc = correct / len(label)
        #tune.track.log(loss=loss)####################################################
       
        
        acc_list.append(acc)
        loss_list.append(loss)
        label_list.extend(label)
        test_out_list.extend(predicted) 
        test_out_prob_list.extend(predicted_prob) 
        if verbose:
            plotCM(List = stateList,test_acc_list=acc_list,test_label_list=label_list,test_out_list=test_out_list)
        
    return acc_list, loss_list, label_list, test_out_prob_list
'''
_, _, label_list, onset_prob = CV_train_onset(CV_datasets=CV_onset_datasets, 
                                    lr, batch_size, patience, 
                                    class_weight, val_ratio, 
                                    channelNum=256)
'''
def train(batch_size, lr, EPOCH, patience, CV_train_datasets, 
          CV_val_datasets=False, CV_test_datasets=False, pred_only_datasets=False, 
          class_weight=False):
    #This function directly train and validate the model
    model = torch.load("./"+subject+'.pt')
    loss_func = nn.CrossEntropyLoss() 
    if class_weight is not False:
        weight = [max(class_weight)/x for x in class_weight]
        weight = torch.FloatTensor(weight).to(device)
        loss_func = nn.CrossEntropyLoss(weight=weight)
        
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)#也可以使用SGD优化算法进行训练
    
    early_stopping = EarlyStopping(patience)
    '''
    optimizer = torch.optim.SGD(model.parameters(),lr=lr)#也可以使用SGD优化算法进行训练
    scheduler =  torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = optimizer,T_max =  EPOCH)
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=patience, 
                                           verbose=False, threshold=0.0001, threshold_mode='rel', 
                                           cooldown=0, min_lr=0, eps=1e-08)
    '''
    for epoch in range(EPOCH):
        #return acc，ground truth & predicited label
        sum_loss_train= 0
        total_train=0
        model.train()
        for train_dataset in CV_train_datasets:
            train_loader = Data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
            for data in train_loader:
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                pred = model(inputs)
                loss = loss_func(pred, labels)
                sum_loss_train+= loss.item()
                total_train+= labels.size(0)
                loss.backward()
                optimizer.step()
        train_loss = sum_loss_train/total_train
        if verbose:
            print('Epoch {}:train-loss:{:.2e},'.format(epoch+1,train_loss),end='')
       
        if CV_val_datasets is not False:#if there is val datasets performing early-stopping
            sum_loss_val=0
            total_val = 0  
            for val_dataset in CV_val_datasets:
                val_loader = Data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
                for data in val_loader:
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device) 
                    pred = model(inputs)
                    loss = loss_func(pred, labels)
                    sum_loss_val+=loss.item()
                    total_val+= labels.size(0)
            val_loss = sum_loss_val/total_val
            if verbose:
                print('val-loss:{:.2e},'.format(val_loss,),end='')
            
            '''
            scheduler.step(val_loss)##############################################
            #scheduler.step()
            '''
            early_stopping(val_loss, model)
            if early_stopping.early_stop:
                if verbose:
                    print("Early stopping")
                break
        
        if CV_test_datasets is not False: #if there is test data for testing
            total_test = 0  
            sum_loss_test= 0
            predicted=[]
            label=[]
            
            for test_dataset in CV_test_datasets:
                test_loader = Data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
                for data in test_loader:
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device)
                    pred_prob = model(inputs)
                    loss = loss_func(pred_prob, labels)
                    sum_loss_test+=loss.item()
                    _, pred = torch.max(pred_prob.data, 1)
                    pred_prob = pred_prob.data
                    pred_prob = F.softmax(pred_prob, dim=1) 
                    total_test+= labels.size(0)
                    predicted.append(pred.cpu())
                    label.append(labels.cpu())
            label = torch.cat(label,dim=0)
            predicted = torch.cat(predicted,dim=0)
            correct = (predicted == label).sum()
            test_acc= correct.item()/total_test
            test_loss=sum_loss_test/total_test
            if verbose:
                print('test-loss:{:.2e},Acc:{:.4f}'.format(test_loss,test_acc),end='')
        if verbose:
            print('')
    
    model.load_state_dict(torch.load('checkpoint.pt'))
    
    ##############################EVALUATION model###############################################
    predicted=[]
    predicted_prob=[]
    label=[]
    acc=[]
    if CV_test_datasets is not False:
        total_test = 0  # 总数
        sum_loss_test= 0
        model.eval()
        for test_dataset in CV_test_datasets:
            test_loader = Data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
            for data in test_loader:
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)  # 有GPU则将数据置入GPU加速
                pred_prob = model(inputs)
                loss = loss_func(pred_prob, labels)
                sum_loss_test+=loss.item()
                _, pred = torch.max(pred_prob.data, 1)
                pred_prob = pred_prob.data
                pred_prob = F.softmax(pred_prob, dim=1) 
                total_test+= labels.size(0)

                predicted.append(pred.cpu())
                predicted_prob.append(pred_prob.cpu())
                label.append(labels.cpu())

        label = torch.cat(label,dim=0)
        predicted = torch.cat(predicted,dim=0)
        predicted_prob = torch.cat(predicted_prob,dim=0)
        correct = (predicted == label).sum()
        test_acc= correct.item()/total_test
        test_loss=sum_loss_test/total_test
        print('Final: test-loss:{:.2e},Acc:{:.4f}'.format(test_loss,test_acc))
        label = label.detach().numpy()
        acc = test_acc
        predicted = predicted.detach().numpy()
        predicted_prob = predicted_prob.detach().numpy()
    
    
    ##PRED-ONLY:do not return acc, only predicted result
    predicted_only=[]
    predicted_only_prob=[]
        
    if pred_only_datasets is not False:
        for pred_only_dataset in pred_only_datasets:
            pred_only_loader = Data.DataLoader(pred_only_dataset, batch_size=batch_size, shuffle=False)
            for inputs in pred_only_loader:
                inputs = inputs[0].to(device)
                pred = model(inputs)
                pred_only_prob = pred.data
                pred_only_prob = F.softmax(pred_only_prob, dim=1) 
                _, pred_only = torch.max(pred.data, 1)
                predicted_only.append(pred_only.cpu())
                predicted_only_prob.append(pred_only_prob.cpu())
            
        predicted_only = torch.cat(predicted_only).detach().numpy()
        predicted_only_prob = torch.cat(predicted_only_prob).detach().numpy()
        print(predicted_only_prob.shape,'#$#')
    
    return label,acc, predicted,predicted_prob, predicted_only,predicted_only_prob, model,loss_func.cpu()

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, delta=0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping: {self.counter}/{self.patience}',end='')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        #if self.verbose:
            #print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). save model ...')
        torch.save(model.state_dict(), 'checkpoint.pt')	# save the best model
        self.val_loss_min = val_loss

In [14]:
#FUNCTIONS USED FOR EVALUATION THE PERFORMANCE AND VISUALIZATION
def plotCM(List,test_acc_list,test_label_list,test_out_list,wer_list=False):
    # plot confusion matrix
    C2= confusion_matrix(test_label_list, test_out_list,normalize='true')#,
    #C2= confusion_matrix(test_out, predicted,normalize='true')#test_out->
    labels=List
    print(C2) 
    disp = ConfusionMatrixDisplay(confusion_matrix = C2, display_labels=labels)
    fig=plt.figure()
    fig.set_size_inches(200,200)
    disp.plot()
    plt.title('N='+str(len(test_acc_list))+"  Acc:%0.2f±%0.2f" % (np.mean(test_acc_list),np.std(test_acc_list)))
    if save:
        plt.savefig(stage+'_CM.png', bbox_inches='tight')
    return 0

def seq_F1_score(interval, test_mat, predicted, plot=True,
           wd=20,smooth_threshold=0.6,onset_duration=0.08,offset_duration=0.08,word_duration = 0.8,eps = 0.05):
    # caluculate the F1_score as the performance for hyperopt 
    label = test_mat[0,:]
    hwd=int(wd/2)
    smooth_out=predicted.copy()
    for i in range(hwd,len(predicted)-hwd+1):
        smooth_out[i]=np.nanmean(predicted[i-hwd:i+hwd])
    bi_out=smooth_out>smooth_threshold
    
    onset_predicted=np.zeros((bi_out.shape))   
    onset_label=np.zeros((bi_out.shape))
        
    clipTimeIndex=[]
    stableOnset = onset_duration//interval
    stableOffset = offset_duration//interval
    sayAword=word_duration//interval

    i = int(stableOnset)
    while i < len(bi_out)-stableOnset:
        if bi_out[i]==1 and bi_out[i-1]==0:
            if np.mean(bi_out[i:int(i+stableOnset)])>1-eps:
                if np.mean(bi_out[int(i-stableOffset):i])<eps:
                    onset_predicted[i-100:i+100] = 1
                    clipTimeIndex.append(i)
                    i+=int(sayAword-1)
        i+=1
    
    i = int(stableOnset)
    while i < len(label)-stableOnset:
        if label[i]==1 and label[i-1]==0:
            if np.mean(label[i:int(i+stableOnset)])>1-eps:
                if np.mean(label[int(i-stableOffset):i])<eps:
                    onset_label[i-100:i+100] = 1 ####误差+-0.25s内
                    i+=int(sayAword-1)
        i+=1    
        
    from sklearn.metrics import accuracy_score
    ACC = accuracy_score(onset_label, onset_predicted)   
    
    if plot:
        clipTimeIndex = np.array(clipTimeIndex)
        clipTimeMat = test_mat[:,clipTimeIndex]
        trial_values = set(test_mat[3,:])
        print (trial_values)
        fig, axs = plt.subplots(len(trial_values), 1, figsize=(500, len(trial_values) *3))
        # Iterate over each split timescale and plot it as a subplot
        for i, index in enumerate(trial_values):
            current_index = [j for j, val in enumerate(test_mat[3,:]) if val == index]
            current_timescale = [test_mat[1,j] for j in current_index]
            #print(len(current_timescale))
            current_predicted = [predicted[j] for j in current_index]
            #print(len(current_predicted))
            current_label = [test_mat[0,j] for j in current_index]
            current_smooth_out = [smooth_out[j] for j in current_index]
            current_bi_out = [bi_out[j] for j in current_index]
            
            current_split_index = [t for t, val in enumerate(clipTimeMat[3,:]) if val == index]
            current_clip_time = [clipTimeMat[1,j] for j in current_split_index]
            
            axs[i].set_xlim(min(current_timescale), max(current_timescale))
            axs[i].plot(current_timescale, current_predicted, 'r+')
            axs[i].plot(current_timescale, current_smooth_out, 'y+')
            axs[i].plot(current_timescale, current_bi_out, 'k+')
            axs[i].plot(current_timescale, current_label, 'b+')

            for T in current_clip_time:
                #axs[i].axvline(T, color='green')
                axs[i].text(T, 0.1, f'{T:.2f}', color='green', ha='center', va='bottom')
            axs[i].set_xticks(current_clip_time)
            axs[i].tick_params(axis='x', which='major', bottom=True, top=True, direction='in', length=100, width=1, color='green')
            axs[i].set_title('Trial ' + str(index))

            
            
        # Adjust the spacing between subplots
        plt.tight_layout()

        # Save the figure
        if save:
            plt.savefig('view_onset.png')
        
    return ACC

'''
EXAMPLE
plot = True
F1 = seq_F1_score(interval,onset_label_mat, predicted = np.vstack(onset_prob)[:,1], plot=True,
                    wd=int(0.018*hz),smooth_threshold=0.81,onset_duration=0.02,offset_duration=0.08,
                    word_duration = 0.5, eps = 0.08)
print(F1)
'''

'\nEXAMPLE\nplot = True\nF1 = seq_F1_score(interval,onset_label_mat, predicted = np.vstack(onset_prob)[:,1], plot=True,\n                    wd=int(0.018*hz),smooth_threshold=0.81,onset_duration=0.02,offset_duration=0.08,\n                    word_duration = 0.5, eps = 0.08)\nprint(F1)\n'

In [15]:
#FUNCTIONS FOR HYPERPARAMETERS TUNING
def hyperopt_onset(overt_back = 0.25,
                   overt_forward = 0.25,
                   gruDim = 3,
                   gruLayer = 64,
                   drop_out = 0.5,
                   smooth_window = 0.042,
                   smooth_threshold = 0.76, 
                   onset_duration = 0.02,
                   offset_duration = 0.02,
                   eps = 0.05,
                   batch_size = 4096, 
                   lr = 0.0005):    
    outPutLoss = 9999 
    F1 = 0
    
    F1 = seq_F1_score(interval, onset_label_mat, predicted = np.vstack(onset_prob)[:,1], plot=False,
                    wd=int(smooth_window*hz),
                    smooth_threshold = smooth_threshold,
                    onset_duration = onset_duration,
                    offset_duration = offset_duration,
                    word_duration = 0.5, 
                    eps = eps)
    
    a = np.array([outPutLoss,F1,
                  overt_back, overt_forward, gruDim, gruLayer, drop_out,
                  smooth_window, smooth_threshold, onset_duration, offset_duration, eps,
                  batch_size, lr])
    if verbose:
        print(a)
    reporterList.append(a)
    #np.savetxt("hyperoptReporter.csv", reporterList, delimiter = ",", fmt = '%s')
    
    return outPutLoss, F1


'''
EXAMPLE:
reporterList = []
print(hyperopt_onset(overt_back = 0.25,
                   overt_forward = 0.25,
                   gruDim = 3,
                   gruLayer = 64,
                   drop_out = 0.5,
                   smooth_window = 0.042,
                   smooth_threshold = 0.76, 
                   onset_duration = 0.02,
                   offset_duration = 0.02,
                   batch_size = 4096, 
                   lr = 0.0005))
'''
from hyperopt import hp,STATUS_OK,Trials,fmin,tpe
def hyperopt_train(params):
    loss, F1 =hyperopt_onset(**params)
    loss = -F1
    return loss
def f(params):
    try:
        loss = hyperopt_train(params)
    except Exception as e:
        # Handle the exception here (e.g., print the error message, hyperparas conflicting)
        print(f"Error: {e}")
        # Assign a high loss value to discourage selecting this parameter combination
        return {'loss': 9999, 'status': STATUS_OK}
    
    return {'loss': loss, 'status': STATUS_OK}

In [11]:
path_raw='../Raw/B'
sylbList = ['shi', 'de', 'ji', 'li', 'bu', 'ge', 'qi', 'zhe', 'ta', 'zhi']
stateList=['silent','speech']
toneList = ['1','2','3','4']
hz = 400
filterType = '_70_150'
interval = 0.01
channelNum = 256

overt_mat = np.load('../'+subject+'_overt_mat'+'.npy',allow_pickle=True)
print(overt_mat.shape)
_,_,_,onset_label_mat =  read_ecog_mat(0.01,0.01,overt_mat,channelNum=256,key_elecs=[1,],
                key_label=[],oppolabel=False,key_sentence=[],opposentence=False, 
                key_paragraph=[13,14,15,16],oppoparagraph=True,hz=hz, block=4)
print(onset_label_mat.shape)
resp_elecs = np.load('../'+subject+'_resp_elecs.npy')


sylb_mat = np.load('../'+subject+'_sylb_mat'+'.npy',allow_pickle=True)
print(sylb_mat.shape)
_,_,_,sylb_label_mat =  read_ecog_mat(0.01,0.01,sylb_mat,channelNum=256,key_elecs=[1,],
                key_label=[],oppolabel=False,key_sentence=[],opposentence=False, 
                key_paragraph=[13,14,15,16],oppoparagraph=True,hz=hz, block=4)
print(sylb_label_mat)

CV_onset_datasets,count_total,_ = CV_datasets(back = 0.25,
                                            forward = 0.25,
                                            mat=overt_mat,
                                            key_elecs=resp_elecs,
                                            row=0,
                                            CV_list=[[1,2],[3,4],[5,6],[7,8],[9,10],[11,12]],
                                            unbalance=True,List=stateList,
                                            augmented=False)
#this is the weight of each type, use to generate weighted crossentrophy_loss function
class_weight = np.sum(count_total,axis=0)

(5, 178692)
select paraqraph:[13, 14, 15, 16] oppo
select elecs:[1] read_ecog: (137374, 1, 8)
(5, 137374)
(7, 1034)
select paraqraph:[13, 14, 15, 16] oppo
select elecs:[1] read_ecog: (775, 1, 8)
[[  8.        5.        4.      ...   2.        1.        0.     ]
 [ 24.2053   25.07445  26.36158 ... 532.05439 533.36673 534.45248]
 [  1.        1.        1.      ...   7.        7.        7.     ]
 ...
 [  1.        1.        1.      ...   3.        3.        3.     ]
 [  0.        0.        3.      ...   3.        1.        0.     ]
 [ 32.       21.       19.      ...  12.        7.        0.     ]]
mat: (5, 178692)
select paraqraph:[1, 2] select elecs:[  0   1   2   5   8   9  10  12  13  14  15  29  30  31  32  45  46  48
  49  53  59  61  64  68  69  75  78  80  94  95  96 100 107 110 111 112
 113 114 126 127 128 130 131 133 134 135 136 137 138 139 140 142 143 150
 151 152 153 154 155 156 157 158 159 165 166 167 168 169 170 171 172 173
 174 175 180 181 182 183 184 185 186 187 188 189 19

In [13]:
#generate the speech detector model
overt_back = 0.25
overt_forward = 0.25
val_ratio = 0.1 
CV_list=[[1,2],[3,4],[5,6],[7,8],[9,10],[11,12]]
batch_size = 2048
lr=0.001
EPOCH=50
patience=10

class CRNN(nn.Module):
    def __init__(self, *, duration, typeNum, in_chans, 
                 num_layers=4, gruDim=256, drop_out=0.5):
        super().__init__()
        self.conv1d = nn.Conv1d(in_channels=in_chans, out_channels=gruDim, kernel_size=3, stride=1, padding=0)
        self.leaky_relu = nn.LeakyReLU(negative_slope=0.01)
        self.max_pooling = nn.MaxPool1d(kernel_size=2, stride=None, padding=0)
        self.dropout = nn.Dropout(p=drop_out)
        gru_layers = []
        for i in range(num_layers):
            if i == 0:
                gru_layers.append(nn.GRU(gruDim, gruDim, 1, batch_first=True, bidirectional=True))
            else:
                gru_layers.append(nn.GRU(gruDim * 2, gruDim, 1, batch_first=True, bidirectional=True))
            
        # Create the sequential model with stacked GRU layers
        self.gru_layers = nn.Sequential(*gru_layers)
        #self.gru1 = nn.GRU(channelNum, gruDim, 3, batch_first=True, bidirectional=True)
        elec_feature = int(2*gruDim)
        self.fc1 = nn.Linear(elec_feature, typeNum)

    def forward(self, x):
        #x = x.permute(0, 2, 1)  # Convert (batch_size, seq_len, input_size) to (batch_size, input_size, seq_len)
        x = self.conv1d(x)
        x = self.leaky_relu(x)
        x = self.max_pooling(x)
        x = rearrange(x,'batch electrodes duration -> batch duration electrodes')
        for gru_layer in self.gru_layers:
            x, _ = gru_layer(x)
            #x = self.elu(x)
            x = self.dropout(x)
        
        x = x[:, -1, :]
        x = self.fc1(x)

        return x
    
model1 = CRNN(duration = int((overt_back+overt_forward)*hz), typeNum = 2, 
             in_chans=len(resp_elecs), gruDim=256).to(device)

torch.save(model1,("./"+subject+'_onset.pt'))
torch.save(model1,("./"+subject+'.pt'))
######

In [None]:
torch.cuda.empty_cache()
verbose = True
stage = 'onset'
plot = False
save = False
########################################################################################
_, _, _, onset_prob = CV_train_onset(CV_datasets=CV_onset_datasets, val_ratio = val_ratio,
                                    lr=lr, batch_size=batch_size, patience=patience, 
                                    class_weight=class_weight, channelNum=256)
'''
'''

Epoch 1:train-loss:2.41e-04,val-loss:2.67e-04,test-loss:1.79e-04,Acc:0.7785
Epoch 2:train-loss:1.56e-04,val-loss:1.90e-04,test-loss:1.22e-04,Acc:0.8669
Epoch 3:train-loss:1.14e-04,val-loss:1.80e-04,test-loss:1.12e-04,Acc:0.8736
Epoch 4:train-loss:9.76e-05,val-loss:1.42e-04,test-loss:8.74e-05,Acc:0.9072
Epoch 5:train-loss:1.05e-04,val-loss:2.02e-04,EarlyStopping: 1/10test-loss:1.58e-04,Acc:0.8322
Epoch 6:train-loss:9.01e-05,val-loss:1.60e-04,EarlyStopping: 2/10test-loss:1.34e-04,Acc:0.8502
Epoch 7:train-loss:8.38e-05,val-loss:1.36e-04,test-loss:8.82e-05,Acc:0.8963
Epoch 8:train-loss:7.55e-05,val-loss:1.77e-04,EarlyStopping: 1/10test-loss:1.52e-04,Acc:0.8230
Epoch 9:train-loss:8.19e-05,val-loss:1.34e-04,test-loss:1.31e-04,Acc:0.8574
Epoch 10:train-loss:7.37e-05,val-loss:1.14e-04,test-loss:9.13e-05,Acc:0.8913
Epoch 11:train-loss:7.65e-05,val-loss:1.23e-04,EarlyStopping: 1/10test-loss:9.29e-05,Acc:0.8901
Epoch 12:train-loss:7.13e-05,val-loss:1.35e-04,EarlyStopping: 2/10test-loss:1.22e-04,A

In [40]:
np.save(subject+'_opt_onset_prob.npy',onset_prob)###
#reload saved data
#onset_prob = np.load(subject+'_opt_onset_prob.npy',allow_pickle=True)###
print(onset_prob.shape)

(137374, 2)


In [39]:
##############################DO NOT RUN!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
plot = False
verbose = False

hyperparas = {
    #'overt_back':hp.choice('overt_back',[0.1,0.15,0.2,0.25,0.3]),
    #'overt_forward':hp.choice('overt_forward', [0.1,0.15,0.2,0.25,0.3]),
    #'gruDim':hp.choice('gruDim', [64,128,256,512]),
    #'gruLayer':hp.choice('gruLayer', [1,2,3]),
    #'drop_out':hp.choice('drop_out', [0.2,0.3,0.4,0.5,0.6,0.7,0.8]),
    'smooth_window':hp.choice('smooth_window',np.arange(0.01, 0.05, 0.002).tolist()),
    'smooth_threshold':hp.choice('smooth_threshold',np.arange(0.2, 0.9, 0.01).tolist()),
    'onset_duration':hp.choice('onset_duration',np.arange(0.02, 0.3, 0.02).tolist()),
    'offset_duration':hp.choice('offset_duration',np.arange(0.02, 0.3, 0.02).tolist()),
    'eps':hp.choice('eps',np.arange(0.02, 0.2, 0.02).tolist()),
    
    #'batch_size': hp.choice('batch_size', [512,1024,2048,4096,9192]),
    #'lr': hp.choice('lr', [0.001,0.005,0.0005,0.0001]),
}

trials=Trials()
reporterList = []
best=fmin(f,hyperparas,algo=tpe.suggest,max_evals=500,trials=trials)
verbose=False
print('best',best)

100%|██████████| 500/500 [35:18<00:00,  4.24s/trial, best loss: -0.9879744347547571]
best {'eps': 4, 'offset_duration': 2, 'onset_duration': 11, 'smooth_threshold': 46, 'smooth_window': 18}


In [None]:
#generate clipTimeMat for hyperopt of sylb decoder and tone decoder
clipTimeMat, plot_data = slicer(interval, onset_label_mat, predicted=np.vstack(onset_prob)[:,1], sylb_mat = sylb_mat, 
                       smooth_window = 0.046,
                       smooth_threshold = 0.66,
                       onset_duration = 0.24,
                       offset_duration = 0.06,
                       word_duration = 0.5,eps = 0.1)

In [None]:
print(clipTimeMat.shape)
np.save(subject+'_opt_clipTimeMat.npy',clipTimeMat)###