Encoder Training
==============

In [1]:
%pylab notebook

import tensorflow as tf
from tensorflow.keras import layers

import primo.models
import primo.datasets


Matplotlib created a temporary config/cache directory at /tmp/matplotlib-1kdc0kbz because the default path (/tf/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


Populating the interactive namespace from numpy and matplotlib


Reserve space on the GPU for running simulations. It's important to do this before running any tensorflow code (which will take all available GPU memory):

Load up the training and validation datasets:

In [2]:
train_dataset = primo.datasets.OpenImagesTrain(
    '/tf/open_images/train/', switch_every=10**5
)

validation_dataset = primo.datasets.OpenImagesVal('/tf/open_images/validation/')

In [3]:
def keras_batch_generator(dataset_batch_generator, similarity_threshold):
    # Yield datasets
    # TODO: Verify with Callie this understanding is correct https://github.com/uwmisl/cas9-similarity-search/issues/2
    while True:
        # This tuple contains:
        # indices: a positive integer uniquely identifying an image. This index is obtained by enumerating all the images in the dataset (before splitting them into test/train/validate datasets)
        # pairs:
        indices, pairs = next(dataset_batch_generator)
        # The Euclidean distances between the two vectors in each pair
        distances = np.sqrt(np.square(pairs[:,0,:] - pairs[:,1,:]).sum(1))
        # Whether or not the images in this pair should be considered 'similar'. This is a boolean value, represented by an int (0 or 1), and is determined by whether the aforementioned Euclidean distances between image feature vectors are under some pre-deterined "similarity threshold".
        similar = (distances < similarity_threshold).astype(int)
        # Yield a pair of sequences, and 0-or-1 indicating whether they're similar.
        yield pairs, similar

In [4]:
# To see how this value was derived, please consult the Materials and Methods subsection under Feature Extraction section.
similarity_threshold = 75
# Intuitively determined:
encoder_training_dataset_batch_size = 100
# Intuitively determined:
encoder_validation_dataset_batch_size = 2500

encoder_train_batches = keras_batch_generator(
    train_dataset.balanced_pairs(encoder_training_dataset_batch_size, similarity_threshold),
    similarity_threshold
)

encoder_val_batches = keras_batch_generator(
    validation_dataset.random_pairs(encoder_validation_dataset_batch_size),
    similarity_threshold
)

# TODO: The new predictor is the nucleaseq Cas9 predictor. https://github.com/uwmisl/cas9-similarity-search/issues/3
predictor_train_batch_size = 1000
predictor_train_batches = train_dataset.random_pairs(predictor_train_batch_size)

Create the models and stack them together with the trainer:

In [5]:
# Yield predictor here is a differentiable DNA hybridization yield predictor (originally learned from the Nupack simulator). Represented in brown to the right of the one-hot box.
![big](../../documentation/similarity_search_schematic.jpg)

/bin/sh: 1: Syntax error: word unexpected (expecting ")")


In [6]:
encoder = primo.models.Encoder()

# TODO: Replace the yield_predictor with the nucleaseq Cas9 predictor, use that here instead. https://github.com/uwmisl/cas9-similarity-search/issues/3 
#yield_predictor = primo.models.Predictor('/tf/primo/data/models/yield-model.h5')
yield_predictor = primo.models.PredictorFunction()
encoder.model.compile()
yield_predictor.model.compile()
encoder_trainer = primo.models.EncoderTrainer(encoder, yield_predictor)

Run the training!

In [7]:
encoder_trainer.model.compile(tf.keras.optimizers.Adagrad(1e-3), 'binary_crossentropy')

In [8]:
history = encoder_trainer.model.fit_generator(
    encoder_train_batches,
#     steps_per_epoch = 1000,
#     epochs = 100,
    steps_per_epoch = 500,
    epochs = 10,
    validation_data = encoder_val_batches,
    validation_steps = 1,
    verbose = 2
)



switching to train_7.h5 and train_a.h5
Epoch 1/10
500/500 - 21s - loss: 1.7592 - val_loss: 0.0861
Epoch 2/10
500/500 - 20s - loss: 1.6681 - val_loss: 0.0832
Epoch 3/10
500/500 - 20s - loss: 1.6605 - val_loss: 0.0765
Epoch 4/10
500/500 - 20s - loss: 1.6576 - val_loss: 0.0814
Epoch 5/10
500/500 - 20s - loss: 1.6559 - val_loss: 0.0842
Epoch 6/10
500/500 - 20s - loss: 1.6546 - val_loss: 0.0928
Epoch 7/10
500/500 - 20s - loss: 1.6552 - val_loss: 0.0855
Epoch 8/10
500/500 - 20s - loss: 1.6512 - val_loss: 0.0726
Epoch 9/10
500/500 - 20s - loss: 1.6503 - val_loss: 0.0817
Epoch 10/10
500/500 - 20s - loss: 1.6510 - val_loss: 0.0867


Save the models:

In [34]:
encoder.save('/tf/primo/data/models/encoder-model-short.h5')

In [27]:
encoder_trainer.model.summary()


Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 2, 4096)]    0                                            
__________________________________________________________________________________________________
lambda_2 (Lambda)               ((None, 4096), (None 0           input_1[0][0]                    
__________________________________________________________________________________________________
encoder (Sequential)            (None, 20, 4)        8554576     lambda_2[0][0]                   
                                                                 lambda_2[0][1]                   
__________________________________________________________________________________________________
lambda_4 (Lambda)               (None, 20, 4, 2)     0           encoder[0][0]              

In [28]:
import primo.tools.filepath as filepaths
import primo.tools.sequences as seqtools
import pandas as pd
query_features_filepath = filepaths.get_query_features_path(isDocker=True)
query_features = pd.read_hdf(query_features_filepath)
query_seqs = encoder.encode_feature_seqs(query_features)
print(f"Query Seqs: {query_seqs}")

def seq_str_to_input(seq):
    return np.transpose(seqtools.seqs_to_onehots(seq), [1, 0, 2])

Query Seqs: ['TAAAAAAAAAAAAGAAAAAA' 'TAAAAAAAAAAAAGAAAAAA' 'GAAAAAAAAAAAAGAAAAAA']


In [33]:
a = encoder_trainer.model.predict(np.array([[query_features.loc['callie_janelle'], query_features.loc['callie_janelle']]]))
print(f"Full model: {a}")

b = encoder_trainer.predictor.model.predict(np.array([
    np.concatenate([
        seq_str_to_input('TAAAAAAAAAAAAGAAAAAA'),
        seq_str_to_input('TAAAAAAAAAAAAGAAAAAA'),
    ]),
]))
print(f"Predictor with sequences: {b}")

Full model: [0.0324022]
Predictor with sequences: [1.]


In [12]:
encoder_trainer.calcdists.predict(np.array([[query_features.loc['callie_janelle'], query_features.loc['luis_lego']]]))

array([101.17676], dtype=float32)

In [13]:
batch = next(encoder_train_batches)

In [14]:
pair = batch[0][12]
seqs = np.array([encoder.model.predict(pair)])
print(seqs.shape)
print(encoder_trainer.predictor.model.predict(seqs))
print(encoder_trainer.calcdists.predict(np.array([pair])))

(1, 2, 20, 4)
[0.03016231]
[96.10688]


In [32]:

encoder_trainer.predictor.model(np.array([
    np.concatenate([
        seq_str_to_input('TAAAAAAAAAAAAGAAAAAA'),
        seq_str_to_input('TAAAAAAAAAAAAGAAAAAA'),
    ]),
    np.concatenate([
        seq_str_to_input('GACATCAACGAACAAAGTAA'),
        seq_str_to_input('GAAAACAAAAAAAAAAAAAA'),
    ]),
]))
#print(np.transpose(seqtools.seqs_to_onehots('GAAAACAAAAAAAAAAAAAA'), [1, 0, 2]).shape)


<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1.000000e+00, 1.999998e-04], dtype=float32)>

In [16]:
encoder_trainer.model.predict(batch[0])

array([0.0460593 , 0.02812302, 0.02792858, 0.05731571, 0.03083841,
       0.02804673, 0.03832094, 0.04700711, 0.02779615, 0.0283931 ,
       0.04557332, 0.0332604 , 0.03016231, 0.05645469, 0.03302583,
       0.0484863 , 0.02797938, 0.05400315, 0.02882586, 0.03496324,
       0.04873855, 0.04492734, 0.04271714, 0.04292899, 0.0295831 ,
       0.02814495, 0.02986963, 0.03983732, 0.02804759, 0.0296664 ,
       0.04291183, 0.04913203, 0.0594243 , 0.03603164, 0.05822189,
       0.02809034, 0.03955992, 0.03785758, 0.04759109, 0.02794483,
       0.02903815, 0.02818676, 0.03485832, 0.05141601, 0.05777366,
       0.03794867, 0.05279316, 0.03054041, 0.02891738, 0.04118821,
       0.04120168, 0.02811972, 0.02828982, 0.02797928, 0.03543286,
       0.02983895, 0.05016077, 0.0286609 , 0.05125837, 0.03111096,
       0.02861081, 0.04331847, 0.03831703, 0.05131742, 0.02793324,
       0.0281919 , 0.05273511, 0.02894383, 0.04839699, 0.04375416,
       0.05929791, 0.02888922, 0.04748857, 0.03278374, 0.03267