In [1]:
import json
import h5py
from keras.models import model_from_json
import numpy as np
from train_autoencoder import smile_convert
from sample_autoencoder import load_test_data

Couldn't import dot_parser, loading of dot files will not be possible.


Using Theano backend.


## Use pre-trained weights and model

In [2]:
model_file = "../data/best_vae_model.json"
weights_file = "../data/best_vae_annealed_weights.h5"
char_file = "../data/zinc_char_list.json"
test_file = "../data/250k_rndm_zinc_drugs_clean.smi"
limit = 20 ## number of test data points to use

In [3]:
char_list = json.load(open(char_file))
model_dict = json.load(open(model_file))

## The saved model file includes both the encoder and decoder in a single Sequential model.

We're going to split this out into separate encoder and decoder models.

This involves first locating the `VariationalDense` layer which generates the mean and variance of the encoding, and then setting up an encoder model which just runs through the first layers.

In [4]:
vae_layer = np.argmax([L["name"] == "VariationalDense" for L in model_dict["layers"]])

encoder = model_dict.copy()
del encoder['loss']
del encoder['optimizer']
encoder['layers'] = encoder['layers'][:vae_layer+1]
max_len = encoder["layers"][0]["batch_input_shape"][1]
n_chars = encoder["layers"][0]["batch_input_shape"][2]
encoder = model_from_json(json.dumps(encoder))

In [5]:
def set_encoder_weights(weights_file, model, vae_layer):
    with h5py.File(weights_file, mode='r') as fp:
        for k in range(min(fp.attrs['nb_layers'], vae_layer+1)):
            print "setting weights for layer", k, vae_layer
            g = fp['layer_{}'.format(k)]
            weights = [g['param_{}'.format(p)] for p in range(g.attrs['nb_params'])]
            w_shape = [i.shape for i in weights]
            print('Weights for this layer have shapes {}'.format(w_shape))
            try:
                model.layers[k].set_weights(weights)
            except AssertionError:
                print('Failed loading weights on layer {}. '
                                   'Weights initiated with random'.format(k))
                continue
                
set_encoder_weights(weights_file, encoder, vae_layer)

setting weights for layer 0 8
Weights for this layer have shapes [(9, 35, 9, 1), (9,)]
setting weights for layer 1 8
Weights for this layer have shapes [(9,), (9,), (9,), (9,)]
setting weights for layer 2 8
Weights for this layer have shapes [(9, 9, 9, 1), (9,)]
setting weights for layer 3 8
Weights for this layer have shapes [(9,), (9,), (9,), (9,)]
setting weights for layer 4 8
Weights for this layer have shapes [(10, 9, 11, 1), (10,)]
setting weights for layer 5 8
Weights for this layer have shapes [(10,), (10,), (10,), (10,)]
setting weights for layer 6 8
Weights for this layer have shapes []
setting weights for layer 7 8
Weights for this layer have shapes [(940, 435), (435,)]
setting weights for layer 8 8
Weights for this layer have shapes [(435, 292), (292,), (435, 292), (292,)]


## Decoder

For the decoder to work, we also need to add the ability to take an input directly, since it is no longer attached to a preceding layer.

In [6]:
decoder = model_dict.copy()
decoder['layers'] = decoder['layers'][vae_layer+1:]
decoder['layers'][0]['batch_input_shape'] = encoder.output_shape
del decoder['optimizer'] ## Why is this necessary? This is a problem.
decoder = model_from_json(json.dumps(decoder))
decoder.sample_weight_mode = None

In [7]:
def set_decoder_weights(weights_file, model, vae_layer):
    with h5py.File(weights_file, mode='r') as fp:
        for k in range(vae_layer+1, fp.attrs['nb_layers']):
            decoder_ix = k - vae_layer - 1
            print "setting weights for layer", k, decoder_ix, vae_layer
            g = fp['layer_{}'.format(k)]
            weights = [g['param_{}'.format(p)] for p in range(g.attrs['nb_params'])]
            w_shape = [i.shape for i in weights]
            print('Weights for this layer have shapes {}'.format(w_shape))
            try:
                model.layers[decoder_ix].set_weights(weights)
            except AssertionError:
                print('Failed loading weights on layer {}. '
                                   'Weights initiated with random'.format(k))
                continue
                
set_decoder_weights(weights_file, decoder, vae_layer)

setting weights for layer 9 0 8
Weights for this layer have shapes [(292,), (292,), (292,), (292,)]
setting weights for layer 10 1 8
Weights for this layer have shapes [(292, 292), (292,)]
setting weights for layer 11 2 8
Weights for this layer have shapes []
setting weights for layer 12 3 8
Weights for this layer have shapes [(292, 501), (501, 501), (501,), (292, 501), (501, 501), (501,), (292, 501), (501, 501), (501,)]
setting weights for layer 13 4 8
Weights for this layer have shapes [(501, 501), (501, 501), (501,), (501, 501), (501, 501), (501,), (501, 501), (501, 501), (501,)]
setting weights for layer 14 5 8
Weights for this layer have shapes [(501, 501), (501, 501), (501,), (501, 501), (501, 501), (501,), (501, 501), (501, 501), (501,)]
setting weights for layer 15 6 8
Weights for this layer have shapes [(501, 35), (35, 35), (35,), (501, 35), (35, 35), (35,), (501, 35), (35, 35), (35,), (35, 35)]


## Try it out:

In [8]:
test_set = load_test_data(test_file, n_chars, max_len, char_list, limit=500)

('Training set size is', 500)
Training set size is 500, after filtering to max length of 120
('total chars:', 35)


In [9]:
def reconstruct_smiles(one_hot, truncate=True):
    """ If `truncate=True`, stop at first space. If `False`, don't remove internal spaces"""
    char_array = np.array(char_list)[np.argmax(one_hot, 2)]
    if truncate:
        return map(lambda x: ''.join(x).split(' ',1)[0], char_array)
    else:
        return map(lambda x: ''.join(x).strip(), char_array)

In [10]:
from theano import function

seq_input = encoder.get_first_input()
encode = function([seq_input], encoder.layers[-1].get_mean_logsigma(encoder.layers[-2].get_output()))

enc_input = decoder.get_first_input()
output = decoder.get_output()
decode = function([enc_input], output)

In [11]:
limit = 20

mu, logsigma = encode(test_set[:limit])

In [12]:
reconstruct_smiles(test_set[:limit])

[u'CCOC(=O)C[C@H](C)CNC(=O)C(=O)N1CCc2ccc(C)cc21',
 u'ClC(Cl)(Cl)c1nonc1C(Cl)(Cl)Cl',
 u'CCc1cccc(CC)c1NC(=O)NC1CC1',
 u'Cc1ccc(C)c(NC(=S)NCCc2cccs2)c1',
 u'Cc1nnsc1C(=O)Nc1nnc(-c2ccc(Br)cc2)o1',
 u'COc1ccc(C(=O)N(C)[C@@H](C)C/C(N)=N/O)cc1O',
 u'CC(C)c1nsc(NC[C@H](C2CC2)[NH+](C)C)n1',
 u'O=C(Nc1ccc(Oc2ccc(Cl)nn2)cc1)[C@@H](O)c1ccccc1',
 u'CCCCn1nc(C)c(C[NH2+]C[C@@H](C)O)c1Cl',
 u'COC(=O)c1ccc(NC(=O)c2c(C)sc3ncnc(N4CCC[C@H](C)C4)c23)cc1',
 u'CCn1cc(/C=C/C(=O)c2ccc3ccccc3c2)cn1',
 u'CCc1nc2n(n1)CCC[C@H]2NC(=O)c1ccc(-n2cc(C)cn2)cc1',
 u'NC(=O)c1ccc(NC(=O)c2cccn(Cc3ccc(F)cc3)c2=O)cc1',
 u'Cc1ccc(-c2nc(C[NH+]3CCCC[C@H]3c3cccnc3)c(C)o2)s1',
 u'COc1ccc(OC)c(S(=O)(=O)n2cc3c(=O)n(C)c(=O)n(C)c3n2)c1',
 u'O=C1/C(=C/c2ccccc2)Oc2c1ccc1c2CN(Cc2cccs2)CO1',
 u'CC[NH+](CC)[C@](C)(CC)[C@H](O)c1cscc1Br',
 u'Cc1noc(C)c1CCCNC(=O)N[C@H]1CC(=O)N(C2CC2)C1',
 u'CCCNC(=O)[C@H]1CS[C@H](c2ccccc2O)N1C(C)=O',
 u'CC[C@H](NC(=O)NCc1c(C)noc1C)c1ccc(OC)cc1']

In [13]:
x_hat = decode(mu)

In [14]:
reconstruct_smiles(x_hat, truncate=False)

[u'CCOC(=O)N[C@H](C)CNC(=O)C(=O)N CCc2ccc(C)cc21',
 u'CC (CN C(C)c1n nc1C(',
 u'CC 1cccc(C )c1NC(=O)NC1CC1',
 u'Cc1ccc( )c(NC(=O) CCc2cccs2)c1',
 u'Cc1nncc1C(=O)Nc1ccn(-c2ccc(Br)cc2)n1',
 u'COc1ccc(C(=O)N(C)[C@@H](C)CNC(C)(N)O',
 u'CC(C)c1nnc(NC C@H](CCC   [C@H](C)C)n1',
 u'O=C(Nc1ccc(-c2ccc(Cl)n 2)cc1)[C@@H](O)c1ccccc1',
 u'C#CC  nc(C)c(C[NH2+]C[C@@H](C)C',
 u'COC(=O)c1ccc(NC(=O)c2c(C)nc3ccnc(N3CCC[C@H](C)C',
 u'CC 1cc(/C= /C(=O) 2ccc3ccccc3c2)cn1',
 u'CC 1nc2n(n1)CCC[C@H] NC(=O)c1ccc(-n2cc(C)cc2)cc1',
 u'CC(=O)c1ccc(NC(=O)c2cccc(Nc3ccc(F)cc3)c2=O)cc1',
 u'Cc1ccc(-c2nc(C[NH+]3CCC [C@H]3c3ccc c3)c(C)n2)n1',
 u'COc1ccc(OC)c(S(=O)(=O)c2cc3c(   n    (    (C)c3n2)c1',
 u'O=C1/C(=C/c2ccccc2)O c     2c1N(C(c  c sc',
 u'CC[NH+](CC)[C@](C)(CC [C@H]( )c1cccc1Br',
 u'Cc1noc(C)c1CCCNC(=O)N[C@H]1CC(=O)N(C2CC2)C1',
 u'CCCN( =O)[C@H]1CC[C@H](c2ccccc1O)',
 u'CC[C@H](NC(=O)NCc1c(C)n c1C)c1ccc(OC)cc1']

In [15]:
np.mean(np.argmax(x_hat, 2) == np.argmax(test_set[:x_hat.shape[0]], 2))

0.94499999999999995

## What if we sample $z$, instead of looking just at the mean

In [16]:
from scipy.stats import norm

In [17]:
prior = norm(np.zeros(mu.shape), np.ones(mu.shape))
q_dist = norm(mu, np.exp(logsigma))
sampled = q_dist.rvs()

In [18]:
x_hat = decode(np.array(sampled, dtype=np.float32))
np.mean(np.argmax(x_hat, 2) == np.argmax(test_set[:x_hat.shape[0]], 2))

0.87541666666666662

In [19]:
(prior.logpdf(sampled) - q_dist.logpdf(sampled)).sum(1)

array([-46.62651038, -44.5470448 , -33.32653342, -45.47039925,
       -55.72006952, -34.23518116, -41.84344367, -58.49255424,
       -39.90348387, -57.36332324, -35.98827463, -58.83815762,
       -32.73888906, -61.64702754, -63.63639275, -61.20123432,
       -58.01327165, -44.96964335, -48.96139468, -42.20242053])

In [20]:
reconstruct_smiles(x_hat, truncate=False)

[u'CCOC(=O)[C@@H](C)CC[C@@H](=O)N1CC1c1cc(F)cc1',
 u'OCC(N)(C#N)c1n n',
 u'COc1cccc(C )c1NC(=O)N',
 u'Cc1ccc( )c(NC(=O)NC c2cccc2)c1',
 u'C 1nccc1C(=O)Nc1cnc(-c2cc (Br)cc2) o1',
 u'Cc1(cc( C(=O)N(C)[C@@H](C)C2C( )(F)',
 u'CN CO 1ncc(N CC23(C   3) [NH3+])(C)n1',
 u'O=C(  ccc O c2cc (Cl)n    C1)[C@@H](O)c1ccccc1',
 u'CCCc1cnc( NC(C[NH2+]C[C@@H](C)C)c1',
 u'COC(=O)c1ccc(NC =O)c2c(C)nc3cn ( N2CCC4C4)c(C)c2)cc1',
 u'CCn1c ( (=S/C(=O)c2c (C)c',
 u'CCc1cc2n(n1)CCC[C@H]2NC(=O)    c(- 2n (C)c',
 u'O=C(NOc1ccc(NC =O)Nc ccc Nc2cc(F)ccc2)c2=O)c1',
 u'Cc1ccc(-c2nc(C[NH+]2CC (C(=O)Cc3cccc 4) ([no2)c1',
 u'COc1cc(OC )c(S(=O)(=O)c2c 3c(=O) (C)c(=O)c(C)c3',
 u'O=C(NC(C#Nc2cccc 2 OCc1F  (F)N2(CCc1cccs',
 u'CC[ @H](CC)[C@H](C(C  [C@@H])Cc1cc',
 u'Clc1nc(C)c1CCCNC(=O)N[C@H]1CC(=O)N(C2CC2)C',
 u'C CCC(=O)[C@H]1C[C@@H](c2ccccc2Br) S1)[',
 u'C [C@H](NC(=O)NC NC(C)ccc 1)c1ccc(Br)c(C)c1']

In [21]:
reconstruct_smiles(x_hat, truncate=True)

[u'CCOC(=O)[C@@H](C)CC[C@@H](=O)N1CC1c1cc(F)cc1',
 u'OCC(N)(C#N)c1n',
 u'COc1cccc(C',
 u'Cc1ccc(',
 u'C',
 u'Cc1(cc(',
 u'CN',
 u'O=C(',
 u'CCCc1cnc(',
 u'COC(=O)c1ccc(NC',
 u'CCn1c',
 u'CCc1cc2n(n1)CCC[C@H]2NC(=O)',
 u'O=C(NOc1ccc(NC',
 u'Cc1ccc(-c2nc(C[NH+]2CC',
 u'COc1cc(OC',
 u'O=C(NC(C#Nc2cccc',
 u'CC[',
 u'Clc1nc(C)c1CCCNC(=O)N[C@H]1CC(=O)N(C2CC2)C',
 u'C',
 u'C']