In [1]:
import tensorflow as tf
import numpy as np

import h5py, os, io
import requests as rq

from tfomics import moana, evaluate, impress
import matplotlib.pyplot as plt

Matplotlib is building the font cache; this may take a moment.


In [2]:
data = rq.get('https://www.dropbox.com/s/c3umbo5y13sqcfp/synthetic_dataset.h5?raw=true')
data.raise_for_status()

with h5py.File(io.BytesIO(data.content), 'r') as dataset:
    x_train = np.array(dataset['X_train']).astype(np.float32).transpose([0, 2, 1])
    y_train = np.array(dataset['Y_train']).astype(np.float32)
    x_valid = np.array(dataset['X_valid']).astype(np.float32).transpose([0, 2, 1])
    y_valid = np.array(dataset['Y_valid']).astype(np.int32)
    x_test = np.array(dataset['X_test']).astype(np.float32).transpose([0, 2, 1])
    y_test = np.array(dataset['Y_test']).astype(np.int32)

print(y_test[0])

[0 1 0 0 0 0 0 0 0 0 0 0]


In [10]:
names = []
aupr = []
auroc = []
losses = []
match_fracs = []
false_fracs = []

category = input('Category: ')
models = [i for i in os.walk(os.path.join('results', category))][0][1]

for i in range(len(models)):
    num_filters = moana.count_meme_entries(os.path.join('motifs', category, models[i] + '.txt'))
    motif_dir = os.path.join('results', category, models[i], 'tomtom.tsv')

    match_frac, match_any, filter_matches, filter_qvalues, motif_qvalues, hit_counts = evaluate.motif_comparison_synthetic_dataset(motif_dir, num_filters=num_filters)

    names.append(models[i])
    match_fracs.append(match_frac)
    false_fracs.append(match_any - match_frac)

    model = tf.keras.models.load_model(h5py.File(os.path.join('models', category, models[i] + '.h5'), 'r'))
    print(model.name)
    results = model.evaluate(x_test, y_test)
    losses.append(results[0])
    auroc.append(results[1])
    aupr.append(results[2])

    """
    index = [type(i) for i in model.layers].index(tf.keras.layers.Activation)
    ppms = moana.filter_activations(x_test, model, layer=index, window=20, threshold=0.5)
    fig = plt.figure(figsize=(25, 4))
    impress.plot_filters(ppms, fig, num_cols=8, names=filter_matches, fontsize=14)
    """

statistics = np.array([names, losses, auroc, aupr, match_fracs, false_fracs])
np.save(os.path.join('statistics', f'{category}.npy'), statistics, allow_pickle=True)

Category: positional-encoding
model-False
model-True
