In [4]:
!pip install -q tensorflowjs
!pip install jax-unirep
!pip install tensorflow_decision_forests 
!nvidia-smi

Collecting tensorflow_decision_forests
  Downloading tensorflow_decision_forests-0.2.4-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (13.4 MB)
[K     |████████████████████████████████| 13.4 MB 4.2 MB/s 
Collecting wurlitzer
  Downloading wurlitzer-3.0.2-py3-none-any.whl (7.3 kB)
Installing collected packages: wurlitzer, tensorflow-decision-forests
Successfully installed tensorflow-decision-forests-0.2.4 wurlitzer-3.0.2
NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver. Make sure that the latest NVIDIA driver is installed and running.



In [5]:
#@title Runtime
use_tpu = False #@param

In [6]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import os
import tensorflow as tf
import urllib
from dataclasses import dataclass
import tensorflowjs as tfjs
import tensorflow_decision_forests as tfdf
import json
np.random.seed(0)



In [7]:
urllib.request.urlretrieve(
    "https://github.com/ur-whitelab/peptide-dashboard/raw/master/ml/data/hemo-positive.npz",
    "positive.npz",
)
urllib.request.urlretrieve(
    "https://github.com/ur-whitelab/peptide-dashboard/raw/master/ml/data/hemo-negative.npz",
    "negative.npz",
)
with np.load("positive.npz") as r:
    pos_data = r[list(r.keys())[0]]
with np.load("negative.npz") as r:
    neg_data = r[list(r.keys())[0]]

# create labels and stich it all into one
# tensor
labels = np.concatenate(
    (
        np.ones((pos_data.shape[0], 1), dtype=pos_data.dtype),
        np.zeros((neg_data.shape[0], 1), dtype=pos_data.dtype),
    ),
    axis=0,
)

features = np.concatenate((pos_data, neg_data), axis=0)

In [8]:
print('Positive data', pos_data.shape[0])
print('Negative data', neg_data.shape[0])

Positive data 1826
Negative data 7490


In [9]:
def decoder(seq_vector):
  alphabet = ['A','R','N','D','C','Q','E','G','H','I', 'L','K','M','F','P','S','T','W','Y','V']
  seq = []
  for i, index in enumerate(seq_vector):
    if index == 0:
      break
    seq.append(alphabet[index-1])
  seq = ''.join(seq)
  return seq

decoded_features = []
for f in features:
  decoded_features.append(decoder(f))

In [10]:
# Generating UniRep representations
from jax_unirep import get_reps
h_avg, h_final, c_final = get_reps(decoded_features)
unirep_features = h_avg



In [19]:
@dataclass
class Config:
    batch_size: int
    buffer_size: int
    reg_strength: float
    lr: float
    drop_rate: float
        
config = Config(
                batch_size=32, 
                buffer_size=10000,
                reg_strength=0,
                lr=1e-2,
                drop_rate=0.1
               )

In [37]:
def build_model():
    # return tfdf.keras.RandomForestModel(task=tfdf.keras.Task.CLASSIFICATION,name='hemo-unirep-random-forest')
    return tfdf.keras.GradientBoostedTreesModel(num_trees=1500)

In [38]:
# we now need to shuffle before creating TF dataset
# so that our train/test/val splits are random
i = np.arange(len(labels))
np.random.shuffle(i)
labels = labels[i]
unirep_features = unirep_features[i]
data = tf.data.Dataset.from_tensor_slices((unirep_features, labels))
# now split into val, test, train and batch
N = len(data)  
split = int(0.1 * N)
test_data = data.take(split).batch(config.batch_size)
nontest = data.skip(split)
val_data, train_data = nontest.take(split).batch(config.batch_size), \
    nontest.skip(split).shuffle(config.buffer_size).batch(config.batch_size).prefetch(tf.data.experimental.AUTOTUNE)

In [68]:
tf.keras.backend.clear_session()
model = tfdf.keras.RandomForestModel()
# model = tfdf.keras.GradientBoostedTreesModel(num_trees=1500)
model.compile(
    metrics=[tf.keras.metrics.AUC(from_logits=False), tf.keras.metrics.BinaryAccuracy(threshold=0.5)])
model.fit(train_data)

Use /tmp/tmppw7u0oa_ as temporary training directory
Starting reading the dataset
Dataset read in 0:00:21.157116
Training model
Model trained in 0:01:28.520847
Compiling model


<keras.callbacks.History at 0x7fdc90abc310>

In [71]:
model.summary()

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
	31 : data:0.1453 [NUMERICAL]
	31 : data:0.1447 [NUMERICAL]
	31 : data:0.1261 [NUMERICAL]
	31 : data:0.125 [NUMERICAL]
	31 : data:0.1219 [NUMERICAL]
	31 : data:0.1125 [NUMERICAL]
	31 : data:0.1094 [NUMERICAL]
	31 : data:0.103 [NUMERICAL]
	30 : data:0.97 [NUMERICAL]
	30 : data:0.885 [NUMERICAL]
	30 : data:0.808 [NUMERICAL]
	30 : data:0.800 [NUMERICAL]
	30 : data:0.719 [NUMERICAL]
	30 : data:0.549 [NUMERICAL]
	30 : data:0.414 [NUMERICAL]
	30 : data:0.391 [NUMERICAL]
	30 : data:0.258 [NUMERICAL]
	30 : data:0.1743 [NUMERICAL]
	30 : data:0.1688 [NUMERICAL]
	30 : data:0.1558 [NUMERICAL]
	30 : data:0.1513 [NUMERICAL]
	30 : data:0.1252 [NUMERICAL]
	30 : data:0.1158 [NUMERICAL]
	30 : data:0.1114 [NUMERICAL]
	29 : data:0.902 [NUMERICAL]
	29 : data:0.878 [NUMERICAL]
	29 : data:0.787 [NUMERICAL]
	29 : data:0.674 [NUMERICAL]
	29 : data:0.673 [NUMERICAL]
	29 : data:0.518 [NUMERICAL]
	29 : data:0.460 [NUMERICAL]
	29 : data:0.411 [NUMERI

In [72]:
test_result = model.evaluate(test_data)
print(f'Test accuracy {test_result[2]:.2f}. Test AUC {test_result[1]:.2f}')

Test accuracy 0.84. Test AUC 0.78
