In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import keras
from keras import layers



In [2]:
print(tf.__version__)

2.17.1


In [3]:
numLayers = 4
filters = [32, 64, 64, 128]
kernels = [4, 4, 4, 4]
strides = [1, 2, 2, 1]
latentSpaceDim = 16
preBottleneckShape = None
betaFactor =  0.075

def build_conv_layers(encoder_input):
  x = encoder_input
  for i in range(numLayers):
    conv_layer = layers.Conv2D(
        filters = filters[i],
        kernel_size=kernels[i],
        strides = strides[i],
        padding = 'same',
        name = f'encoder_conv_layer_{i+1}'
    )
    x = conv_layer(x)
    x = layers.ReLU(name = f'encoder_relu_layer_{i+1}')(x)
    x = layers.BatchNormalization(name = f'encoder_batch_normalization_layer_{i+1}')(x)
  return x

def build_bottleneck(conv_output):
  global preBottleneckShape
  preBottleneckShape = tf.keras.backend.int_shape(conv_output)[1:]
  x = layers.Flatten()(conv_output)
  z_mean = layers.Dense(latentSpaceDim, name = 'z_mean')(x)
  z_log_var = layers.Dense(latentSpaceDim, name = 'z_log_var')(x)
  return z_mean, z_log_var

def build_conv_transpose_layers(reshaped):
  x = reshaped
  for i in reversed(range(1,numLayers)):
    conv_transpose_layer = layers.Conv2DTranspose(
        filters = filters[numLayers-i],
        kernel_size = kernels[numLayers-i],
        strides = strides[numLayers-i],
        padding = 'same',
        name = f'encoder_conv_transpose_layer_{i+1}'
    )
    x = conv_transpose_layer(x)
    x = layers.ReLU(name = f'decoder_relu_layer_{i+1}')(x)
    x = layers.BatchNormalization(name = f'decoder_batch_normalization_layer_{i+1}')(x)
  return x

@keras.saving.register_keras_serializable()
class Reparametrize(layers.Layer):
  def call(self, inputs):
    z_mean, z_log_var = inputs
    epsilon = tf.keras.backend.random_normal(shape=tf.shape(z_mean), mean = 0.0, stddev=1.0)

    kl_loss = -0.5 * tf.reduce_mean(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
    self.add_loss(kl_loss * betaFactor)

    return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [13]:
import os

current = os.getcwd()

enc = keras.models.load_model(os.path.join(current, 'encoder_model.keras'))
dec = keras.models.load_model(os.path.join(current, 'decoder_model.keras'))

enc.summary()
dec.summary()

In [21]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

x_test = np.asarray(x_test, dtype=np.float32) / 255
x_test = x_test.reshape(x_test.shape + (1,))

np.random.shuffle(x_test)

print(x_test.shape)


(10000, 28, 28, 1)


In [6]:
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual
from IPython.display import display

In [7]:
def generate_image(latent_vector, decoder):
     """Generates an image from a modified latent vector."""

     latent_vector = np.array(latent_vector).reshape(1,-1)
     generated_image = decoder.predict(latent_vector)
     generated_image = generated_image.reshape((28, 28)) # Reshape to image
     return generated_image

In [None]:
from ipywidgets import Image

def generate_image_data(latent_vector, decoder):
    generated_image = generate_image(latent_vector, decoder)
    # Convert to PNG bytes
    from PIL import Image as PILImage
    import io
    pil_image = PILImage.fromarray((generated_image * 255).astype(np.uint8))
    pil_image_resized = pil_image.resize((12*28, 12*28), PILImage.NEAREST)

    buf = io.BytesIO()
    pil_image_resized.save(buf, format='PNG')
    return buf.getvalue()



def interactive_latent_exploration(encoder, decoder, x_test):
     predictions = encoder.predict(x_test)
    
     predictions_flat = predictions.reshape(predictions.shape[0], -1)
     mean = np.mean(predictions_flat, axis=0)
     covariance = np.cov((predictions_flat-mean).T)
     e, v = np.linalg.eig(covariance)
     list_e = np.abs(e).tolist()

     # Get indices that would sort the eigenvalues from largest to smallest
     sorted_indices = np.argsort(list_e)[::-1]

     img_widget = Image(value=generate_image_data(np.zeros(predictions[0].shape), decoder))
     display(img_widget)

     def update_image(**kwargs):
        latent_vector = mean + np.dot(
            v,
            (np.array([kwargs[key] for key in kwargs]) * np.array(list_e)[sorted_indices]).T
        ).T
        img_widget.value = generate_image_data(latent_vector, decoder)

     # Create sliders for each dimension of the latent space
     latent_sliders = {}
     for i in range(latentSpaceDim):
         latent_sliders[f'latent_dim_{i+1}'] = widgets.FloatSlider(min=-3, max=3, step=0.025, value=0, description=f'Latent Dim {i+1}')

     # Reorder the slider keys based on the sorted indices
     ordered_slider_keys = [f'latent_dim_{index+1}' for index in sorted_indices]
     ordered_sliders = [latent_sliders[key] for key in ordered_slider_keys]

     # Create the grid layout for the sliders
     grid_rows = []
     for i in range(0, len(ordered_sliders), 4):
          grid_rows.append(widgets.HBox(ordered_sliders[i:i+4]))

     slider_grid = widgets.VBox(grid_rows)

     # Use interact to connect sliders with the update function
     interactive_plot = interactive(update_image, **{key: latent_sliders[key] for key in ordered_slider_keys})

     # Rearrange children to put the image output on top
     interactive_plot.children = [interactive_plot.children[-1]]

     # Display the interactive plot and slider grid
     display(interactive_plot)
     display(slider_grid)

interactive_latent_exploration(enc, dec, x_test)

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 8ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x1c\x00\x00\x00\x1c\x08\x00\x00\x00\x00Wf\x80H\x…

interactive(children=(Output(),), _dom_classes=('widget-interact',))

VBox(children=(HBox(children=(FloatSlider(value=0.0, description='Latent Dim 1', max=3.0, min=-3.0, step=0.025…