In [1]:
# ## imports
import os
import tensorflow as tf

from models.VAE_a import VariationalAutoencoder
from utils.loaders_a import load_mnist
from utils.gpu_utils import gpu_optim
from utils.custom_utils import mk_run_folders
from utils.custom_utils import timer, benchmark, copy_weights, try_func


In [2]:
# project dirs

gpu_flag = True
build_flag = False  # build / load
train_flag = True

SECTION = 'vae'
RUN_ID = '0002'
DATA_NAME = 'digits'
RUN_FOLDER = 'run/{}/'.format(SECTION)
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])

mk_run_folders(RUN_FOLDER, 'viz', 'images', 'weights')

if gpu_flag:
    gpu_optim()  # gpu start


run dirs OK
1 Physical GPUs, 1 Logical GPUs


In [3]:
# params

EPOCHS = 1
BATCH_SIZE = 32
PRINT_EVERY_N_BATCHES = 100
INITIAL_EPOCH = 0
LEARNING_RATE = 0.0005


# ## data load
(x_train, y_train), (x_test, y_test) = load_mnist()


In [4]:
# ## architecture

vae = VariationalAutoencoder(
    input_dim=(28, 28, 1)
    , encoder_conv_filters=[32, 64, 64, 64]
    , encoder_conv_kernel_size=[3, 3, 3, 3]
    , encoder_conv_strides=[1, 2, 2, 1]
    , decoder_conv_t_filters=[64, 64, 32, 1]
    , decoder_conv_t_kernel_size=[3, 3, 3, 3]
    , decoder_conv_t_strides=[1, 2, 2, 1]
    , z_dim=2
    , r_loss_factor=1000
)

vae.compile(LEARNING_RATE)

if build_flag:
    vae.save(RUN_FOLDER)
else:
    vae.load_weights(os.path.join(RUN_FOLDER, 'weights/weights_'))
    print("wights loaded: ", os.path.join(RUN_FOLDER, 'weights/weights_'))

vae.encoder.summary()
vae.decoder.summary()


wights loaded:  run/vae/0002_digits\weights/weights_
Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_input (InputLayer)      [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
encoder_conv_0 (Conv2D)         (None, 28, 28, 32)   320         encoder_input[0][0]              
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, 28, 28, 32)   0           encoder_conv_0[0][0]             
__________________________________________________________________________________________________
encoder_conv_1 (Conv2D)         (None, 14, 14, 64)   18496       leaky_re_lu[0][0]                
_______________________________________

In [5]:
# ## training

vae.train = timer(vae.train)  # timer decorator func

if train_flag:
    vae.train(
        x_train
        , batch_size=BATCH_SIZE
        , epochs=EPOCHS
        , run_folder=RUN_FOLDER
        , print_every_n_batches=PRINT_EVERY_N_BATCHES
        , initial_epoch=INITIAL_EPOCH
    )


   1/1875 [..............................] - ETA: 1s - loss: 40.7345 - reconstruction_loss: 35.2928 - kl_loss: 5.4418

   9/1875 [..............................] - ETA: 11s - loss: 41.7488 - reconstruction_loss: 36.0429 - kl_loss: 5.7059

  17/1875 [..............................] - ETA: 11s - loss: 42.0188 - reconstruction_loss: 36.3271 - kl_loss: 5.6918

  25/1875 [..............................] - ETA: 11s - loss: 42.0352 - reconstruction_loss: 36.3434 - kl_loss: 5.6918

  34/1875 [..............................] - ETA: 11s - loss: 42.1419 - reconstruction_loss: 36.4417 - kl_loss: 5.7003

  43/1875 [..............................] - ETA: 11s - loss: 42.3315 - reconstruction_loss: 36.6236 - kl_loss: 5.7079

  52/1875 [..............................] - ETA: 11s - loss: 42.6356 - reconstruction_loss: 36.9283 - kl_loss: 5.7073



  61/1875 [..............................] - ETA: 11s - loss: 42.7920 - reconstruction_loss: 37.1111 - kl_loss: 5.6809

  70/1875 [>.............................] - ETA: 10s - loss: 42.7908 - reconstruction_loss: 37.1453 - kl_loss: 5.6455



  79/1875 [>.............................] - ETA: 10s - loss: 42.6682 - reconstruction_loss: 37.0349 - kl_loss: 5.6334



  88/1875 [>.............................] - ETA: 10s - loss: 42.6500 - reconstruction_loss: 37.0237 - kl_loss: 5.6263

  97/1875 [>.............................] - ETA: 10s - loss: 42.6468 - reconstruction_loss: 37.0209 - kl_loss: 5.6259

 102/1875 [>.............................] - ETA: 11s - loss: 42.6548 - reconstruction_loss: 37.0285 - kl_loss: 5.6263

 111/1875 [>.............................] - ETA: 11s - loss: 42.7524 - reconstruction_loss: 37.1237 - kl_loss: 5.6286

 120/1875 [>.............................] - ETA: 10s - loss: 42.7846 - reconstruction_loss: 37.1564 - kl_loss: 5.6282

 129/1875 [=>............................] - ETA: 10s - loss: 42.6844 - reconstruction_loss: 37.0610 - kl_loss: 5.6234

 138/1875 [=>............................] - ETA: 10s - loss: 42.6219 - reconstruction_loss: 37.0038 - kl_loss: 5.6181

 147/1875 [=>............................] - ETA: 10s - loss: 42.5848 - reconstruction_loss: 36.9693 - kl_loss: 5.6155

 156/1875 [=>............................] - ETA: 10s - loss: 42.6025 - reconstruction_loss: 36.9858 - kl_loss: 5.6167

 165/1875 [=>............................] - ETA: 10s - loss: 42.6412 - reconstruction_loss: 37.0257 - kl_loss: 5.6155

 174/1875 [=>............................] - ETA: 10s - loss: 42.6159 - reconstruction_loss: 37.0000 - kl_loss: 5.6159

 183/1875 [=>............................] - ETA: 10s - loss: 42.6238 - reconstruction_loss: 37.0089 - kl_loss: 5.6150

 192/1875 [==>...........................] - ETA: 10s - loss: 42.6332 - reconstruction_loss: 37.0186 - kl_loss: 5.6146

 201/1875 [==>...........................] - ETA: 10s - loss: 42.6569 - reconstruction_loss: 37.0431 - kl_loss: 5.6137

 210/1875 [==>...........................] - ETA: 10s - loss: 42.6644 - reconstruction_loss: 37.0476 - kl_loss: 5.6167

 219/1875 [==>...........................] - ETA: 10s - loss: 42.6932 - reconstruction_loss: 37.0778 - kl_loss: 5.6154

 228/1875 [==>...........................] - ETA: 10s - loss: 42.7155 - reconstruction_loss: 37.1011 - kl_loss: 5.6144

 237/1875 [==>...........................] - ETA: 10s - loss: 42.7472 - reconstruction_loss: 37.1307 - kl_loss: 5.6166

 246/1875 [==>...........................] - ETA: 9s - loss: 42.8313 - reconstruction_loss: 37.2185 - kl_loss: 5.6129 

 255/1875 [===>..........................] - ETA: 9s - loss: 42.8117 - reconstruction_loss: 37.2016 - kl_loss: 5.6101

 264/1875 [===>..........................] - ETA: 9s - loss: 42.8799 - reconstruction_loss: 37.2690 - kl_loss: 5.6109

 273/1875 [===>..........................] - ETA: 9s - loss: 42.9406 - reconstruction_loss: 37.3340 - kl_loss: 5.6067

 282/1875 [===>..........................] - ETA: 9s - loss: 43.0001 - reconstruction_loss: 37.3906 - kl_loss: 5.6095

 291/1875 [===>..........................] - ETA: 9s - loss: 43.0507 - reconstruction_loss: 37.4387 - kl_loss: 5.6120

 300/1875 [===>..........................] - ETA: 9s - loss: 43.0298 - reconstruction_loss: 37.4152 - kl_loss: 5.6146

 305/1875 [===>..........................] - ETA: 9s - loss: 43.0299 - reconstruction_loss: 37.4143 - kl_loss: 5.6156

 315/1875 [====>.........................] - ETA: 9s - loss: 43.0197 - reconstruction_loss: 37.4038 - kl_loss: 5.6159

 325/1875 [====>.........................] - ETA: 9s - loss: 43.0082 - reconstruction_loss: 37.3920 - kl_loss: 5.6161

 334/1875 [====>.........................] - ETA: 9s - loss: 43.0280 - reconstruction_loss: 37.4097 - kl_loss: 5.6182

 343/1875 [====>.........................] - ETA: 9s - loss: 43.0558 - reconstruction_loss: 37.4358 - kl_loss: 5.6200

 352/1875 [====>.........................] - ETA: 9s - loss: 43.0656 - reconstruction_loss: 37.4440 - kl_loss: 5.6217

 361/1875 [====>.........................] - ETA: 9s - loss: 43.0816 - reconstruction_loss: 37.4627 - kl_loss: 5.6189



 370/1875 [====>.........................] - ETA: 9s - loss: 43.0889 - reconstruction_loss: 37.4708 - kl_loss: 5.6181

 379/1875 [=====>........................] - ETA: 9s - loss: 43.1169 - reconstruction_loss: 37.4990 - kl_loss: 5.6179

 388/1875 [=====>........................] - ETA: 8s - loss: 43.1211 - reconstruction_loss: 37.5052 - kl_loss: 5.6159



 397/1875 [=====>........................] - ETA: 8s - loss: 43.1027 - reconstruction_loss: 37.4881 - kl_loss: 5.6146

 402/1875 [=====>........................] - ETA: 8s - loss: 43.0837 - reconstruction_loss: 37.4703 - kl_loss: 5.6134

 411/1875 [=====>........................] - ETA: 8s - loss: 43.0852 - reconstruction_loss: 37.4718 - kl_loss: 5.6133

 420/1875 [=====>........................] - ETA: 8s - loss: 43.0646 - reconstruction_loss: 37.4504 - kl_loss: 5.6142

 429/1875 [=====>........................] - ETA: 8s - loss: 43.0620 - reconstruction_loss: 37.4477 - kl_loss: 5.6142






















































































































































































































































































































































































Epoch 00001: saving model to run/vae/0002_digits\weights\weights


Time elapsed: min 0, sec 13


In [6]:
# copy weights

copy_weights(RUN_FOLDER, add='_')

copy OK: run/vae/0002_digits\weights/weights_.data-00000-of-00001


In [7]:
tf.keras.backend.clear_session()
