Skip to content

Commit

Permalink
vae update
Browse files Browse the repository at this point in the history
  • Loading branch information
wesleybeckner committed Apr 23, 2019
1 parent 58ac0b2 commit 1c9a54c
Show file tree
Hide file tree
Showing 11 changed files with 1,560 additions and 1,644 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
scripts/development/data
# ignore all jupyter notebooks for now?
# scripts/*.ipynb
*.h5

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
2 changes: 1 addition & 1 deletion examples/salty_getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
"version": "3.6.6"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion scripts/development/scrape_and_save_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2617,7 +2617,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
"version": "3.6.5"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion scripts/development/train_and_save_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2760,7 +2760,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
"version": "3.6.4"
}
},
"nbformat": 4,
Expand Down
42 changes: 6 additions & 36 deletions scripts/molecular_dynamics/therm_cond.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1263,13 +1263,7 @@
"404/404 [==============================] - 0s 193us/step - loss: 0.1055 - mean_squared_error: 0.0170 - val_loss: 0.1179 - val_mean_squared_error: 0.0305\n",
"Epoch 97/100\n",
"404/404 [==============================] - 0s 154us/step - loss: 0.1161 - mean_squared_error: 0.0295 - val_loss: 0.1351 - val_mean_squared_error: 0.0495\n",
"Epoch 98/100\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 98/100\n",
"404/404 [==============================] - 0s 173us/step - loss: 0.1283 - mean_squared_error: 0.0434 - val_loss: 0.0962 - val_mean_squared_error: 0.0121\n",
"Epoch 99/100\n",
"404/404 [==============================] - 0s 178us/step - loss: 0.1041 - mean_squared_error: 0.0206 - val_loss: 0.1112 - val_mean_squared_error: 0.0283\n",
Expand Down Expand Up @@ -1466,13 +1460,7 @@
"404/404 [==============================] - 0s 161us/step - loss: 0.0939 - mean_squared_error: 0.0108 - val_loss: 0.0894 - val_mean_squared_error: 0.0074\n",
"Epoch 95/100\n",
"404/404 [==============================] - 0s 173us/step - loss: 0.0948 - mean_squared_error: 0.0132 - val_loss: 0.0939 - val_mean_squared_error: 0.0131\n",
"Epoch 96/100\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 96/100\n",
"404/404 [==============================] - 0s 168us/step - loss: 0.0936 - mean_squared_error: 0.0138 - val_loss: 0.0925 - val_mean_squared_error: 0.0136\n",
"Epoch 97/100\n",
"404/404 [==============================] - 0s 193us/step - loss: 0.0924 - mean_squared_error: 0.0143 - val_loss: 0.0911 - val_mean_squared_error: 0.0138\n",
Expand Down Expand Up @@ -1669,13 +1657,7 @@
"404/404 [==============================] - 0s 168us/step - loss: 0.1282 - mean_squared_error: 0.0107 - val_loss: 0.1333 - val_mean_squared_error: 0.0174\n",
"Epoch 93/100\n",
"404/404 [==============================] - 0s 188us/step - loss: 0.1218 - mean_squared_error: 0.0070 - val_loss: 0.1284 - val_mean_squared_error: 0.0148\n",
"Epoch 94/100\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 94/100\n",
"404/404 [==============================] - 0s 161us/step - loss: 0.1200 - mean_squared_error: 0.0075 - val_loss: 0.1294 - val_mean_squared_error: 0.0182\n",
"Epoch 95/100\n",
"404/404 [==============================] - 0s 181us/step - loss: 0.1157 - mean_squared_error: 0.0057 - val_loss: 0.1253 - val_mean_squared_error: 0.0167\n",
Expand Down Expand Up @@ -1773,13 +1755,7 @@
"Epoch 41/100\n",
"404/404 [==============================] - 0s 188us/step - loss: 0.3672 - mean_squared_error: 0.0091 - val_loss: 0.3938 - val_mean_squared_error: 0.0403\n",
"Epoch 42/100\n",
"404/404 [==============================] - 0s 171us/step - loss: 0.3606 - mean_squared_error: 0.0108 - val_loss: 0.3929 - val_mean_squared_error: 0.0476\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"404/404 [==============================] - 0s 171us/step - loss: 0.3606 - mean_squared_error: 0.0108 - val_loss: 0.3929 - val_mean_squared_error: 0.0476\n",
"Epoch 43/100\n",
"404/404 [==============================] - 0s 178us/step - loss: 0.3527 - mean_squared_error: 0.0110 - val_loss: 0.3813 - val_mean_squared_error: 0.0440\n",
"Epoch 44/100\n",
Expand Down Expand Up @@ -1977,13 +1953,7 @@
"404/404 [==============================] - 0s 188us/step - loss: 0.3340 - mean_squared_error: 0.0120 - val_loss: 0.3285 - val_mean_squared_error: 0.0111\n",
"Epoch 40/100\n",
"404/404 [==============================] - 0s 191us/step - loss: 0.3279 - mean_squared_error: 0.0144 - val_loss: 0.3227 - val_mean_squared_error: 0.0140\n",
"Epoch 41/100\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 41/100\n",
"404/404 [==============================] - 0s 173us/step - loss: 0.3156 - mean_squared_error: 0.0105 - val_loss: 0.3131 - val_mean_squared_error: 0.0123\n",
"Epoch 42/100\n",
"404/404 [==============================] - 0s 206us/step - loss: 0.3066 - mean_squared_error: 0.0095 - val_loss: 0.3004 - val_mean_squared_error: 0.0076\n",
Expand Down Expand Up @@ -2390,7 +2360,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
"version": "3.6.6"
}
},
"nbformat": 4,
Expand Down
1 change: 1 addition & 0 deletions scripts/vae/1mil_GDB17.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"C": 0, "N": 1, "5": 2, "l": 3, "O": 4, "s": 5, "7": 6, "r": 7, "(": 8, "[": 9, "=": 10, "P": 11, "o": 12, "]": 13, "#": 14, "6": 15, "3": 16, " ": 17, "I": 18, "c": 19, "4": 20, "+": 21, "-": 22, "n": 23, "H": 24, "8": 25, "\\": 26, "1": 27, "B": 28, "2": 29, ")": 30, "F": 31, "S": 32}
67 changes: 67 additions & 0 deletions scripts/vae/vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import matplotlib.pylab as plt
import numpy as np
import seaborn as sns; sns.set()
%matplotlib inline

import keras
from keras.models import Sequential, Model
from keras.layers import Dense
from keras.optimizers import Adam
import salty
from numpy import array
from numpy import argmax
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OneHotEncoder
import numpy as np
from sklearn.model_selection import train_test_split
from random import shuffle

def

def Encoder(x, latent_rep_size, smile_max_length, epsilon_std = 0.01):
h = Convolution1D(9, 9, activation = 'relu', name='conv_1')(x)
h = Convolution1D(9, 9, activation = 'relu', name='conv_2')(h)
h = Convolution1D(10, 11, activation = 'relu', name='conv_3')(h)
h = Flatten(name = 'flatten_1')(h)
h = Dense(435, activation = 'relu', name = 'dense_1')(h)

def sampling(args):
z_mean_, z_log_var_ = args
batch_size = K.shape(z_mean_)[0]
epsilon = K.random_normal(shape=(batch_size, latent_rep_size),
mean=0., stddev = epsilon_std)
return z_mean_ + K.exp(z_log_var_ / 2) * epsilon

z_mean = Dense(latent_rep_size, name='z_mean', activation = 'linear')(h)
z_log_var = Dense(latent_rep_size, name='z_log_var', activation = 'linear')(h)

def vae_loss(x, x_decoded_mean):
x = K.flatten(x)
x_decoded_mean = K.flatten(x_decoded_mean)
xent_loss = smile_max_length * binary_crossentropy(x, x_decoded_mean)
kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - \
K.exp(z_log_var), axis = -1)
return xent_loss + kl_loss

return (vae_loss, Lambda(sampling, output_shape=(latent_rep_size,),
name='lambda')([z_mean, z_log_var]))

def Decoder(z, latent_rep_size, smile_max_length, charset_length):
h = Dense(latent_rep_size, name='latent_input', activation = 'relu')(z)
h = RepeatVector(smile_max_length, name='repeat_vector')(h)
h = GRU(501, return_sequences = True, name='gru_1')(h)
h = GRU(501, return_sequences = True, name='gru_2')(h)
h = GRU(501, return_sequences = True, name='gru_3')(h)
return TimeDistributed(Dense(charset_length, activation='softmax'),
name='decoded_mean')(h)

def sample(a, temperature=1.0):
# helper function to sample an index from a probability array
# a = np.log(a) / temperature
# a = np.exp(a) / np.sum(np.exp(a))
# return np.argmax(np.random.multinomial(1, a, 1))
# work around from https://github.com/llSourcell/How-to-Generate-Music-Demo/issues/4
a = np.log(a) / temperature
dist = np.exp(a)/np.sum(np.exp(a))
choices = range(len(a))
return np.random.choice(choices, p=dist)

0 comments on commit 1c9a54c

Please sign in to comment.