diff --git a/saved_weights/vae10.h5 b/saved_weights/vae10.h5 new file mode 100644 index 0000000..0caabd0 Binary files /dev/null and b/saved_weights/vae10.h5 differ diff --git a/saved_weights/vae20.h5 b/saved_weights/vae20.h5 new file mode 100644 index 0000000..c55bcfa Binary files /dev/null and b/saved_weights/vae20.h5 differ diff --git a/src/graph.py b/src/graph.py index 44c0f0c..4d75f4f 100644 --- a/src/graph.py +++ b/src/graph.py @@ -3,6 +3,7 @@ import sys import cv2 import numpy as np +import random from network import VAENetwork import matplotlib.pyplot as plt @@ -31,7 +32,28 @@ def generate_latent_plot(self): plt.show() def generate_reconstructed_image(self): - + + n = 20 + digit_size = 28 + epsilon_std = 0.1 + figure = np.zeros((digit_size * n, digit_size * n)) + grid_x = np.linspace(-250, 50, n) + grid_y = np.linspace(-125, 200, n) + + for i, yi in enumerate(grid_x): + for j, xi in enumerate(grid_y): + z_sample = np.array([[xi, yi]]) * epsilon_std + x_decoded = self.decoder_model.predict(z_sample) + digit = x_decoded[0].reshape(digit_size, digit_size) + figure[i * digit_size: (i + 1) * digit_size, + j * digit_size: (j + 1) * digit_size] = digit + + plt.figure(figsize=(10, 10)) + plt.imshow(figure) + plt.show() + + def generate_reconstructed_image(self): + n = 20 digit_size = 28 epsilon_std = 0.1 @@ -51,6 +73,40 @@ def generate_reconstructed_image(self): plt.imshow(figure) plt.show() + pass + + def generate_reconstructed_image_random_samples(self): + + n = 20 + digit_size = 28 + epsilon_std = 1 + figure = np.zeros((digit_size * n, digit_size * n)) + + #Get the space of latent vectors + x_encoded = self.encoder_model.predict(self.test_data, batch_size=128)[2] + min_vals = x_encoded.min(axis=0) + max_vals = x_encoded.max(axis=0) + lv_size = len(min_vals) + + # Pick a random set of 400 n-dim vectors + samples = np.zeros((n*n,lv_size)) + for idx_vec in range(0,lv_size): + #samples[:,idx_vec] = random.sample(range(min_vals[idx_vec], max_vals[idx_vec]), n*n) + samples[:,idx_vec] = random.randrange(int(min_vals[idx_vec]), int(max_vals[idx_vec]), n*n) + samples[:,idx_vec] = np.array([random.randint(int(min_vals[idx_vec]),int(max_vals[idx_vec])) for x in range(n*n)]) - pass \ No newline at end of file + print(samples.shape) + + for i in range(0,n): + for j in range(0,n): + x_decoded = self.decoder_model.predict(np.array([samples[i*n+j,:]])*epsilon_std) + digit = x_decoded[0].reshape(digit_size, digit_size) + figure[i * digit_size: (i + 1) * digit_size, + j * digit_size: (j + 1) * digit_size] = digit + + plt.figure(figsize=(10, 10)) + plt.imshow(figure) + plt.show() + + pass diff --git a/src/network.py b/src/network.py index bcab583..9b1f0c8 100644 --- a/src/network.py +++ b/src/network.py @@ -30,8 +30,8 @@ def __init__(self, input_shape): self.model = self.create_model() def create_model(self): - - + + reconstructed_image = Lambda(lambda x: x*255)(self.decoder_model(self.encoder_model(self.input)[2])) model = Model(self.input, reconstructed_image) @@ -51,19 +51,19 @@ def create_model(self): def generate_encoder(self): - conv_layer_1 = Conv2D(filters=VAENetwork.NUM_FILTERS, + conv_layer_1 = Conv2D(filters=VAENetwork.NUM_FILTERS, kernel_size=VAENetwork.KERNEL_SIZE, activation='relu', strides=2, padding='same')(self.input) - - conv_layer_2 = Conv2D(filters=VAENetwork.NUM_FILTERS*2, + + conv_layer_2 = Conv2D(filters=VAENetwork.NUM_FILTERS*2, kernel_size=VAENetwork.KERNEL_SIZE, activation='relu', strides=2, padding='same')(conv_layer_1) - # conv_layer_3 = Conv2D(filters=VAENetwork.NUM_FILTERS*4, + # conv_layer_3 = Conv2D(filters=VAENetwork.NUM_FILTERS*4, # kernel_size=VAENetwork.KERNEL_SIZE, # activation='relu', # strides=2, @@ -79,12 +79,12 @@ def generate_encoder(self): return encoder def generate_decoder(self): - + decoder_input = Input(shape=(VAENetwork.LATENT_DIMENSIONS,)) fully_connected = Dense(self.encoder_conv_shape[1]*self.encoder_conv_shape[2]*self.encoder_conv_shape[3], activation='relu')(decoder_input) deconv_input = Reshape((self.encoder_conv_shape[1], self.encoder_conv_shape[2], self.encoder_conv_shape[3]))(fully_connected) - + deconv_layer_1 = Conv2DTranspose(filters=VAENetwork.NUM_FILTERS*4, kernel_size=VAENetwork.KERNEL_SIZE, activation='relu', @@ -100,15 +100,15 @@ def generate_decoder(self): # deconv_layer_3 = Conv2DTranspose(filters=VAENetwork.NUM_FILTERS, # kernel_size=VAENetwork.KERNEL_SIZE, # activation='relu', - # strides=2, + # strides=2, # padding='same')(deconv_layer_2) - + reconstructed_image = Conv2DTranspose(filters=1, kernel_size=VAENetwork.KERNEL_SIZE, activation='sigmoid', strides=1, padding='same')(deconv_layer_2) - + decoder = Model(decoder_input, reconstructed_image) decoder.summary() diff --git a/src/run.py b/src/run.py index a50a5dc..245a6ca 100644 --- a/src/run.py +++ b/src/run.py @@ -15,7 +15,7 @@ def main(args): mnist_data = None if os.path.exists(MNIST_DATA_PATH): mnist_data = mnist.MNIST(MNIST_DATA_PATH) - + else: print("MNIST data not found in the Data/ directory. Add the MNIST data to the path: ") print(os.path.join(os.getcwd(), '..', 'data')) @@ -27,15 +27,15 @@ def main(args): x_train = np.array([np.reshape(np.array(x, dtype=np.uint8), (28, 28, 1)) for x in x_train]) x_test = np.array([np.reshape(np.array(x, dtype=np.uint8), (28, 28, 1)) for x in x_test]) - + if not os.path.exists(MODEL_WEIGHT_DIRECTORY): os.makedirs(MODEL_WEIGHT_DIRECTORY, exist_ok=True) vae_obj = VAENetwork(x_train[2].shape) vae_model = vae_obj.get_model() - # vae_obj.train(x_train, x_test) + #vae_obj.train(x_train, x_test) - # vae_obj.save_weights(MODEL_WEIGHT_DIRECTORY) + #vae_obj.save_weights(MODEL_WEIGHT_DIRECTORY) vae_obj.load_weights(MODEL_WEIGHT_DIRECTORY) vae_model.summary() @@ -43,21 +43,14 @@ def main(args): graph_generator = GraphGenerator(vae_obj, mnist_data) if vae_obj.LATENT_DIMENSIONS == 2: graph_generator.generate_latent_plot() + graph_generator.generate_reconstructed_image() - graph_generator.generate_reconstructed_image() - img = vae_model.predict(np.reshape(x_test[45], (1, 28, 28, 1))) - cv2.imshow('img', np.reshape(x_test[45],(28, 28))) - cv2.waitKey(0) - cv2.destroyAllWindows() - - cv2.imshow('img', np.reshape(np.array(img, dtype=np.uint8), (28, 28))) - cv2.waitKey(0) - cv2.destroyAllWindows() + graph_generator.generate_reconstructed_image_random_samples() if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Train a VAE on top of the MNIST Data.") + parser = argparse.ArgumentParser(description="Train a VAE on top of the MNIST Data.") args = parser.parse_args() main(args)