In [None]:
import tensorflow as tf 
import tensorflow_addons as tfa
import tensorflow.keras.layers as L
import tensorflow.keras.backend as K
import tensorflow.keras.models as M

In [None]:
p_min = 0.001
p_max = 0.999

def logloss(y_true, y_pred):
    y_pred = tf.clip_by_value(y_pred,p_min,p_max)
    return -K.mean(y_true*K.log(y_pred) + (1-y_true)*K.log(1-y_pred))

In [None]:
def create_res_model(n_features, n_features_2):
    print(f'the input dim is {n_features}, {n_features_2}')

    input_1 = L.Input(shape = (n_features,), name = 'Input1')
    input_2 = L.Input(shape = (n_features_2,), name = 'Input2')

    head_1 = M.Sequential([
        L.BatchNormalization(),
        L.Dropout(0.3),
        L.Dense(512, activation='elu'), 
        L.BatchNormalization(),
        L.Dropout(0.5),
        L.Dense(256, activation='elu')
        ],name='Head1') 

    input_3 = head_1(input_1)
    input_3_concat = L.Concatenate()([input_2, input_3])

    head_2 = M.Sequential([
        L.BatchNormalization(),
        L.Dropout(0.3),
        L.Dense(n_features_2, activation='relu'),
        L.BatchNormalization(),
        L.Dropout(0.5),
        L.Dense(n_features_2, activation='elu'),
        L.BatchNormalization(),
        L.Dropout(0.5),
        L.Dense(256, activation='relu'),
        L.BatchNormalization(),
        L.Dropout(0.5),
        L.Dense(256, activation='selu')
        ],name='Head2')

    input_4 = head_2(input_3_concat)
    input_4_avg = L.Average()([input_3, input_4]) 

    head_3 = M.Sequential([
        L.BatchNormalization(),
        L.Dropout(0.3),
        L.Dense(256, activation='swish'),
        L.BatchNormalization(),
        L.Dense(256, activation='selu'),
        L.BatchNormalization(),
        L.Dense(206, activation='sigmoid')
        ],name='Head3')

    output = head_3(input_4_avg)


    model = M.Model(inputs = [input_1, input_2], outputs = output)
    model.compile(optimizer=tf.optimizers.Adam(lr=0.002),
                  loss=tf.keras.losses.BinaryCrossentropy(label_smoothing=0.001), metrics=logloss)

    return model