In [1]:
# Initialize
import tensorflow as tf
import numpy as np

# Set seeds
tf.set_random_seed(42)
np.random.seed(42)
    
# Where model weights are stored.
MODEL_WEIGHT_PATH = "./data/1900_weights"

In [39]:
import os
from unirep import mLSTMCell1900, tf_get_shape, aa_seq_to_int
import pandas as pd


def is_valid_seq(seq, max_len=500):
    """
    True if seq is valid for the babbler, False otherwise.
    """
    l = len(seq)
    valid_aas = "MRHKDESTNQCUGPAVIFYWLO"
    if (l < max_len) and set(seq) <= set(valid_aas):
        return True
    else:
        return False

    
class babbler1900():

    def __init__(self, model_path="./data/1900_weights", batch_size=500):
        self._model_path = model_path
        self._batch_size = batch_size
        
        self._rnn = mLSTMCell1900(1900,
                    model_path=self._model_path,
                        wn=True)
        zero_state = self._rnn.zero_state(self._batch_size, tf.float32)

        self._embed_matrix = tf.get_variable(
            "embed_matrix", dtype=tf.float32, initializer=np.load(os.path.join(self._model_path, "embed_matrix:0.npy"))
        )
        
        with tf.Session() as sess:
            self._zero_state = sess.run(zero_state)
        
    def get_reps(self, seqs):
        seq_ints = [aa_seq_to_int(seq.strip())[:-1] for seq in seqs]
        lengths = [len(x) for x in seq_ints]
        tf_tensor = tf.convert_to_tensor(seq_ints)
        dataset = tf.data.Dataset.from_tensor_slices(tf_tensor).batch(self._batch_size)
        iterator = dataset.make_one_shot_iterator()
        input_tensor = iterator.get_next()

        embed_cell = tf.nn.embedding_lookup(self._embed_matrix, input_tensor)
        _output, _final_state = tf.nn.dynamic_rnn(
            self._rnn,
            embed_cell,
            initial_state=self._zero_state,
            swap_memory=True,
            parallel_iterations=1
        )
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            final_state_, hs = sess.run([_final_state, _output])
            assert final_state_[0].shape[0] == self._batch_size

            final_cell, final_hidden = final_state_
            avg_hidden = np.array([np.mean(x, axis=0) for x in hs])
            together = np.concatenate((avg_hidden, final_hidden, final_cell), axis=1)
            return together

# Given a pandas dataframe with a column "sequence", return a list of pandas dataframes grouped by sequence length
def create_batches(seqs):
    # Get the unique lengths of all these sequences
    batches = []
    lengths = seqs["sequence"].apply(lambda x: len(x))
    unique_lengths = lengths.unique()
    for unique_length in unique_lengths:
        boolean_mask = lengths == unique_length
        seqs_of_length = seqs[boolean_mask]
        batches += [seqs_of_length]
    print("There are {} batches".format(len(batches)))
    return batches
    
# Get representations for a numpy array of sequences where all sequences are the same length
def inference_on_seqs_array(seqs):
        tf.reset_default_graph()
        model = babbler1900(batch_size=seqs.shape[0])
        result = model.get_reps(seqs)
        return result

# Get representations for a pandas dataframe of sequences in coulmn "sequence"
def inference_on_seqs(seqs):
    # Check that all these sequences are valid
    valid_func = lambda x: is_valid_seq(x)
    valid_np = np.vectorize(valid_func)
    valids = valid_np(seqs["sequence"].values)
    assert False not in valids
    
    batches = create_batches(seqs)
    
    ids = None
    reps = pd.DataFrame(columns=list(range(0, 5700)))
    
    for batch in batches:
        print("Getting EclRep representations for {} sequences...".format(batch.shape[0]))
        reps_new = inference_on_seqs_array(batch["sequence"])
        print("Done")
        reps = reps.append(pd.DataFrame(reps_new))
        if ids is not None:
            ids = ids.append(batch)
        else:
            ids = batch
    return ids, reps
    
    
path = "./data/stability_data"
df = pd.read_table(os.path.join(path, "ssm2_stability_scores.txt"))
ids, results = inference_on_seqs(df)
print("Got {} results".format(results.shape))
assert results.shape[0] == seqs.shape[0]
assert results.shape[1] == 5700

# Check that representations are reproducible
import os
import pandas as pd
from tqdm import tqdm_notebook as tqdm

# Load the saved representations
path = "./data/stability_data"
output_path = os.path.join(path, "stability_with_unirep_fusion.hdf")
existing_seqs = pd.read_hdf(output_path, key="ids").reset_index(drop=True)
existing_reps = pd.read_hdf(output_path, key="reps").reset_index(drop=True)
assert existing_seqs.shape[0] == existing_reps.shape[0]
assert np.array_equal(existing_seqs.index, existing_reps.index)

# Create reprensetations for some seqs
print("Checking that these results match the saved truth...")
for index, row in tqdm(existing_seqs.iterrows(), total=existing_seqs.shape[0]):
    check_rep = results.iloc[index].values
    true_rep = existing_reps.iloc[index].values

    if not np.allclose(true_rep, check_rep, atol=0.0001):
        true_check_diff = abs(np.sum(true_rep - check_rep))
        print("{}: {} difference with saved truth".format(index, true_check_diff))

There are 2 batches
Getting EclRep representations for 12022 sequences...
Done
Getting EclRep representations for 829 sequences...
Done
Got (12851, 5700) results
Checking that these results match the saved truth...


HBox(children=(IntProgress(value=0, max=11380), HTML(value='')))

0: 0.00011940242984564975 difference with saved truth
1: 5.786981273558922e-05 difference with saved truth
2: 0.00014238500443752855 difference with saved truth
3: 9.430746285943314e-05 difference with saved truth
4: 0.00010332326928619295 difference with saved truth
5: 3.106344593106769e-05 difference with saved truth
6: 8.777742914389819e-05 difference with saved truth
7: 6.928145739948377e-05 difference with saved truth
8: 9.750597200763877e-06 difference with saved truth
9: 0.00010012062557507306 difference with saved truth
10: 4.042027285322547e-06 difference with saved truth
11: 5.3946343541610986e-05 difference with saved truth
12: 0.00011359375639585778 difference with saved truth
13: 0.000197594563360326 difference with saved truth
14: 4.0183083910960704e-05 difference with saved truth
15: 0.0001630683254916221 difference with saved truth
16: 0.0002511250786483288 difference with saved truth
17: 2.7528407372301444e-05 difference with saved truth
18: 4.224988879286684e-05 diffe

238: 4.7319808800239116e-05 difference with saved truth
239: 0.00017999557894654572 difference with saved truth
240: 5.330742715159431e-05 difference with saved truth
241: 0.00014007074059918523 difference with saved truth
242: 0.00013417209265753627 difference with saved truth
243: 6.232551822904497e-05 difference with saved truth
244: 0.0002834377810359001 difference with saved truth
245: 0.00028701010160148144 difference with saved truth
246: 2.9970440664328635e-05 difference with saved truth
247: 0.00014735189324710518 difference with saved truth
248: 3.706944335135631e-05 difference with saved truth
249: 0.00025496695889160037 difference with saved truth
250: 9.971870895242319e-06 difference with saved truth
251: 8.79946892382577e-05 difference with saved truth
252: 0.00012845656601712108 difference with saved truth
253: 3.612175351008773e-05 difference with saved truth
254: 0.00019585579866543412 difference with saved truth
255: 0.00013001907791476697 difference with saved truth


409: 0.0002045736473519355 difference with saved truth
410: 4.059485218022019e-05 difference with saved truth
411: 5.3010156989330426e-05 difference with saved truth
412: 0.00014775616000406444 difference with saved truth
413: 0.0001034307133522816 difference with saved truth
414: 3.8161164411576465e-05 difference with saved truth
415: 2.73379118880257e-05 difference with saved truth
416: 3.6393092159414664e-05 difference with saved truth
417: 0.0002086151362163946 difference with saved truth
418: 0.0002932990319095552 difference with saved truth
419: 1.6923157090786844e-05 difference with saved truth
420: 1.7167021724162623e-05 difference with saved truth
421: 0.00014552530774381012 difference with saved truth
422: 0.00017671435489319265 difference with saved truth
423: 0.0003173869918100536 difference with saved truth
424: 0.00014641079178545624 difference with saved truth
425: 9.437589324079454e-05 difference with saved truth
426: 7.45380821172148e-05 difference with saved truth
427

656: 2.7912134100915864e-05 difference with saved truth
657: 1.868281833594665e-05 difference with saved truth
658: 7.881095370976254e-05 difference with saved truth
659: 2.8963693694095127e-05 difference with saved truth
660: 0.00018489916692487895 difference with saved truth
661: 0.00025095121236518025 difference with saved truth
662: 0.00017284041678067297 difference with saved truth
663: 4.275114770280197e-05 difference with saved truth
664: 0.0003053766558878124 difference with saved truth
665: 3.258863580413163e-05 difference with saved truth
666: 1.9498461369948927e-06 difference with saved truth
667: 3.800525519181974e-05 difference with saved truth
668: 2.101020436384715e-05 difference with saved truth
669: 7.845467189326882e-05 difference with saved truth
670: 2.5688194000395015e-05 difference with saved truth
671: 0.0001221459824591875 difference with saved truth
672: 0.00020107263117097318 difference with saved truth
673: 6.0603528254432604e-05 difference with saved truth
6

830: 8.509230974595994e-05 difference with saved truth
831: 0.00019755699031520635 difference with saved truth
832: 0.0001449925621272996 difference with saved truth
833: 2.6121990231331438e-05 difference with saved truth
834: 0.00018272835586685687 difference with saved truth
835: 9.637288167141378e-05 difference with saved truth
836: 0.0003742755507119 difference with saved truth
837: 0.0001672345242695883 difference with saved truth
838: 6.660771032329649e-05 difference with saved truth
839: 0.00016600749222561717 difference with saved truth
840: 6.270636731642298e-06 difference with saved truth
841: 9.168932592729107e-05 difference with saved truth
842: 0.0001340503804385662 difference with saved truth
843: 7.805742643540725e-05 difference with saved truth
844: 0.00011828173592220992 difference with saved truth
845: 0.00012628280092030764 difference with saved truth
846: 4.7687361075077206e-05 difference with saved truth
847: 4.933504533255473e-05 difference with saved truth
848: 0

1062: 0.00011160288704559207 difference with saved truth
1063: 5.794597018393688e-05 difference with saved truth
1064: 0.0002439666131976992 difference with saved truth
1065: 0.0001022935175569728 difference with saved truth
1066: 0.0001838456664700061 difference with saved truth
1067: 0.00023400552163366228 difference with saved truth
1068: 6.823657167842612e-05 difference with saved truth
1069: 0.00017940913676284254 difference with saved truth
1070: 2.4985029085655697e-05 difference with saved truth
1071: 4.3666230340022594e-05 difference with saved truth
1072: 0.00012919455184601247 difference with saved truth
1073: 9.429895726498216e-05 difference with saved truth
1074: 4.8949510528473184e-05 difference with saved truth
1075: 3.362934148753993e-05 difference with saved truth
1076: 8.968840120360255e-05 difference with saved truth
1077: 0.0001692397054284811 difference with saved truth
1078: 0.00010452306742081419 difference with saved truth
1079: 3.378093242645264e-05 difference w

1225: 0.0002889984752982855 difference with saved truth
1226: 1.0737970114860218e-05 difference with saved truth
1227: 0.0005169470678083599 difference with saved truth
1228: 7.597162039019167e-05 difference with saved truth
1229: 0.0002474661450833082 difference with saved truth
1230: 0.00030042268917895854 difference with saved truth
1231: 0.00014621164882555604 difference with saved truth
1232: 2.1641344574163668e-05 difference with saved truth
1233: 0.00013159352238290012 difference with saved truth
1234: 0.00010586485586827621 difference with saved truth
1235: 0.00029307749355211854 difference with saved truth
1236: 0.000176824105437845 difference with saved truth
1237: 6.17218975094147e-05 difference with saved truth
1238: 0.00012353697093203664 difference with saved truth
1239: 7.592415204271674e-05 difference with saved truth
1240: 0.00025326735340058804 difference with saved truth
1241: 5.572684131038841e-06 difference with saved truth
1242: 0.00014503634884022176 difference w

1459: 0.0002219493326265365 difference with saved truth
1460: 5.7449142332188785e-05 difference with saved truth
1461: 0.00018064332834910601 difference with saved truth
1462: 0.00027727577253244817 difference with saved truth
1463: 7.451621058862656e-05 difference with saved truth
1464: 0.00016439124010503292 difference with saved truth
1465: 0.00019495647575240582 difference with saved truth
1466: 0.00023137207608669996 difference with saved truth
1467: 0.0002931290364358574 difference with saved truth
1468: 3.182363434461877e-05 difference with saved truth
1469: 0.00012402616266626865 difference with saved truth
1470: 1.7951060726772994e-06 difference with saved truth
1471: 0.00010872143320739269 difference with saved truth
1472: 8.184859325410798e-05 difference with saved truth
1473: 0.0002884164860006422 difference with saved truth
1474: 0.00012298689398448914 difference with saved truth
1475: 1.6110590877360664e-05 difference with saved truth
1476: 9.865735773928463e-05 differenc

1622: 0.00027332562603987753 difference with saved truth
1623: 0.0001218754259753041 difference with saved truth
1624: 0.0001264350430574268 difference with saved truth
1625: 0.0002124357270076871 difference with saved truth
1626: 0.00019856657308991998 difference with saved truth
1627: 0.00023213545500766486 difference with saved truth
1628: 0.00010152138565899804 difference with saved truth
1629: 9.000564023153856e-05 difference with saved truth
1630: 0.00017400793149136007 difference with saved truth
1631: 0.00016444142966065556 difference with saved truth
1632: 0.0004852708661928773 difference with saved truth
1633: 3.2177667890209705e-05 difference with saved truth
1634: 0.00011445891868788749 difference with saved truth
1635: 0.00012835086090490222 difference with saved truth
1636: 2.180713818233926e-05 difference with saved truth
1637: 0.00030521274311468005 difference with saved truth
1638: 6.863319140393287e-05 difference with saved truth
1639: 0.0002205528289778158 difference

1859: 0.00010993555770255625 difference with saved truth
1860: 0.00030858066747896373 difference with saved truth
1861: 0.00022829859517514706 difference with saved truth
1862: 3.683254180941731e-05 difference with saved truth
1863: 1.1405172699596733e-05 difference with saved truth
1864: 5.07099466631189e-05 difference with saved truth
1865: 0.000353882642230019 difference with saved truth
1866: 0.00019064854132011533 difference with saved truth
1867: 2.738579314609524e-05 difference with saved truth
1868: 0.00013130153820384294 difference with saved truth
1869: 0.0001212858478538692 difference with saved truth
1870: 0.0002727036189753562 difference with saved truth
1871: 0.00041355486609973013 difference with saved truth
1872: 5.764831803389825e-05 difference with saved truth
1873: 0.00018299462681170553 difference with saved truth
1874: 7.736403495073318e-05 difference with saved truth
1875: 0.0001299209543503821 difference with saved truth
1876: 0.0003877638082485646 difference wit

2022: 3.215122706023976e-05 difference with saved truth
2023: 5.302396311890334e-05 difference with saved truth
2024: 0.00029548018937930465 difference with saved truth
2025: 0.00010746487532742321 difference with saved truth
2026: 0.0001267275947611779 difference with saved truth
2027: 2.16551088669803e-05 difference with saved truth
2028: 0.00011686817742884159 difference with saved truth
2029: 0.00029934049234725535 difference with saved truth
2030: 0.00034349149791523814 difference with saved truth
2031: 0.00017123861471191049 difference with saved truth
2032: 4.583796180668287e-05 difference with saved truth
2033: 0.0001404786598868668 difference with saved truth
2034: 0.00011159062705701217 difference with saved truth
2035: 0.00036704569356516004 difference with saved truth
2036: 0.00018959846056532115 difference with saved truth
2037: 0.00025150779401883483 difference with saved truth
2038: 0.00016676721861585975 difference with saved truth
2039: 0.0003165910311508924 difference

2273: 0.00016587805293966085 difference with saved truth
2274: 0.00012942575267516077 difference with saved truth
2275: 0.00028047358500771224 difference with saved truth
2276: 4.25911603088025e-05 difference with saved truth
2277: 0.00023989110195543617 difference with saved truth



KeyboardInterrupt: 

In [19]:
import os
import pandas as pd
import numpy as np
from tqdm import tqdm_notebook as tqdm

path = "./data/stability_data"
output_path = os.path.join(path, "all_stability_with_eclrep_fusion.hdf")

reps_output = pd.DataFrame(columns=list(range(0, 5700)))

# TODO: Assert that there are no duplicate sequences
# TODO: How do I match the reps with sequences?

for filename in os.listdir(path):
    if filename.endswith(".txt"):
        print("Processing data from {}".format(filename))
        df = pd.read_table(os.path.join(path, filename))
        if "consensus_stability_score" in df.columns:
            stability_name = "consensus_stability_score"
        else:
            stability_name = "stabilityscore"
        df = df[["name", "sequence", stability_name]]
        df.rename(columns={'consensus_stability_score': 'stability', 'stabilityscore': 'stability'}, inplace=True)
        results = inference_on_seqs(df["sequence"].values)
        print(results.shape)
        
#     if model.is_valid_seq(row["sequence"], max_len=500):
#         unirep_fusion = model.get_rep(row["sequence"])
#         unirep_fusion = np.concatenate((unirep_fusion[0], unirep_fusion[1], unirep_fusion[2]))
#         print(unirep_fusion.shape)


ids_output.to_hdf(output_path, index=False, mode="a", key="ids", format="fixed")
reps_output.to_hdf(output_path, index=False, mode="a", key="reps", format="fixed")

Processing data from ssm2_stability_scores.txt
Index(['name', 'sequence', 'stability'], dtype='object')
Getting EclRep representations for 12022 sequences of length 43...


KeyboardInterrupt: 

In [None]:
ids = pd.read_hdf(output_path, key="ids")
print("{} points in ids".format(ids.shape[0]))
reps = pd.read_hdf(output_path, key="reps")
print("{} points in reps".format(reps.shape[0]))
print(reps.iloc[0])