Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added saved_weights/vae10.h5
Binary file not shown.
Binary file added saved_weights/vae20.h5
Binary file not shown.
60 changes: 58 additions & 2 deletions src/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import cv2
import numpy as np
import random

from network import VAENetwork
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -31,7 +32,28 @@ def generate_latent_plot(self):
plt.show()

def generate_reconstructed_image(self):


n = 20
digit_size = 28
epsilon_std = 0.1
figure = np.zeros((digit_size * n, digit_size * n))
grid_x = np.linspace(-250, 50, n)
grid_y = np.linspace(-125, 200, n)

for i, yi in enumerate(grid_x):
for j, xi in enumerate(grid_y):
z_sample = np.array([[xi, yi]]) * epsilon_std
x_decoded = self.decoder_model.predict(z_sample)
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[i * digit_size: (i + 1) * digit_size,
j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize=(10, 10))
plt.imshow(figure)
plt.show()

def generate_reconstructed_image(self):

n = 20
digit_size = 28
epsilon_std = 0.1
Expand All @@ -51,6 +73,40 @@ def generate_reconstructed_image(self):
plt.imshow(figure)
plt.show()

pass

def generate_reconstructed_image_random_samples(self):

n = 20
digit_size = 28
epsilon_std = 1
figure = np.zeros((digit_size * n, digit_size * n))

#Get the space of latent vectors
x_encoded = self.encoder_model.predict(self.test_data, batch_size=128)[2]
min_vals = x_encoded.min(axis=0)
max_vals = x_encoded.max(axis=0)
lv_size = len(min_vals)

# Pick a random set of 400 n-dim vectors
samples = np.zeros((n*n,lv_size))
for idx_vec in range(0,lv_size):
#samples[:,idx_vec] = random.sample(range(min_vals[idx_vec], max_vals[idx_vec]), n*n)
samples[:,idx_vec] = random.randrange(int(min_vals[idx_vec]), int(max_vals[idx_vec]), n*n)
samples[:,idx_vec] = np.array([random.randint(int(min_vals[idx_vec]),int(max_vals[idx_vec])) for x in range(n*n)])


pass
print(samples.shape)

for i in range(0,n):
for j in range(0,n):
x_decoded = self.decoder_model.predict(np.array([samples[i*n+j,:]])*epsilon_std)
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[i * digit_size: (i + 1) * digit_size,
j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize=(10, 10))
plt.imshow(figure)
plt.show()

pass
22 changes: 11 additions & 11 deletions src/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def __init__(self, input_shape):
self.model = self.create_model()

def create_model(self):


reconstructed_image = Lambda(lambda x: x*255)(self.decoder_model(self.encoder_model(self.input)[2]))

model = Model(self.input, reconstructed_image)
Expand All @@ -51,19 +51,19 @@ def create_model(self):

def generate_encoder(self):

conv_layer_1 = Conv2D(filters=VAENetwork.NUM_FILTERS,
conv_layer_1 = Conv2D(filters=VAENetwork.NUM_FILTERS,
kernel_size=VAENetwork.KERNEL_SIZE,
activation='relu',
strides=2,
padding='same')(self.input)
conv_layer_2 = Conv2D(filters=VAENetwork.NUM_FILTERS*2,

conv_layer_2 = Conv2D(filters=VAENetwork.NUM_FILTERS*2,
kernel_size=VAENetwork.KERNEL_SIZE,
activation='relu',
strides=2,
padding='same')(conv_layer_1)

# conv_layer_3 = Conv2D(filters=VAENetwork.NUM_FILTERS*4,
# conv_layer_3 = Conv2D(filters=VAENetwork.NUM_FILTERS*4,
# kernel_size=VAENetwork.KERNEL_SIZE,
# activation='relu',
# strides=2,
Expand All @@ -79,12 +79,12 @@ def generate_encoder(self):
return encoder

def generate_decoder(self):

decoder_input = Input(shape=(VAENetwork.LATENT_DIMENSIONS,))
fully_connected = Dense(self.encoder_conv_shape[1]*self.encoder_conv_shape[2]*self.encoder_conv_shape[3], activation='relu')(decoder_input)
deconv_input = Reshape((self.encoder_conv_shape[1], self.encoder_conv_shape[2], self.encoder_conv_shape[3]))(fully_connected)


deconv_layer_1 = Conv2DTranspose(filters=VAENetwork.NUM_FILTERS*4,
kernel_size=VAENetwork.KERNEL_SIZE,
activation='relu',
Expand All @@ -100,15 +100,15 @@ def generate_decoder(self):
# deconv_layer_3 = Conv2DTranspose(filters=VAENetwork.NUM_FILTERS,
# kernel_size=VAENetwork.KERNEL_SIZE,
# activation='relu',
# strides=2,
# strides=2,
# padding='same')(deconv_layer_2)

reconstructed_image = Conv2DTranspose(filters=1,
kernel_size=VAENetwork.KERNEL_SIZE,
activation='sigmoid',
strides=1,
padding='same')(deconv_layer_2)

decoder = Model(decoder_input, reconstructed_image)

decoder.summary()
Expand Down
21 changes: 7 additions & 14 deletions src/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def main(args):
mnist_data = None
if os.path.exists(MNIST_DATA_PATH):
mnist_data = mnist.MNIST(MNIST_DATA_PATH)

else:
print("MNIST data not found in the Data/ directory. Add the MNIST data to the path: ")
print(os.path.join(os.getcwd(), '..', 'data'))
Expand All @@ -27,37 +27,30 @@ def main(args):

x_train = np.array([np.reshape(np.array(x, dtype=np.uint8), (28, 28, 1)) for x in x_train])
x_test = np.array([np.reshape(np.array(x, dtype=np.uint8), (28, 28, 1)) for x in x_test])

if not os.path.exists(MODEL_WEIGHT_DIRECTORY):
os.makedirs(MODEL_WEIGHT_DIRECTORY, exist_ok=True)

vae_obj = VAENetwork(x_train[2].shape)
vae_model = vae_obj.get_model()
# vae_obj.train(x_train, x_test)
#vae_obj.train(x_train, x_test)

# vae_obj.save_weights(MODEL_WEIGHT_DIRECTORY)
#vae_obj.save_weights(MODEL_WEIGHT_DIRECTORY)

vae_obj.load_weights(MODEL_WEIGHT_DIRECTORY)
vae_model.summary()

graph_generator = GraphGenerator(vae_obj, mnist_data)
if vae_obj.LATENT_DIMENSIONS == 2:
graph_generator.generate_latent_plot()
graph_generator.generate_reconstructed_image()


graph_generator.generate_reconstructed_image()

img = vae_model.predict(np.reshape(x_test[45], (1, 28, 28, 1)))

cv2.imshow('img', np.reshape(x_test[45],(28, 28)))
cv2.waitKey(0)
cv2.destroyAllWindows()

cv2.imshow('img', np.reshape(np.array(img, dtype=np.uint8), (28, 28)))
cv2.waitKey(0)
cv2.destroyAllWindows()
graph_generator.generate_reconstructed_image_random_samples()

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train a VAE on top of the MNIST Data.")
parser = argparse.ArgumentParser(description="Train a VAE on top of the MNIST Data.")
args = parser.parse_args()
main(args)