In [1]:
import os
import logging
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler, StandardScaler

In [2]:
ko_1_size = 15
wt_1_size = 15
ko_2_size = 41
wt_2_size = 41
ko_3_size = 16
wt_3_size = 18

ko_1_file_prefix = '1KO'
wt_1_file_prefix = '1WT'
ko_2_file_prefix = '2KO'
wt_2_file_prefix = '2WT'
ko_3_file_prefix = '3KO'
wt_3_file_prefix = '3WT'

ko_mice_symbol = 0
wt_mice_symbol = 1

In [3]:
max_trail_num = 1280
# model params
learning_rate = 0.0001
reduce_lr_factor = 0.5
reduce_lr_patience = 5
min_learning_rate = 0.00001
early_stopping_patience = 10
batch_size = 32
epochs = 13
validation_split = 0.2
metric_cut_percent = 0.2
metric_lower_cut_percent = 0.2
metric_upper_cut_percent = 0.2
noise_deviation = 0.00005

# data augmentation params
data_augmentation_factor = 3
step_size = 40
test_data_size = 10

In [4]:
ko_directory = r'/data/KO/'
wt_directory = r'/data/WT/'
print('KO data dir exists: ' + str(os.path.exists(ko_directory)))
print('WT data dir exists: ' + str(os.path.exists(wt_directory)))

KO data dir exists: True
WT data dir exists: True


In [5]:
def read_x_data(file_dir, data_size, file_prefix):
    train_x = []
    train_record = []
    for i in range(data_size):
        train_x.append(pd.read_excel(os.path.join(file_dir, file_prefix + str(i+1) + '.xlsx'), dtype='int16', header=None, sheet_name='Sheet1'))
        train_record.append(pd.read_excel(os.path.join(file_dir, file_prefix + str(i+1) + '.xlsx'), dtype='int16', header=None, sheet_name='Sheet2'))

    return train_x, train_record

# Data Pre-Process

In [6]:
# def downsample(data, time_factor=2):
#     """
#     input shape： (n_samples, trails, timesteps, features)
#     output shape： (n_samples, trails, timesteps//factor, features)
#     """
#     return data[:, ::time_factor, :]

In [7]:
def train_data_pre_process(behavior_data_list, record_data_list, with_noise):
    data_size = len(behavior_data_list)
    pre_processed_data = []
    
    for i in range(data_size):
        behavior_data = np.expand_dims(behavior_data_list[i].values[:, 200:1000], axis=-1)
        record_data = record_data_list[i]
        operation_data = record_data.loc[0]
        odor1_data = np.zeros((behavior_data.shape[0], behavior_data.shape[1], 1))
        odor2_data = np.zeros((behavior_data.shape[0], behavior_data.shape[1], 1))
        reward_data = np.zeros((behavior_data.shape[0], behavior_data.shape[1], 1))
        cur_trail_count_data = np.zeros((behavior_data.shape[0], behavior_data.shape[1], 1))
        trail_count_data = np.zeros((behavior_data.shape[0], behavior_data.shape[1], 1))
        
        sum_trail_count = operation_data.shape[0]
        for j in range(sum_trail_count):
            lick_data = behavior_data[j]
            trail_result = -1
            odor = -1

            if(operation_data[j]==1):
                odor = 2
                odor2_data[j, 0:100, :] = 1

                # search lick index
                lick_index = -1
                for k in range(300,500):
                    if(lick_data[k][0] == 1):
                        lick_index = k
                        break

                if(lick_index > 0):
                    reward_data[j, (lick_index +1):(lick_index + 51), :] = 1
                    trail_result = 1
                else:
                    trail_result = 2
            else:
                odor = 1
                odor1_data[j, 0:100, :] = 1

                # search lick index
                lick_index = -1
                for k in range(0,600):
                    if(lick_data[k]==1):
                        lick_index = k
                        break

                # append result data
                if(lick_index>0):
                    trail_result = 3
                else:
                    trail_result = 4

            cur_trail_count_data[j, :, :] = (j + 1) / sum_trail_count
        
        trail_count_data[:sum_trail_count, :, :] = sum_trail_count / 1850.0
        x_data = np.concatenate((behavior_data, odor1_data, odor1_data, reward_data, cur_trail_count_data, trail_count_data), axis=2)
        
        if with_noise:
            noise = np.random.normal(0, noise_deviation, x_data.shape)
            x_data += noise

        pre_processed_data.append(x_data)
        
    return pre_processed_data

In [8]:
def get_preprocessed_x_data(file_dir, data_size, file_prefix, with_noise):
    lick_data, record_data = read_x_data(file_dir, data_size, file_prefix)
    pre_processed_data = train_data_pre_process(lick_data, record_data, with_noise)
    return pre_processed_data

In [9]:
def padding_x_data(x_data_list, max_trail_count):
    padded_x_data_list = []
    for x_data in x_data_list:
        if max_trail_count > len(x_data):
            padding_size = max_trail_count - len(x_data)
            padded = np.zeros((padding_size, x_data.shape[1], x_data.shape[2]))
            padded_x_data = np.concatenate((x_data, padded), axis=0)
            padded_x_data_list.append(padded_x_data)
        else:
            padded_x_data_list.append(x_data[:max_trail_count, :, :]) 
            
    return padded_x_data_list

In [10]:
def get_model_x_y(ko_x_data, wt_x_data, split=0.2):
    ko_data_length = len(ko_x_data)
    wt_data_length = len(wt_x_data)
    ko_y_data = np.zeros(ko_data_length)
    ko_y_data[:] = ko_mice_symbol
    wt_y_data = np.zeros(wt_data_length)
    wt_y_data[:] = wt_mice_symbol
    
    ko_split = int(ko_data_length * (1 - split))
    wt_split = int(wt_data_length * (1 - split))
    
    ko_x_data = np.array(ko_x_data)
    wt_x_data = np.array(wt_x_data)
    
    train_x = np.concatenate([ko_x_data[:ko_split], wt_x_data[:wt_split]])
    train_y = np.concatenate([ko_y_data[:ko_split], wt_y_data[:wt_split]])
    
    val_x = np.concatenate([ko_x_data[ko_split:], wt_x_data[wt_split:]])
    val_y = np.concatenate([ko_y_data[ko_split:], wt_y_data[wt_split:]])
    
    x_data = np.concatenate([train_x, val_x])
    y_data = np.concatenate([train_y, val_y])
    
    return x_data, y_data

In [11]:
def cal_max_trail_count(data_array):
    return max([data.shape[0] for data in data_array]);

In [12]:
def get_sliding_window_size(x_data, augmentation_factor, step_size):
    trail_nums = [x.shape[0] for x in x_data]
    max_trail_num = max(trail_nums)
    if augmentation_factor <= 1:
        return max_trail_num
    return max_trail_num - (augmentation_factor - 1) * step_size

In [13]:
def sliding_windows(x_data, y_data, window_size, step_size, augmentation_factor):
    x_data_augmented = []
    y_data_augmented = np.repeat(y_data, max(1, augmentation_factor))
    if augmentation_factor <= 1:
        return x_data, y_data_augmented
    for i in range(x_data.shape[0]):
        for j in range(0, x_data.shape[1] - window_size + 1, step_size):
            single_window = x_data[i, j:j + window_size, :, :]
            x_data_augmented.append(single_window)
        
    return np.array(x_data_augmented), y_data_augmented

In [14]:
def get_max_trail_count(x_data_list):
    max_trail_count = 0
    for x_data in x_data_list:
        cur_trail_count = x_data.shape[0]
        if cur_trail_count > max_trail_count:
            max_trail_count = cur_trail_count
    return max_trail_count

In [15]:
import random

def shuffle_lists(*lists):
    for lst in lists:
        random.shuffle(lst)

def interleave_lists(list1, list2, list3):
    len1, len2, len3 = len(list1), len(list2), len(list3)
    
    total_length = len1 + len2 + len3
    
    weight1 = len1 / total_length
    weight2 = len2 / total_length
    weight3 = len3 / total_length
    
    result = []
    
    used1 = used2 = used3 = 0
    
    while used1 + used2 + used3 < total_length:
        ideal1 = (used1 + used2 + used3 + 1) * weight1
        ideal2 = (used1 + used2 + used3 + 1) * weight2
        ideal3 = (used1 + used2 + used3 + 1) * weight3
        
        diff1 = ideal1 - used1 if used1 < len1 else float('-inf')
        diff2 = ideal2 - used2 if used2 < len2 else float('-inf')
        diff3 = ideal3 - used3 if used3 < len3 else float('-inf')
        
        max_diff = max(diff1, diff2, diff3)
        
        if max_diff == diff1:
            result.append(list1[used1])
            used1 += 1
        elif max_diff == diff2:
            result.append(list2[used2])
            used2 += 1
        else:
            result.append(list3[used3])
            used3 += 1
    
    return result

In [16]:
ko_1_x_data = get_preprocessed_x_data(ko_directory, ko_1_size, ko_1_file_prefix, True)
wt_1_x_data = get_preprocessed_x_data(wt_directory, wt_1_size, wt_1_file_prefix, True)

In [17]:
ko_2_x_data = get_preprocessed_x_data(ko_directory, ko_2_size, ko_2_file_prefix, True)
wt_2_x_data = get_preprocessed_x_data(wt_directory, wt_2_size, wt_2_file_prefix, True)

In [18]:
ko_3_x_data = get_preprocessed_x_data(ko_directory, ko_3_size, ko_3_file_prefix, True)
wt_3_x_data = get_preprocessed_x_data(wt_directory, wt_3_size, wt_3_file_prefix, True)

In [19]:
ko_x_data = interleave_lists(ko_1_x_data, ko_2_x_data, ko_3_x_data)

wt_x_data = interleave_lists(wt_1_x_data, wt_2_x_data, wt_3_x_data)

In [20]:
ko_data_len = len(ko_x_data)
wt_data_len = len(wt_x_data)

In [21]:
padded_ko_x_data = padding_x_data(ko_x_data, max_trail_num)
padded_wt_x_data = padding_x_data(wt_x_data, max_trail_num)
all_x_data_list = padded_ko_x_data + padded_wt_x_data

In [22]:
x_data, y_data = get_model_x_y(padded_ko_x_data, padded_wt_x_data, validation_split)

In [23]:
x_data.shape

(146, 1280, 800, 6)

In [24]:
window_size = get_sliding_window_size(x_data, data_augmentation_factor, step_size)

In [25]:
window_size

1200

In [26]:
x_train, y_train = sliding_windows(x_data, y_data, window_size, step_size, data_augmentation_factor)

In [27]:
del ko_1_x_data
del wt_1_x_data
del ko_2_x_data
del wt_2_x_data
del ko_3_x_data
del wt_3_x_data

del ko_x_data
del wt_x_data

del padded_ko_x_data
del padded_wt_x_data
del all_x_data_list

del x_data
del y_data

In [28]:
x_train.shape

(438, 1200, 800, 6)

In [29]:
from sklearn.metrics import accuracy_score
from keras.models import Sequential, Model
from keras.layers import Dense, Dropout, TimeDistributed, LSTM, GlobalAveragePooling1D, Flatten, Bidirectional, GRU, Conv1D, MaxPooling1D
from keras.optimizers import AdamW
# keras.optimizers.Adam runs slowly on M1,M2, so use keras.optimizers.legacy.Adam instead
from keras.optimizers.legacy import Adam
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
from keras.regularizers import l1, l2
import tensorflow as tf
import keras.backend as K

In [30]:
import tensorflow as tf
import keras.backend as K

def get_sorted_error_sliced(y_true, y_pred):
    
    error = K.abs(y_pred - y_true)
    error_transpose = tf.transpose(error)
    sorted_error = tf.sort(error_transpose)
    
    num_samples = tf.shape(sorted_error)[1]
    num_to_remove = num_samples // 5
    
    sorted_error_sliced = sorted_error[:, :-num_to_remove]
    
    return sorted_error_sliced

In [31]:
import tensorflow as tf
import keras.backend as K

def custom_error(y_true, y_pred):
    
    sorted_error_sliced = get_sorted_error_sliced(y_true, y_pred)
    
    mse = K.mean(K.square(sorted_error_sliced))
    return mse

In [32]:
import tensorflow as tf
import keras.backend as K

def custom_accuracy(y_true, y_pred):
    sorted_error_sliced = get_sorted_error_sliced(y_true, y_pred)
    
    correct_predictions = 1 - K.abs(sorted_error_sliced)
    
    return K.mean(correct_predictions)

In [33]:
def weighted_mse(y_true, y_pred):
    sorted_error_sliced = get_sorted_error_sliced(y_true, y_pred)

    mse = K.mean(K.square(sorted_error_sliced))
    
    errors = K.abs(sorted_error_sliced)
    
    weights = K.exp(-K.square(errors))
    
    weighted_errors = K.square(errors) * weights
    
    weighted_mse = K.mean(weighted_errors)
    
    return weighted_mse

In [34]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras.losses import binary_crossentropy

def custom_binary_crossentropy(y_true, y_pred):
    loss = binary_crossentropy(y_true, y_pred)
    
    sorted_indices = tf.argsort(loss, axis=0)
    sorted_loss = tf.gather(loss, sorted_indices)
    
    num_data_points = tf.shape(loss)[0]
    num_keep = num_data_points - num_data_points // 5
    
    reduced_loss = tf.reduce_mean(sorted_loss[:num_keep])
    
    return reduced_loss

In [35]:
from tensorflow.keras import backend as K

def custom_accuracy(y_true, y_pred):
    return K.mean(1 - K.abs(y_true - K.round(y_pred)))

In [36]:
class F1Score(tf.keras.metrics.Metric):
    def __init__(self, name='f1_score', **kwargs):
        super().__init__(name=name, **kwargs)
        self.precision = Precision()
        self.recall = Recall()

    def update_state(self, y_true, y_pred, sample_weight=None):
        self.precision.update_state(y_true, y_pred, sample_weight)
        self.recall.update_state(y_true, y_pred, sample_weight)

    def result(self):
        p = self.precision.result()
        r = self.recall.result()
        return 2 * ((p * r) / (p + r + 1e-6))

    def reset_states(self):
        self.precision.reset_states()
        self.recall.reset_states()

In [37]:
import tensorflow as tf
from tensorflow.keras.metrics import Metric

class MatthewsCorrelationCoefficient(Metric):
    def __init__(self, name='mcc', threshold=0.5, **kwargs):
        super(MatthewsCorrelationCoefficient, self).__init__(name=name, **kwargs)
        self.threshold = threshold
        self.true_positives = self.add_weight(name='tp', initializer='zeros')
        self.true_negatives = self.add_weight(name='tn', initializer='zeros')
        self.false_positives = self.add_weight(name='fp', initializer='zeros')
        self.false_negatives = self.add_weight(name='fn', initializer='zeros')

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.cast(y_pred > self.threshold, tf.float32)
        y_true = tf.cast(y_true, tf.float32)

        tp = tf.reduce_sum(y_true * y_pred)
        tn = tf.reduce_sum((1 - y_true) * (1 - y_pred))
        fp = tf.reduce_sum((1 - y_true) * y_pred)
        fn = tf.reduce_sum(y_true * (1 - y_pred))

        self.true_positives.assign_add(tp)
        self.true_negatives.assign_add(tn)
        self.false_positives.assign_add(fp)
        self.false_negatives.assign_add(fn)

    def result(self):
        numerator = (self.true_positives * self.true_negatives - 
                    self.false_positives * self.false_negatives)
        
        denominator = tf.sqrt(
            (self.true_positives + self.false_positives) *
            (self.true_positives + self.false_negatives) *
            (self.true_negatives + self.false_positives) *
            (self.true_negatives + self.false_negatives) + 
            tf.keras.backend.epsilon()
        )
        
        return numerator / denominator

    def reset_states(self):
        self.true_positives.assign(0)
        self.true_negatives.assign(0)
        self.false_positives.assign(0)
        self.false_negatives.assign(0)

In [38]:
from tensorflow.keras.metrics import Metric, SpecificityAtSensitivity

class BalancedAccuracy(Metric):
    def __init__(self, name='balanced_acc', **kwargs):
        super().__init__(name=name, **kwargs)
        self.recall = Recall()
        self.specificity = SpecificityAtSensitivity(0.5)
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        self.recall.update_state(y_true, y_pred, sample_weight)
        self.specificity.update_state(y_true, y_pred, sample_weight)
        
    def result(self):
        return (self.recall.result() + self.specificity.result()) / 2
    
    def reset_states(self):
        self.recall.reset_states()
        self.specificity.reset_states()

In [39]:
from keras.mixed_precision import set_global_policy
set_global_policy('mixed_float16')

The dtype policy mixed_float16 may run slowly because this machine does not have a GPU. Only Nvidia GPUs with compute capability of at least 7.0 run quickly with mixed_float16.


In [40]:
import gc
import numpy as np
from sklearn.model_selection import KFold
from keras import backend as K
from keras.metrics import Precision, Recall, AUC
import time
import pickle
from keras.layers import LayerNormalization, MultiHeadAttention, GlobalMaxPooling1D, ReLU, BatchNormalization, Attention, Reshape, Conv1D, GlobalAveragePooling1D, Dense, Input, MultiHeadAttention

K.clear_session()
gc.collect()

model = Sequential()

model.add(TimeDistributed(Conv1D(16, kernel_size=25, activation='relu', 
       padding='same',
       kernel_initializer='he_normal'),
          input_shape=x_train.shape[1:]))
model.add(TimeDistributed(BatchNormalization(momentum=0.95)))
model.add(TimeDistributed(MaxPooling1D(4)))
model.add(TimeDistributed(Flatten()))

model.add(Bidirectional(LSTM(128, return_sequences=True, kernel_regularizer=l2(0.002))))
model.add(Dropout(0.5))
model.add(Bidirectional(LSTM(128, return_sequences=True, kernel_regularizer=l2(0.002))))
model.add(Dropout(0.5))

model.add(GlobalAveragePooling1D())
model.add(Dense(64, activation='relu', kernel_regularizer=l2(0.005)))
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))

optimizer = Adam(learning_rate = learning_rate)

model.compile(optimizer = optimizer, loss = 'binary_crossentropy', metrics = [
                  'accuracy', 
                  Precision(name='precision'), 
                  Recall(name='recall'),
                  AUC(name='auc'),
                  F1Score(),
                  MatthewsCorrelationCoefficient(name='mcc'),
                  BalancedAccuracy()])

model.summary()

reduce_lr = ReduceLROnPlateau(monitor = 'val_loss',
                              factor = reduce_lr_factor,
                              patience = reduce_lr_patience,
                              min_lr = min_learning_rate)
early_stopping = EarlyStopping(monitor = 'val_loss', patience = early_stopping_patience)

history = model.fit(x_train, y_train,
                    batch_size = batch_size, 
                    epochs = epochs)

model.save('conv_bilstm_model.keras')
with open('conv_bilstm_history.pkl', 'wb') as file_pi:
    pickle.dump(history.history, file_pi)

K.clear_session()
gc.collect()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 time_distributed (TimeDist  (None, 1200, 800, 16)     2416      
 ributed)                                                        
                                                                 
 time_distributed_1 (TimeDi  (None, 1200, 800, 16)     64        
 stributed)                                                      
                                                                 
 time_distributed_2 (TimeDi  (None, 1200, 200, 16)     0         
 stributed)                                                      
                                                                 
 time_distributed_3 (TimeDi  (None, 1200, 3200)        0         
 stributed)                                                      
                                                                 
 bidirectional (Bidirection  (None, 1200, 256)         3

  m.reset_state()
  m.reset_state()
  m.reset_state()


Epoch 3/13
Epoch 4/13
Epoch 5/13
Epoch 6/13
Epoch 7/13
Epoch 8/13
Epoch 9/13
Epoch 10/13
Epoch 11/13
Epoch 12/13
Epoch 13/13


635