In [1]:
import tensorflow as tf
from tensorflow.keras import layers, Model

import numpy as np
import requests as rq
import os, io, h5py

from tfomics import moana, evaluate
from tfomics.layers import MultiHeadAttention

import subprocess
import shlex

In [2]:
data = rq.get('https://www.dropbox.com/s/c3umbo5y13sqcfp/synthetic_dataset.h5?raw=true')
data.raise_for_status()

with h5py.File(io.BytesIO(data.content), 'r') as dataset:
    x_train = np.array(dataset['X_train']).astype(np.float32).transpose([0, 2, 1])
    y_train = np.array(dataset['Y_train']).astype(np.float32)
    x_valid = np.array(dataset['X_valid']).astype(np.float32).transpose([0, 2, 1])
    y_valid = np.array(dataset['Y_valid']).astype(np.int32)
    x_test = np.array(dataset['X_test']).astype(np.float32).transpose([0, 2, 1])
    y_test = np.array(dataset['Y_test']).astype(np.int32)

In [6]:
inputs = layers.Input(shape=(200, 4))

# Convolutional Block
nn = layers.Conv1D(filters=32, kernel_size=19, use_bias=False, padding='same')(inputs)
nn = layers.BatchNormalization()(nn)
nn = layers.Activation('relu', name='conv_activation')(nn)
nn = layers.MaxPool1D(pool_size=4)(nn)
nn = layers.Dropout(0.1)(nn)

# Multi-Head Attention
nn, weights = MultiHeadAttention(num_heads=8, d_model=64)(nn, nn, nn)
nn = layers.Dropout(0.1)(nn)

nn = layers.Flatten()(nn)

# Feed Forward
nn = layers.Dense(512, use_bias=False)(nn)
nn = layers.BatchNormalization()(nn)
nn = layers.Activation('relu')(nn)
nn = layers.Dropout(0.5)(nn)

# Output
outputs = layers.Dense(12, activation='sigmoid')(nn)

# Compile model
model = Model(inputs=inputs, outputs=outputs)

auroc = tf.keras.metrics.AUC(curve='ROC', name='auroc')
aupr = tf.keras.metrics.AUC(curve='PR', name='aupr')
model.compile(tf.keras.optimizers.Adam(0.0005), loss='binary_crossentropy', metrics=[auroc, aupr])

# Train Model
lr_decay = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_aupr', factor=0.2, patient=5, verbose=1, min_lr=1e-7, mode='max')
model.fit(x_train, y_train, epochs=75, validation_data=(x_valid, y_valid), callbacks=[lr_decay], verbose=1, shuffle=True)

Epoch 1/75
657/657 - 3s - loss: 0.4577 - auroc: 0.5820 - aupr: 0.2096 - val_loss: 0.3844 - val_auroc: 0.6948 - val_aupr: 0.3801
Epoch 2/75
657/657 - 2s - loss: 0.3641 - auroc: 0.7234 - aupr: 0.4209 - val_loss: 0.3108 - val_auroc: 0.7992 - val_aupr: 0.5454
Epoch 3/75
657/657 - 2s - loss: 0.3194 - auroc: 0.7984 - aupr: 0.5413 - val_loss: 0.2699 - val_auroc: 0.8631 - val_aupr: 0.6489
Epoch 4/75
657/657 - 2s - loss: 0.2899 - auroc: 0.8403 - aupr: 0.6159 - val_loss: 0.2495 - val_auroc: 0.8881 - val_aupr: 0.6970
Epoch 5/75


KeyboardInterrupt: 