# Visualizing the discrepancy between distances in latent space and input space

Following Figure 1 (right) in the paper, we auto-encode a 3D cone-like shape with a 2D latent space. Then, we visualize how the distances of decoded latent samples in the input space differ from their corresponding distance in the latent space.

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.layers import Input, Dense, Lambda
from tensorflow.python.keras.models import Model, Sequential
from tensorflow.python.keras import backend as K

import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot

# Set up plotly
init_notebook_mode(connected=True)
layout = go.Layout(
    width=500,
    height=300,
    margin=go.Margin(l=100, r=100, b=20, t=20),
    showlegend=False
)
upper_color = '#494EB2'
lower_color = '#FF881E'
config={'showLink': False}

## Generate the data
A upside-down cone, whose top and bottom represent two different classes. We use a subset of the generated points for visualization.

In [2]:
num_points = 100000
plotted_points_per_class = 400

# Generate x and y uniformly in [-0.5, 0.5]^2
x = np.random.rand(num_points) - 0.5
y = np.random.rand(num_points) - 0.5
def get_z(x, y):
    distance = np.sqrt(x**2 + y**2)
    return np.tanh(4*distance)**2
points = np.stack((x, y, get_z(x, y)), axis=-1)

# Cut out the middle for a clear separation between both classes
upper_points = points[points[:, 2] >= 0.8]
lower_points = points[points[:, 2] <= 0.2]
points = np.concatenate((upper_points, lower_points))

upper_points = upper_points[:plotted_points_per_class]
lower_points = lower_points[:plotted_points_per_class]

## Plot the data in 3D

In [3]:
upper_plot = go.Scatter3d(
    x = upper_points[:, 0],
    y = upper_points[:, 1],
    z = upper_points[:, 2],
    mode = 'markers',
    marker = {'size': 3, 'color': upper_color}
)
lower_plot = go.Scatter3d(
    x = lower_points[:, 0],
    y = lower_points[:, 1],
    z = lower_points[:, 2],
    mode = 'markers',
    marker = {'size': 3, 'color': lower_color}
)
data = [upper_plot, lower_plot]
iplot(go.Figure(data=data, layout=layout), config=config)

In addition, show the data generating function:

In [4]:
x = np.linspace(-0.5, 0.5, 20)
y = np.linspace(-0.5, 0.5, 20)
x, y = np.meshgrid(x, y)
surface = go.Surface(x=x, y=y, z=get_z(x, y), showscale=False)

data = [surface, upper_plot, lower_plot]
iplot(go.Figure(data=data, layout=layout), config=config)

## Construct a beta-VAE
The smoothness of the distortions visualized in Figure 1 in the paper indicates soft activation functions, so we use sigmoid here.

In [5]:
original_dim = 3
latent_dim = 2
intermediate_dim = 20
beta = 1.0

# Build the encoder
encoder_input = Input(shape=(original_dim,))
encoder_head = Sequential([
    Dense(intermediate_dim, activation='softplus', 
          input_shape=(original_dim, )),
    Dense(intermediate_dim, activation='softplus')
])
h = encoder_head(encoder_input)
latent_mean = Dense(latent_dim, activation='linear')(h)
latent_var = Dense(latent_dim, activation='softplus')(h)

# Sample from the encoder output
def sampling(args):
    z_mean, z_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim),
                              mean=0.,
                              stddev=1.)
    return z_mean + K.sqrt(z_var) * epsilon

latent_sampled = Lambda(sampling)([latent_mean, latent_var])

# Build the decoder
decoder = Sequential([
    Dense(intermediate_dim, activation='softplus', 
          input_shape=(latent_dim,)),
    Dense(intermediate_dim, activation='softplus'),
    Dense(original_dim, activation='linear')
])
reconstruction = decoder(latent_sampled)

# Create the VAE and encoder
vae = Model(encoder_input, reconstruction)
encoder = Model(encoder_input, [latent_mean, latent_var])

# Create a generator from latent space to input space
generator_input = Input(shape=(latent_dim,))
generator = Model(generator_input, decoder(generator_input))

# Compute the VAE loss

# Fix the generator's variance to be 0.1
recon_loss = tf.distributions.Normal(loc=0., scale=0.1).log_prob(
    encoder_input - reconstruction)
recon_loss = - K.sum(recon_loss, axis=-1)
kl_loss = - 0.5 * K.sum(1 + K.log(latent_var) 
                        - K.square(latent_mean) 
                        - latent_var, 
                        axis=-1)
vae_loss = K.mean(recon_loss + beta * kl_loss)

vae.add_loss(vae_loss)
vae.compile(optimizer='rmsprop')
vae.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 3)            0                                            
__________________________________________________________________________________________________
sequential_1 (Sequential)       (None, 20)           500         input_1[0][0]                    
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 2)            42          sequential_1[1][0]               
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 2)            42          sequential_1[1][0]               
__________________________________________________________________________________________________
lambda_1 (

## Train the VAE

In [6]:
history = vae.fit(points, epochs=50, shuffle=True, verbose=0)
data = [go.Scatter(y=history.history['loss'])]
iplot(go.Figure(data=data, layout=layout), config=config)

Calculate the average mean and variance output by the encoder for a point from the upper / lower class.

In [7]:
upper_latent_mean, upper_latent_log_var = encoder.predict(upper_points)
lower_latent_mean, lower_latent_log_var = encoder.predict(lower_points)
print('Average absolute latent mean (upper points)', 
      np.mean(np.absolute(upper_latent_mean), axis=0))
print('Average latent variance (upper points)', 
      np.mean(np.exp(upper_latent_log_var), axis=0))
print('Average absolute latent mean (lower points)', 
      np.mean(np.absolute(lower_latent_mean), axis=0))
print('Average latent variance (lower points)', 
      np.mean(np.exp(lower_latent_log_var), axis=0))

Average absolute latent mean (upper points) [0.8765453  0.93472856]
Average latent variance (upper points) [1.1339548 1.1475053]
Average absolute latent mean (lower points) [0.0819083  0.05027945]
Average latent variance (lower points) [1.0076865 1.0070825]


## Plot the 2D latent representations 
with a line connecting two points from the upper class.

In [8]:
upper_latent, _ = encoder.predict(upper_points)
lower_latent, _ = encoder.predict(lower_points)
upper_plot = go.Scatter(
    x = upper_latent[:, 0],
    y = upper_latent[:, 1],
    mode = 'markers',
    marker = {'color': upper_color}
)
lower_plot = go.Scatter(
    x = lower_latent[:, 0],
    y = lower_latent[:, 1],
    mode = 'markers',
    marker = {'color': lower_color}
)

# Choose two upper points that are far away in the latent space
upper_min = upper_latent[np.argmin(upper_latent, axis=0)[0]]
upper_max = upper_latent[np.argmax(upper_latent, axis=0)[0]]

# Plot a line between these points
line = go.Scatter(
    x = [upper_min[0], upper_max[0]],
    y = [upper_min[1], upper_max[1]],
    mode = 'lines',
    line = {'color': 'red', 'width': 3}
)

data = [upper_plot, lower_plot, line]
iplot(go.Figure(data=data, layout=layout), config=config)

## Plot the 3D reconstructions

In [9]:
upper_reconstructed = generator.predict(upper_latent)
lower_reconstructed = generator.predict(lower_latent)

upper_plot = go.Scatter3d(
    x = upper_reconstructed[:, 0],
    y = upper_reconstructed[:, 1],
    z = upper_reconstructed[:, 2],
    mode = 'markers',
    marker = {'size': 3, 'color': upper_color}
)
lower_plot = go.Scatter3d(
    x = lower_reconstructed[:, 0],
    y = lower_reconstructed[:, 1],
    z = lower_reconstructed[:, 2],
    mode = 'markers',
    marker = {'size': 3, 'color': lower_color}
)
data = [upper_plot, lower_plot]
iplot(go.Figure(data=data, layout=layout), config=config)

## Plot the distortion of the connecting line

In [10]:
line_x = np.linspace(upper_min[0], upper_max[0], 50)
line_y = np.linspace(upper_min[1], upper_max[1], 50)
line_latent = np.stack([line_x, line_y]).T
line_reconstructed = generator.predict(line_latent)

line_plot = go.Scatter3d(
    x = line_reconstructed[:, 0],
    y = line_reconstructed[:, 1],
    z = line_reconstructed[:, 2],
    mode = 'lines',
    line = {'color': 'red', 'width': 5}
)
data = [upper_plot, lower_plot, line_plot]
iplot(go.Figure(data=data, layout=layout), config=config)

## Plot the surface of the latent manifold

In [13]:
# Calculate the range of the latent space
latent, _ = encoder.predict(points)
min_values = np.amin(latent, axis=0)
max_values = np.amax(latent, axis=0)

# Generate a meshgrid of points in the latent space
points_per_axis = 100
latent_x = np.linspace(min_values[0], max_values[0], points_per_axis)
latent_y = np.linspace(min_values[1], max_values[1], points_per_axis) 
latent_x_grid, latent_y_grid = np.meshgrid(latent_x, latent_y)

# Project each point of the meshgrid into the input space

# Reshape the 2-D meshgrid coordinates to 1-D
latent_x_grid = latent_x_grid.reshape(points_per_axis*points_per_axis)
latent_y_grid = latent_y_grid.reshape(points_per_axis*points_per_axis)
latent_coords = np.stack((latent_x_grid, latent_y_grid)).T
reconstructed = generator.predict(latent_coords)

# Undo the reshape above
x = reconstructed[:, 0].reshape(points_per_axis, points_per_axis)
y = reconstructed[:, 1].reshape(points_per_axis, points_per_axis)
z = reconstructed[:, 2].reshape(points_per_axis, points_per_axis)

# Plot the surface of the latent manifold in the input space
surface = go.Surface(x=x, y=y, z=z, showscale=False)
data = [surface]
iplot(go.Figure(data=data, layout=layout), config=config)

## Plot a heatmap of distortions
As in Figure 1 in the paper, we can visualize the local distortion of the latent space via $\sqrt{\det(J_z^T J_z)}$.

In [14]:
def get_jacobians(output_tensor, input_tensor):
    # Calculate the batch of jacobian matrices
    gradients = []
    for scalars in tf.unstack(output_tensor, axis=1):
        gradients.append(tf.gradients(scalars, [input_tensor])[0])
    return tf.stack(gradients, axis=1)

def get_distortions(output_tensor, input_tensor):
    # Calculate the batch of distortions
    jacobians = get_jacobians(output_tensor, input_tensor)
    jacobians_t = tf.transpose(jacobians, [0, 2, 1])
    dets = tf.matrix_determinant(tf.matmul(jacobians_t, jacobians))
    return tf.sqrt(dets)

distortions = get_distortions(generator.output, generator_input)

sess = K.get_session()
evaluated_gradients = sess.run(distortions, 
                               feed_dict={generator_input: latent_coords})
evaluated_gradients = evaluated_gradients.reshape(points_per_axis, 
                                                  points_per_axis)

heatmap = go.Heatmap(x=latent_x, y=latent_y, z=evaluated_gradients)
data = [heatmap]
iplot(go.Figure(data=data, layout=layout), config=config)