In [1]:
import numpy as np
import tensorflow as tf

from keras import backend as K
from keras.layers import Input, Dense, Lambda, Flatten, Activation, Merge, Concatenate, Add
from keras import layers
from keras.layers.merge import concatenate
from keras.models import Model, Sequential
from keras.objectives import binary_crossentropy
from keras.callbacks import LearningRateScheduler
from keras.models import load_model

from models import vgg
from models.cvae import CVAE
from utils.angles import deg2bit, bit2deg
from utils.losses import mad_loss_tf, cosine_loss_tf, von_mises_loss_tf, maad_from_deg
from utils.losses import gaussian_kl_divergence_tf, gaussian_kl_divergence_np
from utils.losses  import von_mises_log_likelihood_tf, von_mises_log_likelihood_np
from utils.towncentre import load_towncentre
from utils.experiements import get_experiment_id

Using TensorFlow backend.


#### TownCentre data

In [2]:
xtr, ytr_deg, xval, yval_deg, xte, yte_deg = load_towncentre('data/TownCentre.pkl.gz', canonical_split=True)
image_height, image_width = xtr.shape[1], xtr.shape[2]
ytr_bit = deg2bit(ytr_deg)
yval_bit = deg2bit(yval_deg)
yte_bit = deg2bit(yte_deg)

image_height, image_width, n_channels = xtr.shape[1:]
flatten_x_shape = xtr[0].flatten().shape[0]
phi_shape = yte_bit.shape[1]

In [3]:
#import matplotlib.pyplot as plt
#%matplotlib inline
# fig, axs = plt.subplots(1, 10, figsize=(30, 15))
# for i in range(0, 10):
#     axs[i].imshow(xtr[i])

#### Notation

$x$ - image,

$\phi$ - head angle,

$u$ - hidden variable

#### Prior network

$ p(u|x) \sim \mathcal{N}(\mu_1(x, \theta), \sigma_1(x, \theta)) $

#### Encoder network

$ q(u|x,\phi) \sim \mathcal{N}(\mu_2(x, \theta), \sigma_2(x, \theta)) $

#### Sample  $u \sim \{p(u|x), q(u|x,\phi) \}$

#### Decoder network

$p(\phi|u,x) \sim \mathcal{VM}(\mu(x,u,\theta''), \kappa(x,u,\theta'')) $

In [4]:
n_u = 8

cvae_model = CVAE(n_hidden_units=n_u)

#### Training

In [8]:
import keras
from utils.custom_keras_callbacks import SideModelCheckpoint

#proper logs format - 'logs/cvae.{epoch:02d}-{val_loss:.2f}.hdf5'

# decoder_ckpt_path = 'logs/cvae.decoder.weights.hdf5'
cvae_ckpt_path = 'logs/cvae.full_model.weights.hdf5'


model_ckpt_callback = keras.callbacks.ModelCheckpoint(cvae_ckpt_path,
                                                      monitor='val_loss',
                                                      mode='min',
                                                      save_best_only=True,
                                                      save_weights_only=True,  
                                                      verbose=1)

# decoder_ckpt_callback = SideModelCheckpoint('cvae_decoder', 
#                                             model_to_save=cvae_model.decoder_model, 
#                                             save_path=decoder_ckpt_path,
#                                             save_weights_only=True)

In [None]:
# cvae_model.decoder_model.save_weights(decoder_ckpt_path)
# cvae_model.full_model.save_weights(cvae_ckpt_path)

In [15]:
cvae_model.full_model.fit([xtr, ytr_bit], [ytr_bit], batch_size=10, epochs=50, validation_data=([xval, yval_bit], yval_bit),
                   callbacks=[model_ckpt_callback])

Train on 6882 samples, validate on 834 samples
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50


Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


<keras.callbacks.History at 0x141e86b38>

#### Predictions using decoder part

$ \phi_i = \mu(x_i,u_i,\theta'') $

In [10]:
cvae_best = CVAE(n_hidden_units=n_u)
cvae_best.full_model.load_weights(cvae_ckpt_path)

In [11]:
from scipy.stats import sem

def _eval_model(cvae_model, x, ytrue_deg, ytrue_bit, data_part):
    
    n_samples = x.shape[0]

    cvae_preds = cvae_model.full_model.predict([x, ytrue_bit])
    elbo_te, ll_te, kl_te = cvae_model._cvae_elbo_loss_np(ytrue_bit, cvae_preds)

    ypreds = cvae_model.decoder_model.predict(x)
    ypreds_bit = ypreds[:,0:2]
    kappa_preds_te = ypreds[:,2:]

    ypreds_deg = bit2deg(ypreds_bit)

    loss_te = maad_from_deg(ytrue_deg, ypreds_deg)
    mean_loss_te = np.mean(loss_te)
    std_loss_te = np.std(loss_te)

    print("MAAD error (test) : %f ± %f" % (mean_loss_te, std_loss_te))

    print("kappa (test) : %f ± %f" % (np.mean(kappa_preds_te), np.std(kappa_preds_te)))

    log_likelihood_loss = von_mises_log_likelihood_np(ytrue_bit, ypreds_bit, kappa_preds_te,
                                                         input_type='biternion')

    print("ELBO (%s) : %f ± %f SEM" % (data_part, np.mean(-elbo_te), sem(-elbo_te)))

    print("KL(encoder|prior) (%s) : %f ± %f SEM" % (data_part, np.mean(-kl_te), sem(-kl_te)))

    print("log-likelihood (%s) : %f±%fSEM" % (data_part, 
                                              np.mean(log_likelihood_loss), 
                                              sem(log_likelihood_loss)))
    return

In [12]:
_eval_model(cvae_best, xtr, ytr_deg, ytr_bit, 'train')

MAAD error (test) : 54.725613 ± 48.245662
kappa (test) : 0.934393 ± 0.810261
ELBO (train) : -1.568528 ± 0.006742 SEM
KL(encoder|prior) (train) : -0.035083 ± 0.000357 SEM
log-likelihood (train) : -1.543154±0.006767SEM


In [13]:
_eval_model(cvae_best, xval, yval_deg, yval_bit, 'validation')

MAAD error (test) : 53.554670 ± 47.014745
kappa (test) : 0.943233 ± 0.826707
ELBO (validation) : -1.570017 ± 0.018833 SEM
KL(encoder|prior) (validation) : -0.037350 ± 0.001082 SEM
log-likelihood (validation) : -1.539291±0.019358SEM


In [14]:
_eval_model(cvae_best, xte, yte_deg, yte_bit, 'test')

MAAD error (test) : 47.863967 ± 45.116663
kappa (test) : 0.903833 ± 0.801953
ELBO (test) : -1.519896 ± 0.017887 SEM
KL(encoder|prior) (test) : -0.034243 ± 0.000985 SEM
log-likelihood (test) : -1.485811±0.017619SEM
