Encoder Training
==============

In [1]:
%pylab notebook

import tensorflow as tf
from tensorflow.keras import layers

import primo.models
import primo.datasets


Matplotlib created a temporary config/cache directory at /tmp/matplotlib-_861we6v because the default path (/tf/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


Populating the interactive namespace from numpy and matplotlib


Reserve space on the GPU for running simulations. It's important to do this before running any tensorflow code (which will take all available GPU memory):

Load up the training and validation datasets:

In [2]:
train_dataset = primo.datasets.OpenImagesTrain(
    '/tf/open_images/train/', switch_every=10**5
)

validation_dataset = primo.datasets.OpenImagesVal('/tf/open_images/validation/')

In [3]:
def keras_batch_generator(dataset_batch_generator, similarity_threshold):
    # Yield datasets
    # TODO: Verify with Callie this understanding is correct https://github.com/uwmisl/cas9-similarity-search/issues/2
    while True:
        # This tuple contains:
        # indices: a positive integer uniquely identifying an image. This index is obtained by enumerating all the images in the dataset (before splitting them into test/train/validate datasets)
        # pairs:
        indices, pairs = next(dataset_batch_generator)
        # The Euclidean distances between the two vectors in each pair
        distances = np.sqrt(np.square(pairs[:,0,:] - pairs[:,1,:]).sum(1))
        # Whether or not the images in this pair should be considered 'similar'. This is a boolean value, represented by an int (0 or 1), and is determined by whether the aforementioned Euclidean distances between image feature vectors are under some pre-deterined "similarity threshold".
        similar = (distances < similarity_threshold).astype(int)
        # Yield a pair of sequences, and 0-or-1 indicating whether they're similar.
        yield pairs, similar

In [4]:
# To see how this value was derived, please consult the Materials and Methods subsection under Feature Extraction section.
similarity_threshold = 75
# Intuitively determined:
encoder_training_dataset_batch_size = 100
# Intuitively determined:
encoder_validation_dataset_batch_size = 2500

encoder_train_batches = keras_batch_generator(
    train_dataset.balanced_pairs(encoder_training_dataset_batch_size, similarity_threshold),
    similarity_threshold
)

encoder_val_batches = keras_batch_generator(
    validation_dataset.random_pairs(encoder_validation_dataset_batch_size),
    similarity_threshold
)

# TODO: The new predictor is the nucleaseq Cas9 predictor. https://github.com/uwmisl/cas9-similarity-search/issues/3
predictor_train_batch_size = 1000
predictor_train_batches = train_dataset.random_pairs(predictor_train_batch_size)

Create the models and stack them together with the trainer:

In [5]:
# Yield predictor here is a differentiable DNA hybridization yield predictor (originally learned from the Nupack simulator). Represented in brown to the right of the one-hot box.
![big](../../documentation/similarity_search_schematic.jpg)

/bin/sh: 1: Syntax error: word unexpected (expecting ")")


In [6]:
encoder = primo.models.Encoder()

# TODO: Replace the yield_predictor with the nucleaseq Cas9 predictor, use that here instead. https://github.com/uwmisl/cas9-similarity-search/issues/3 
yield_predictor = primo.models.Predictor('/tf/primo/data/models/yield-model.h5')
encoder_trainer = primo.models.EncoderTrainer(encoder, yield_predictor)

AttributeError: 'Tensor' object has no attribute 'argmax'

Run the training!

In [None]:
encoder_trainer.model.compile(tf.keras.optimizers.Adagrad(1e-3), 'binary_crossentropy')

In [None]:
history = encoder_trainer.model.fit_generator(
    encoder_train_batches,
    steps_per_epoch = 1000,
    epochs = 100,
    validation_data = encoder_val_batches,
    validation_steps = 1,
    verbose = 2
)

Save the models:

In [None]:
encoder.save('/tf/primo/data/models/encoder_model.h5')