## VAE Training

### Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from models.VAE import VariationalAutoencoder
from utils.loaders import load_mnist

### Params

In [3]:
SECTION = 'vae'
RUN_ID = '0002'
DATA_NAME = 'digits'
RUN_FOLDER = 'run/{}/'.format(SECTION)
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])

if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))

mode = 'build'

### Data

In [4]:
(x_train, y_train), (x_test, y_test) = load_mnist()

### Architecture

In [5]:
vae = VariationalAutoencoder(
    input_dim = (28,28,1),
    encoder_conv_filters = [32,64,64,64],
    encoder_conv_kernel_size = [3,3,3,3],
    encoder_conv_strides = [1,2,2,1],
    decoder_conv_t_filters = [64,64,32,1],
    decoder_conv_t_kernel_size = [3,3,3,3],
    decoder_conv_t_strides = [1,2,2,1],
    z_dim = 2,
    r_loss_factor= 1000
)

if mode == 'build':
    vae.save(RUN_FOLDER)
    vae.plot_model(RUN_FOLDER)
else:
    vae.load_weights(os.path.join(RUN_FOLDER, "weights/weights.weights.h5"))

In [6]:
vae.encoder.summary()

In [7]:
vae.decoder.summary()

### Training

In [8]:
LEARNING_RATE = 0.0005

In [9]:
vae.compile(LEARNING_RATE)

In [10]:
BATCH_SIZE = 32
EPOCHS = 200
PRINT_EVERY_N_BATCHES = 100
INITIAL_EPOCH = 0

In [12]:
vae.train(
    x_train = x_train,
    batch_size = BATCH_SIZE,
    epochs = EPOCHS,
    run_folder = RUN_FOLDER,
    print_every_n_batches = PRINT_EVERY_N_BATCHES,
    initial_epoch = INITIAL_EPOCH
)

Epoch 1/200
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step6ms/step - kl_loss: 2.5972 - loss: 106.1083 - reconstruction_loss: 103.51
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step6ms/step - kl_loss: 2.4811 - loss: 85.1514 - reconstruction_loss: 82.67
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step6ms/step - kl_loss: 2.5907 - loss: 77.1849 - reconstruction_loss: 74.59
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step6ms/step - kl_loss: 2.6572 - loss: 72.6419 - reconstruction_loss: 69.98
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step6ms/step - kl_loss: 2.7028 - loss: 69.7782 - reconstruction_loss: 67.07
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step7ms/step - kl_loss: 2.7362 - loss: 67.6917 - reconstruction_loss: 64.95
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 