In [1]:
import os 
import yaml
import tensorflow as tf 
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go 

In [2]:
os.chdir('..')

from utils.VAE_utils import VAE, build_variational_encoder, build_variational_decoder, get_latent_space

### Setup encoder & decoder configs

In [3]:
with open(os.path.join(os.getcwd(), "configs/VAE_config.yaml"), "r") as file:
    vae_config = yaml.safe_load(file)

    
encoder_config = vae_config['encoder']
decoder_config = vae_config['decoder']

### Load data - MNIST 

In [4]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

In [5]:
mnist_digits = np.concatenate([X_train, X_train], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255

### Build encoder and decoder

In [6]:
encoder_input, encoder_output, encoder, z_mean, z_log_var = build_variational_encoder(
    encoder_config=encoder_config,
)

In [7]:
encoder.summary()

Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
mnist_inputs (InputLayer)       [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
encoder_conv_1 (Conv2D)         (None, 14, 14, 32)   320         mnist_inputs[0][0]               
__________________________________________________________________________________________________
batch_norm_encoder_1 (BatchNorm (None, 14, 14, 32)   128         encoder_conv_1[0][0]             
__________________________________________________________________________________________________
leaky_relu_encoder_1 (LeakyReLU (None, 14, 14, 32)   0           batch_norm_encoder_1[0][0]       
____________________________________________________________________________________________

In [8]:
decoder_input, decoder_output, decoder = build_variational_decoder(
    decoder_config=decoder_config
)

In [9]:
decoder.summary()

Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
decoder_input (InputLayer)   [(None, 2)]               0         
_________________________________________________________________
shape_prod (Dense)           (None, 3136)              9408      
_________________________________________________________________
reshape (Reshape)            (None, 7, 7, 64)          0         
_________________________________________________________________
decoder_conv_1 (Conv2DTransp (None, 14, 14, 64)        36928     
_________________________________________________________________
batch_norm_decoder_1 (BatchN (None, 14, 14, 64)        256       
_________________________________________________________________
leaky_relu_decoder_1 (LeakyR (None, 14, 14, 64)        0         
_________________________________________________________________
decoder_conv_2 (Conv2DTransp (None, 28, 28, 32)        1846

### Instantiate VAE class and train

In [10]:
variational_autoencoder = VAE(encoder, decoder)
variational_autoencoder.compile(
    optimizer=tf.keras.optimizers.Adam()
)

history = variational_autoencoder.fit(
    mnist_digits, 
    epochs=10, 
    batch_size=128
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [11]:
latent_space, sample_range_x, sample_range_y, pixel_range = get_latent_space(
    decoder=decoder, 
    n=30, 
    digit_size=28,
    scale=1.5
)

In [18]:
fig = go.Figure()

fig.add_trace(
    go.Heatmap(
        x=np.arange(len(sample_range_x)),
        y=np.arange(len(sample_range_y)),
        z=latent_space,
        showscale=False
    )
)


fig.update_layout(
    height=800,
    width=800,
    margin=dict(b=0, l=0, r=0, t=20),
    title=dict(
        text='MNIST Digits Represented with a Multivariate Gaussian | Variational Autoencoder',
        font=dict(size=11),
    ),
    xaxis=dict(
        title=dict(text='z[0]'),
        tickmode='array',
        tickfont=dict(size=10),
        tickvals=pixel_range,
        ticktext=sample_range_x
    ),
    yaxis=dict(
        title=dict(text='z[1]'),
        tickmode='array',
        tickfont=dict(size=10),
        tickvals=pixel_range,
        ticktext=sample_range_y,
        autorange='reversed'
    ),
)

fig.write_html('./docs/VAE_mnist.html')

In [20]:
fig.write_image('./VAE_mnist.png')