-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
58ac0b2
commit 1c9a54c
Showing
11 changed files
with
1,560 additions
and
1,644 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"C": 0, "N": 1, "5": 2, "l": 3, "O": 4, "s": 5, "7": 6, "r": 7, "(": 8, "[": 9, "=": 10, "P": 11, "o": 12, "]": 13, "#": 14, "6": 15, "3": 16, " ": 17, "I": 18, "c": 19, "4": 20, "+": 21, "-": 22, "n": 23, "H": 24, "8": 25, "\\": 26, "1": 27, "B": 28, "2": 29, ")": 30, "F": 31, "S": 32} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import matplotlib.pylab as plt | ||
import numpy as np | ||
import seaborn as sns; sns.set() | ||
%matplotlib inline | ||
|
||
import keras | ||
from keras.models import Sequential, Model | ||
from keras.layers import Dense | ||
from keras.optimizers import Adam | ||
import salty | ||
from numpy import array | ||
from numpy import argmax | ||
from sklearn.preprocessing import LabelEncoder | ||
from sklearn.preprocessing import OneHotEncoder | ||
import numpy as np | ||
from sklearn.model_selection import train_test_split | ||
from random import shuffle | ||
|
||
def | ||
|
||
def Encoder(x, latent_rep_size, smile_max_length, epsilon_std = 0.01): | ||
h = Convolution1D(9, 9, activation = 'relu', name='conv_1')(x) | ||
h = Convolution1D(9, 9, activation = 'relu', name='conv_2')(h) | ||
h = Convolution1D(10, 11, activation = 'relu', name='conv_3')(h) | ||
h = Flatten(name = 'flatten_1')(h) | ||
h = Dense(435, activation = 'relu', name = 'dense_1')(h) | ||
|
||
def sampling(args): | ||
z_mean_, z_log_var_ = args | ||
batch_size = K.shape(z_mean_)[0] | ||
epsilon = K.random_normal(shape=(batch_size, latent_rep_size), | ||
mean=0., stddev = epsilon_std) | ||
return z_mean_ + K.exp(z_log_var_ / 2) * epsilon | ||
|
||
z_mean = Dense(latent_rep_size, name='z_mean', activation = 'linear')(h) | ||
z_log_var = Dense(latent_rep_size, name='z_log_var', activation = 'linear')(h) | ||
|
||
def vae_loss(x, x_decoded_mean): | ||
x = K.flatten(x) | ||
x_decoded_mean = K.flatten(x_decoded_mean) | ||
xent_loss = smile_max_length * binary_crossentropy(x, x_decoded_mean) | ||
kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - \ | ||
K.exp(z_log_var), axis = -1) | ||
return xent_loss + kl_loss | ||
|
||
return (vae_loss, Lambda(sampling, output_shape=(latent_rep_size,), | ||
name='lambda')([z_mean, z_log_var])) | ||
|
||
def Decoder(z, latent_rep_size, smile_max_length, charset_length): | ||
h = Dense(latent_rep_size, name='latent_input', activation = 'relu')(z) | ||
h = RepeatVector(smile_max_length, name='repeat_vector')(h) | ||
h = GRU(501, return_sequences = True, name='gru_1')(h) | ||
h = GRU(501, return_sequences = True, name='gru_2')(h) | ||
h = GRU(501, return_sequences = True, name='gru_3')(h) | ||
return TimeDistributed(Dense(charset_length, activation='softmax'), | ||
name='decoded_mean')(h) | ||
|
||
def sample(a, temperature=1.0): | ||
# helper function to sample an index from a probability array | ||
# a = np.log(a) / temperature | ||
# a = np.exp(a) / np.sum(np.exp(a)) | ||
# return np.argmax(np.random.multinomial(1, a, 1)) | ||
# work around from https://github.com/llSourcell/How-to-Generate-Music-Demo/issues/4 | ||
a = np.log(a) / temperature | ||
dist = np.exp(a)/np.sum(np.exp(a)) | ||
choices = range(len(a)) | ||
return np.random.choice(choices, p=dist) |
Oops, something went wrong.