In [None]:
# plot
from sklearn.decomposition import PCA
from sklearn.preprocessing import normalize
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio

import os.path
from os import path
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import h5py
import sys

emg_folder_default = 'D:\\emg_data\\'   # default EMG folder

def loadmat(matfile, vnames):
    try:
        data = sio.loadmat(matfile, variable_names=vnames)
        x_data = data[vnames[0]]
        y_data = data[vnames[1]]
    except NotImplementedError:
        Feature=h5py.File(matfile, 'r') #read mat file
        x_data = Feature[vnames[0]][:]
        y_data = Feature[vnames[1]][:]
        x_data = np.transpose(x_data)
        y_data = np.transpose(y_data)
        print("Load h5py")
    except:
        print("Unexpected error:", sys.exc_info()[0])
        print(matfile)
    return x_data, y_data


# plot a frame of high-density EMG
def plot_frame(frame):
    # plot_frame(frame)
    plt.figure()
    plt.imshow(frame.transpose(), cmap = plt.cm.binary)
    plt.show
    plt.ylabel('Channels')
    plt.xlabel('Time')
    plt.savefig('EMG_map.png')

    plt.figure()
    plt.plot(frame[:,60], 'r', label='ch 60')
    plt.plot(frame[:,40], 'b', label='ch 40')
    plt.legend(framealpha=1, frameon=True);
    plt.show
    plt.savefig('EMG_raw.png')

# check the number of samples in each class (i.e., number of frames w/wo spikes)
def checkData(classes):
    if classes.shape[0]==0:
        return 0
    if ~isinstance(classes, int):
        classes = classes.astype(int)
    try:
        neg, pos = np.bincount(classes)
        total = neg + pos
        print('Examples:\n    Total: {}\n    Positive: {} ({:.2f}% of total)\n'.format(
            total, pos, 100 * pos / total))
        return 1.0/neg, 1.0/pos
    except:
        neg = 0
        pos = 0
        return neg, pos


# load data from pickle 
import pickle
def load_data(trial, step_size, ch, seg_index):
    seg = [1, 3, 5]
    # load train data set
    segment = seg[seg_index]
    prefix = "{}_{}_st{}_ch{}".format(trial, segment, step_size, ch)
    x_file = "{}_x.pickle".format(prefix)
    y_file = "{}_y.pickle".format(prefix)
    print(prefix)
    if not path.exists(x_file):
        pathstr = emg_folder_global
        x_file = "{}{}".format(pathstr, x_file)
        y_file = "{}{}".format(pathstr, y_file)
    else:
        print(x_file)
    
    pickle_in = open(x_file, "rb")
    x_data = pickle.load(pickle_in)
    pickle_in = open(y_file, "rb")
    y_data = pickle.load(pickle_in)
    
    checkData(y_data)
    plot_frame(x_data[0,:,:])
    print(y_data[0])
    return x_data, y_data


# load data from mat files
# including two variables: EMGs and spikes
def load_data_mat(TR, SG = 0, ST = 10, MU = 1, WS = 120, TF = 0, MutiSeg = 0, emg_folder_global = None):
    # TR - trial name (e.g., 1_30_GM)
    # SG - segment ID (e.g., 0-2)
    # ST - step size (5, 10, 20, 30, 40, 50)
    # MU - motor unit index (0-N, N is the number)
    # WS - window size (e.g., 120)
    # TF = 0, no shuffle; TF = 1, shuffle;  0<TF<1, seperate data
    # MutiSeg - 0: train with one segment of data; 1: train with two segments of data; 2: train with three segments of data 
    if emg_folder_global is None:
        emg_folder_global = emg_folder_default
    seg = [1, 2, 3]
    # load train data set
    segment = seg[SG]
    # construct mat file name based on parameters
    prefix = "{}-SG{}-WS{}-ST{}".format(TR, segment, WS, ST)
    matfile = "{}.mat".format(prefix)
    if not path.exists(matfile):
        pathstr = emg_folder_global
        matfile = "{}{}".format(pathstr, matfile)
        
    if not path.exists(matfile):
        print('{} not exist'.format(matfile))
        x_data = []
        y_data = []
        return x_data, y_data
    else:
        print('{} exist'.format(matfile))

    vnames = ['EMGs', 'Spikes_cst']

    # load mat file with sio or h5py
    x_data, spikes = loadmat(matfile, vnames)
    
    # load second segment if MutiSeg is 1
    if MutiSeg>=1:
        seg2 = [2, 3, 1]
        segment = seg2[SG]
        prefix = "{}-SG{}-WS{}-ST{}".format(TR, segment, WS, ST)
        matfile = "{}.mat".format(prefix);  
        if not path.exists(matfile):
            pathstr = emg_folder_global
            matfile = "{}{}".format(pathstr, matfile)
    #     print(matfile)
        x_data_2, spikes_2 = loadmat(matfile, vnames)
        x_data = np.concatenate((x_data, x_data_2)) 
        spikes = np.concatenate((spikes, spikes_2)) 

    if MutiSeg>=2:
        seg3 = [3, 1, 2]
        segment = seg3[SG]
        prefix = "{}-SG{}-WS{}-ST{}".format(TR, segment, WS, ST)
        matfile = "{}.mat".format(prefix);  
        if not path.exists(matfile):
            pathstr = emg_folder_global
            matfile = "{}{}".format(pathstr, matfile)
    #     print(matfile)
        x_data_2, spikes_2 = loadmat(matfile, vnames)
        x_data = np.concatenate((x_data, x_data_2)) 
        spikes = np.concatenate((spikes, spikes_2)) 

    print(x_data.shape)
    if MutiSeg>=1:
        print('extend')
        # expand data
        x_data = np.concatenate((x_data, x_data*0.7, x_data*0.4)) 
        spikes = np.concatenate((spikes, spikes, spikes))
    print(x_data.shape)
    
    # exactract spikes for given motor units
    if type(MU) is list:
        y_data = []
        for c in MU:
            if c < spikes.shape[1]:
                y_data.append(spikes[:, c])
            else:
                y_data.append(spikes[:, -1]*0)
    else:
        y_data = []
        y_data.append(spikes[:, MU])

    ## shuffle the data based on TF
    y_data = np.array(y_data)
    y_data = y_data.T
    if TF == 1:
        x_data, y_data = shuffle(x_data, y_data)
    elif TF > 0: 
        x_data, _, y_data, _= train_test_split(x_data, y_data, test_size = 1.0-TF)
    else:
        print('no shuffle')
    y_data = y_data.T
    y_data = list(y_data)
    
    return x_data, y_data

# split data into train set and test set
def split_data(x_data, y_data, test_size = 0):
    y_data = np.array(y_data)
    y_data = y_data.T
    x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size = test_size)

    y_train = y_train.T
    y_train = list(y_train)
    y_test = y_test.T
    y_test = list(y_test)
    return x_train, x_test, y_train, y_test

# shuffle the data
def shuffle_data(x_data, y_data):
    y_data = np.array(y_data)
    y_data = y_data.T
    x_data, y_data = shuffle(x_data, y_data)
    y_data = y_data.T
    y_data = list(y_data)
    return x_data, y_data


In [None]:
import tensorflow.keras.backend as K
from tensorflow.keras.metrics import Metric
from tensorflow.keras.callbacks import Callback
import tensorflow as tf
# import neptune

# customized metrics
def recall_m(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

def precision_m(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision

# calculate f1 score
def f1_m(y_true, y_pred):
    y_pred_binary = tf.where(y_pred>=0.5, 1., 0.)
    true_positives = K.sum(K.round(K.clip(y_true * y_pred_binary, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred_binary, 0, 1)))
    
    precision = true_positives / (predicted_positives + K.epsilon())
    recall = true_positives / (possible_positives + K.epsilon())
    return 2*((precision*recall)/(precision + recall + K.epsilon()))


# customized callbacks
class BaseLogger(Callback):
    """Callback that accumulates epoch averages of metrics.
    This callback is automatically applied to every Keras model.
    """
    def __init__(self):
        super().__init__()
        
    def on_epoch_begin(self, epoch, logs=None):
        self.seen = 0
        self.totals = {}

    def on_batch_end(self, batch, logs=None):
        logs = logs or {}
        batch_size = logs.get('size', 0)
        self.seen += batch_size

        for k, v in logs.items():
            if k in self.totals:
                self.totals[k] += v * batch_size
            else:
                self.totals[k] = v * batch_size

    def on_epoch_end(self, epoch, logs=None):
        if logs is not None:
            for k in self.params['metrics']:
                if k in self.totals:
                    # Make value available to next callbacks.
                    logs[k] = self.totals[k] / self.seen
#         print(logs)


class AccuracyCallback(Callback):
    def __init__(self, metric_name = 'accuracy'):
        super().__init__()
        self.metric_name = metric_name
        self.val_metric = []
        self.metric = []
        self.val_metric_mean = 0
        self.metric_mean = 0
        self.best_metric = 0
        
    def on_epoch_end(self, epoch, logs=None):
        # extract values from logs
        self.val_metric = []
        self.metric = []
        for log_name, log_value in logs.items():
            if log_name.find(self.metric_name) != -1:
                if log_name.find('val') != -1:
                    self.val_metric.append(log_value)
                else:
                    self.metric.append(log_value)

        self.val_metric_mean = np.mean(self.val_metric)
        self.metric_mean = np.mean(self.metric)
        logs['val_{}'.format(self.metric_name)] = np.mean(self.val_metric)   # replace it with your metrics
        logs['{}'.format(self.metric_name)] = np.mean(self.metric)   # replace it with your metrics


class updateLogs(Callback):
    def on_epoch_end(self, epoch, logs):
        self.val_metric = []
        self.metric = []
        for log_name, log_value in logs.items():
            if log_name.find('f1_m') != -1:
#                 print("{}:{}".format(log_name, log_value))
                if log_name.find('val') != -1:
                    self.val_metric.append(log_value)
                else:
                    self.metric.append(log_value)   
                    
        logs['val_f1_m'] = np.mean(self.val_metric)   # replace it with your metrics
        logs['f1_m'] = np.mean(self.metric)   # replace it with your metrics
        

In [None]:
import tensorflow.keras as keras
# import keras
from tensorflow.keras.preprocessing.sequence import TimeseriesGenerator
from tensorflow.keras.models import Sequential, load_model, Model
from tensorflow.keras.layers import Dense, Flatten, Activation, Input
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Conv2D, MaxPooling2D, BatchNormalization, Dropout, LSTM, AveragePooling1D
from tensorflow.keras.regularizers import l2

############ create models with given input shape and output shape ##########
gg_nn_nodes = [128, 128, 128, 64, 256]

# create convolutional neural network with API interface
def get_cnn1d_api(shape_in, shape_out, nn_nodes = [128, 128, 128, 64, 256]):
    '''Create a keras model with functional API'''
    # create convolutional neural network model
    # shape_in = (timesteps, features)
#     nn_nodes = [128, 128, 128, 64, 256]
#     global gg_nn_nodes
    gg_nn_nodes = nn_nodes
    print(gg_nn_nodes)
    visible = Input(shape = shape_in, name='EMG')
    
    cnn = Conv1D(filters=gg_nn_nodes[0], kernel_size=3, activation='relu')(visible)
    cnn = Conv1D(filters=gg_nn_nodes[1], kernel_size=3, activation='relu')(cnn)
    cnn = MaxPooling1D(pool_size=2)(cnn)
    cnn = Dropout(0.5)(cnn)
    
    struct_type = 1
    if struct_type:
        cnn = Conv1D(filters=gg_nn_nodes[2], kernel_size=3, activation='relu')(cnn)
        cnn = Conv1D(filters=gg_nn_nodes[3], kernel_size=3, activation='relu')(cnn)
        cnn = MaxPooling1D(pool_size=2)(cnn)
        cnn = Dropout(0.5)(cnn)
        cnn_2 = Flatten()(cnn)

    outputs = []
    for k in range(1, shape_out+1):
        if not struct_type:
            cnn_2 = Conv1D(filters=gg_nn_nodes[2], kernel_size=3, activation='relu')(cnn)
            cnn_2 = Conv1D(filters=gg_nn_nodes[3], kernel_size=3, activation='relu')(cnn_2)
            cnn_2 = MaxPooling1D(pool_size=2)(cnn_2)
            cnn_2 = Dropout(0.5)(cnn_2)
            cnn_2 = Flatten()(cnn_2)

        s2 = Dense(gg_nn_nodes[4], activation='relu')(cnn_2)
        s2 = Dropout(0.5)(s2)
        output = Dense(1, activation='sigmoid', name='output_{}'.format(k))(s2)
        outputs.append(output)
    
    if 0:
        k = k + 1
        s2 = Dense(nn_nodes[4], activation='relu')(cnn_2)
        s2 = Dropout(0.5)(s2)
        output_f = Dense(1, activation='sigmoid', name='output_{}'.format(k))(s2)
        outputs.append(output_f)
    
    metrics = {'output_1':['accuracy', f1_m]}
    loss = {'output_1':'binary_crossentropy'}
    for k in range(2, shape_out+1):
        key = 'output_{}'.format(k)
        metrics[key] = ['accuracy', f1_m]
        loss[key]= 'binary_crossentropy'
        
    if 0:
        k = k + 1
        key = 'output_{}'.format(k)
        metrics[key] = ['accuracy', 'mse']
        loss[key]= 'mse'
    
    # tie together
    model = Model(inputs=visible, outputs=outputs)
    return model, loss, metrics


In [None]:
# use tensorboard for display
from tensorflow.keras.callbacks import TensorBoard, EarlyStopping, ModelCheckpoint, LambdaCallback
import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
from tensorflow.keras.models import load_model
from sklearn.model_selection import train_test_split

# build model with configuration
def build_model(mIndex, WS = 120, n_output = 1, nn_nodes = [128, 128, 128, 64, 256]):
    n_input = WS
    n_features = 64 # set default number of EMG channels

    model_cnn, loss_cnn, metrics_cnn = get_cnn1d_api((n_input, n_features), n_output, nn_nodes)
    model = model_cnn

#### all possible metrics
#     METRICS = [
#         keras.metrics.BinaryAccuracy(name='accuracy'),
#         keras.metrics.MeanSquaredError(name='mse'),
#         keras.metrics.Precision(name='precision'),
#         keras.metrics.TruePositives(name='tp'),
#         keras.metrics.FalsePositives(name='fp'),
#         keras.metrics.TrueNegatives(name='tn'),
#         keras.metrics.FalseNegatives(name='fn'), 
#         keras.metrics.Recall(name='recall'),
#         keras.metrics.AUC(name='auc'),
#     ]

    METRICS = [
        'accuracy',
        'mse',
         f1_m,
    ]

    print(n_output)

    if n_output == 1:
        model.compile(optimizer = 'rmsprop', #'adam',
                        loss = 'binary_crossentropy',
                        metrics = METRICS) #['accuracy', 'mse'])
    else:
        model.compile(optimizer = 'rmsprop', #sgd', 'adagrad', 'rmsprop', 'adam'
                        loss = loss_cnn,  # mean_squared_error
                        metrics = metrics_cnn) #['accuracy', 'mse'])
    return model

def train_model(model, x_data, y_data, prefix, epochs = 100):
    tname = int(time.time())
    batch_size = 64
    
    # create tersorboard
    log_name = "hdEMG_{}_{}".format(prefix, tname)
    model_name = 'best_model_{}_{}_l.h5'.format(prefix, tname)
    model_name_vl = 'best_model_{}_{}_vl.h5'.format(prefix, tname)
    model_name_a = 'best_model_{}_{}_a.h5'.format(prefix, tname)
    model_name_va = 'best_model_{}_{}_va.h5'.format(prefix, tname)
    model_name_f = 'best_model_{}_{}_f.h5'.format(prefix, tname)
    tensorboard = TensorBoard(log_dir = ".\\logs\\{}".format(log_name))

    # early stop when improvement is small
    # monitor: val_loss, val_accuracy 
    es = EarlyStopping(monitor='loss', mode='min', verbose=1, patience=50)
    
    # save the best model when accuracy is the best
    mc = ModelCheckpoint(model_name, monitor='loss', mode='min', verbose=1, save_best_only=True)
    mc_vl = ModelCheckpoint(model_name_vl, monitor='val_loss', mode='min', verbose=1, save_best_only=True)
    mc_a = ModelCheckpoint(model_name_a, monitor='accuracy', mode='max', verbose=1, save_best_only=True)
    mc_va = ModelCheckpoint(model_name_va, monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)
    mc_f = ModelCheckpoint('best_model_{}_{}_f.h5'.format(prefix, tname), monitor='f1_m', mode='max', verbose=1, save_best_only=True)
    mc_vf = ModelCheckpoint('best_model_{}_{}_vf.h5'.format(prefix, tname), monitor='val_f1_m', mode='max', verbose=1, save_best_only=True)
#     custom_save = LambdaCallback(on_epoch_end=saveModel)
    
    accuracy_callback = AccuracyCallback('accuracy')
    f1_callback = AccuracyCallback('f1_m')
    
    x_train = x_data
    y_train = y_data
    
    # train model
    generatorEnable = False
    if generatorEnable:
#         class_weight = {0 : 1., 1: 1.}
        history = model.fit_generator(data_train,  
                            steps_per_epoch = num_t/batch_size,
                            validation_data = data_val, 
                            validation_steps = num_v/batch_size,
                            epochs = epochs, 
#                             class_weight=class_weight,
                            callbacks=[tensorboard, es, mc])
    else:
#         class_weight = {0 : 1., 1: 1.}
        history = model.fit(x_train, 
                            y_train,
                            validation_split = 0.2,
#                             validation_data=(x_valid, y_valid),
                            batch_size = batch_size,
                            epochs = epochs,
                            verbose = 1,
#                             class_weight = class_weight,
                            callbacks = [es, mc, mc_vl, accuracy_callback, f1_callback, tensorboard, mc_f, mc_vf])
    
    # return best model for further evaluation
    model = load_model(model_name_f, custom_objects={"f1_m": f1_m})
    return model, tname

# display model structure
# pip install pydotplus
# pip install pydot
# https://bobswift.atlassian.net/wiki/spaces/GVIZ/pages/20971549/How+to+install+Graphviz+software
def display_model(model, filename = 'model.png'):
    # plot model structure
    from tensorflow.keras.utils import plot_model
    plot_model(model, to_file='C:\{}'.format(filename), show_shapes=True)
    from IPython.display import Image
    Image(filename='C:\{}'.format(filename))
    
# load model with cuostmized metrics
def load_model_custom(model_name, inference = False):
    model = load_model(model_name, custom_objects={"f1_m": f1_m})
    if inference:
        model.save('tmp.h5', include_optimizer=False, save_format='h5')
        model = load_model('tmp.h5', compile=False)
    return model

# validate model with given data sets
def model_validata(model, x_data, y_data, prefix):
    # sequential data
    y_pred = evaluate(model, x_data, y_data)
    savedata(y_data, y_pred, "{}".format(prefix))
#     scores = model.evaluate(x_data, y_data, verbose=0)
#     print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))
    
# evaluate model prediction
def evaluate(model, x_val, y_val, showFlag = 0):
    # 
    print('\n# Generate predictions')
    y_pred = model.predict(x_val)
    y_pred = np.asarray(y_pred)
    if len(y_pred.shape) == 3:
        y_pred = np.reshape(y_pred, (y_pred.shape[0], y_pred.shape[1]))
    
    if showFlag:
        f, (ax1, ax2) = plt.subplots(1, 2)
        ax1.plot(y_val)
        ax1.set_title('real_value')
        ax2.plot(y_pred)
        ax2.set_title('predict_value')
        plt.show
    return y_pred

# save prediction and acutal values to csv file
def savedata(y_val, y_pred, fname):

    if type(y_val) is list:
        y_val = np.array(y_val)
    if type(y_pred) is list:
        y_pred = np.array(y_pred)
        
    if len(y_pred.shape) == 3:
        y_pred = np.reshape(y_pred, (y_pred.shape[0], y_pred.shape[1]))
        
    if len(y_val.shape) == 3:
        y_val = np.reshape(y_val, (y_val.shape[0], y_val.shape[1]))

    if len(y_val.shape) == 2 and y_val.shape[0] < y_val.shape[1]:
        y_val = np.transpose(y_val)
    elif len(y_val.shape) == 1:
        y_val = np.reshape(y_val, (y_val.shape[0], 1))
        
    if y_pred.shape[0] < y_pred.shape[1]:
        y_val = np.transpose(y_val)
           
    # save data
    if  y_val.shape[0] > y_val.shape[1] and y_val.shape[0] == y_pred.shape[0]:
        data = np.column_stack((y_val, y_pred))
#         data = np.transpose(data)
    else:
        data = np.vstack((y_val, y_pred))
    
#     if data.shape[0] < data.shape[1]:
    data = np.transpose(data)

    data.shape
    pd.DataFrame(data).to_csv("output-{}.csv".format(fname))


colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
def plot_metrics(history):
    # history = model.fit(...)
    # plot_metrics(history)
    metrics =  ['loss', 'auc', 'precision', 'recall']
    for n, metric in enumerate(metrics):
        name = metric.replace("_"," ").capitalize()
        plt.subplot(2,2,n+1)
        plt.plot(history.epoch, history.history[metric], color=colors[0], label='Train')
        plt.plot(history.epoch, history.history['val_' + metric],
             color=colors[0], linestyle="--", label='Val')
        plt.xlabel('Epoch')
        plt.ylabel(name)
        if metric == 'loss':
            plt.ylim([0, plt.ylim()[1]])
        elif metric == 'auc':
            plt.ylim([0.8,1])
        else:
            plt.ylim([0,1])

        plt.legend()

# not used anymore
def crossValidate(model, trialName, segMax = 3, chMax = 8):
    # crossValidate(model, '1_50_GM')
    modelType = ["nn", "cnn", "rnn"]
    trial = trialName
    modelIndex = 1
    step_size = 5
    tname = int(time.time())
    subFolder = model.yname[:-3]
#     print(model.yname)
    for ch in range(0, chMax):
        for seg in range(0, segMax):
            prefix = "{}-{}-ST{}-CH{}".format(trial, modelType[modelIndex], step_size, ch)
            # x_test, y_test = load_data(trial, step_size, ch, seg)
            x_test, y_test = load_data_mat(TR = trial, SG = seg, ST = step_size, CH = ch)
            prefix4file = "{}-SG{}-T{}".format(prefix, seg, tname)
            print(prefix4file)
            os.chdir(subFolder) 
            model_validata(model, x_test, y_test, prefix4file)
            os.chdir('..')
        
# evaluate model with data generator - not used anymore
def evaluate_gen(model, data_set):
    # data_set is generator
    predict = []
    realValue = []
    evaluate_gen = data_set
    print('Samples: %d' % len(evaluate_gen))
    for i in range(len(evaluate_gen)):
        x, y = evaluate_gen[i]
        yhat = model.predict(x)
        yhat = np.hstack(yhat)
        print('.', end = '')
        #print(y)        #print(yhat)        #print(yhat-y)
        #print(yhat[numpy.where(y==1)[0]])
        predict.extend(yhat)
        realValue.extend(y)
    
    # Creates two subplots and unpacks the output array immediately
    f, (ax1, ax2) = plt.subplots(1, 2)
    ax1.plot(realValue)
    ax1.set_title('real_value')
    ax2.plot(predict)
    ax2.set_title('predict_value')
    plt.show

    # save data
    data = np.vstack((realValue, predict))
    data = np.transpose(data)
    data.shape
    pd.DataFrame(data).to_csv("output_train_cnn.csv")

# not used in this study
# best_val_acc = 0
# best_val_loss = sys.float_info.max 
def saveModel(epoch,logs):
    val_acc = logs['val_acc']
    val_loss = logs['val_loss']

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        model.save(...)
    elif val_acc == best_val_acc:
        if val_loss < best_val_loss:
            best_val_loss=val_loss
            model.save(...)