In [1]:
from keras.layers import Input, Dense, Lambda, Flatten, Activation, Merge, Concatenate, Add
from keras import layers
from keras.layers.merge import concatenate
from keras.models import Model, Sequential
from keras.objectives import binary_crossentropy
from keras.callbacks import LearningRateScheduler

import numpy as np

import keras.backend as K
import tensorflow as tf

from models import vgg
from models.cvae import CVAE
from utils.angles import deg2bit, bit2deg
from utils.losses import mad_loss_tf, cosine_loss_tf, von_mises_loss_tf, maad_from_deg
from utils.losses import gaussian_kl_divergence_tf, gaussian_kl_divergence_np
from utils.losses  import von_mises_log_likelihood_tf, von_mises_log_likelihood_np
from utils.towncentre import load_towncentre
from utils.experiements import get_experiment_id

Using TensorFlow backend.


#### TownCentre data

In [2]:
xtr, ytr_deg, xval, yval_deg, xte, yte_deg = load_towncentre('data/TownCentre.pkl.gz', canonical_split=True)
image_height, image_width = xtr.shape[1], xtr.shape[2]
ytr_bit = deg2bit(ytr_deg)
yval_bit = deg2bit(yval_deg)
yte_bit = deg2bit(yte_deg)

image_height, image_width, n_channels = xtr.shape[1:]
flatten_x_shape = xtr[0].flatten().shape[0]
phi_shape = yte_bit.shape[1]

(8694, 50, 50, 3)
************splitting trval-test************
0.5488135039273248
0.7151893663724195
0.6027633760716439
0.5448831829968969
0.4236547993389047
0.6458941130666561
0.4375872112626925
0.8917730007820798
0.9636627605010293
0.3834415188257777
0.7917250380826646
0.5288949197529045
0.5680445610939323
0.925596638292661
0.07103605819788694
0.08712929970154071
0.02021839744032572
0.832619845547938
0.7781567509498505
0.8700121482468192
0.978618342232764
0.7991585642167236
0.46147936225293185
0.7805291762864555
0.11827442586893322
0.6399210213275238
0.1433532874090464
0.9446689170495839
0.5218483217500717
0.4146619399905236
0.26455561210462697
0.7742336894342167
0.45615033221654855
0.5684339488686485
0.018789800436355142
0.6176354970758771
0.6120957227224214
0.6169339968747569
0.9437480785146242
0.6818202991034834
0.359507900573786
0.43703195379934145
0.6976311959272649
0.06022547162926983
0.6667667154456677
0.6706378696181594
0.2103825610738409
0.1289262976548533
0.3154283509241838

0.41810921174800086
0.17295135427115638
0.10721074542854603
0.8173391114616214
0.47314297846564424
0.8822836719191074
0.733289134316726
0.4097262056307436
0.37351101415568366
0.5156383466512517
0.8890599531897286
0.7372785797141679
0.00515296426902323
0.6941578513691256
0.9195074069058207
0.7104557595044916
0.1770057815674959
0.4835181274274587
0.1403160179234194
0.3589952783396321
0.9371170419405177
0.9233053075587083
0.2828368521760829
0.33963104416619916
0.6002128681312939
0.96319729526038
0.14780133406539042
0.2569166436866691
0.8735568272907714
0.4918922317083445
0.8989610922270317
0.18551789752317627
0.5326685874713607
0.32626963264937237
0.31654255989247604
0.44687696394619913
0.43307744910126844
0.3573468796779544
0.9149707703156186
0.7317441854328928
0.7275469913315297
0.2899134495919554
0.5777094243168404
0.779179433301834
0.7955903685432131
0.34453046075431226
0.7708727565686478
0.735893896807733
0.14150648562190027
0.8659454685664772
0.4413214701804108
0.48641044888866547
0

************splitting train-val************
0.5488135039273248
0.7151893663724195
0.6027633760716439
0.5448831829968969
0.4236547993389047
0.6458941130666561
0.4375872112626925
0.8917730007820798
0.9636627605010293
0.3834415188257777
0.7917250380826646
0.5288949197529045
0.5680445610939323
0.925596638292661
0.07103605819788694
0.08712929970154071
0.02021839744032572
0.832619845547938
0.7781567509498505
0.8700121482468192
0.978618342232764
0.7991585642167236
0.46147936225293185
0.7805291762864555
0.11827442586893322
0.6399210213275238
0.1433532874090464
0.9446689170495839
0.5218483217500717
0.4146619399905236
0.26455561210462697
0.7742336894342167
0.45615033221654855
0.5684339488686485
0.018789800436355142
0.6176354970758771
0.6120957227224214
0.6169339968747569
0.9437480785146242
0.6818202991034834
0.359507900573786
0.43703195379934145
0.6976311959272649
0.06022547162926983
0.6667667154456677
0.6706378696181594
0.2103825610738409
0.1289262976548533
0.31542835092418386
0.363710770942622

0.8822836719191074
0.733289134316726
0.4097262056307436
0.37351101415568366
0.5156383466512517
0.8890599531897286
0.7372785797141679
0.00515296426902323
0.6941578513691256
0.9195074069058207
0.7104557595044916
0.1770057815674959
0.4835181274274587
0.1403160179234194
0.3589952783396321
0.9371170419405177
0.9233053075587083
0.2828368521760829
0.33963104416619916
0.6002128681312939
0.96319729526038
0.14780133406539042
0.2569166436866691
0.8735568272907714
0.4918922317083445
0.8989610922270317
0.18551789752317627
0.5326685874713607
0.32626963264937237
0.31654255989247604
0.44687696394619913
0.43307744910126844
0.3573468796779544
0.9149707703156186
0.7317441854328928
0.7275469913315297
0.2899134495919554
0.5777094243168404
0.779179433301834
0.7955903685432131
0.34453046075431226
0.7708727565686478
0.735893896807733
0.14150648562190027
0.8659454685664772
0.4413214701804108
0.48641044888866547
0.4483691788979973
0.5678460014775075
0.6211692473670547
0.4981795657629434
0.8667885432590956
0.627

In [28]:
np.random.seed(0)

In [4]:
xte.shape

(914, 50, 50, 3)

In [3]:
#import matplotlib.pyplot as plt
#%matplotlib inline
# fig, axs = plt.subplots(1, 10, figsize=(30, 15))
# for i in range(0, 10):
#     axs[i].imshow(xtr[i])

#### Notation

$x$ - image,

$\phi$ - head angle,

$u$ - hidden variable

#### Prior network

$ p(u|x) \sim \mathcal{N}(\mu_1(x, \theta), \sigma_1(x, \theta)) $

#### Encoder network

$ q(u|x,\phi) \sim \mathcal{N}(\mu_2(x, \theta), \sigma_2(x, \theta)) $

#### Sample  $u \sim \{p(u|x), q(u|x,\phi) \}$

#### Decoder network

$p(\phi|u,x) \sim \mathcal{VM}(\mu(x,u,\theta''), \kappa(x,u,\theta'')) $

In [3]:
n_u = 8

cvae = CVAE(n_hidden_units=n_u)

#### Training

In [4]:
import keras
from utils.custom_keras_callbacks import SideModelCheckpoint

#proper logs format - 'logs/cvae.{epoch:02d}-{val_loss:.2f}.hdf5'

model_ckpt_callback = keras.callbacks.ModelCheckpoint('logs/cvae.{epoch:02d}-{val_loss:.2f}.hdf5',
                                                      monitor='val_loss',
                                                      mode='min',
                                                      save_best_only=True,
                                                      verbose=1)

save_decoder_callback = SideModelCheckpoint('cvae_decoder', model_to_save=cvae.decoder_model, save_path='logs/cvae_decoder.{epoch:02d}-{val_loss:.2f}.hdf5')

In [5]:
cvae.full_model.fit([xtr, ytr_bit], [ytr_bit], batch_size=10, epochs=20, validation_split=0.1,
                   callbacks=[model_ckpt_callback, save_decoder_callback])

Train on 6166 samples, validate on 686 samples
Epoch 1/20
val_loss improved from inf to 1.667148, saving cvae_decoder to logs/cvae_decoder.01-1.67.hdf5
Epoch 2/20
val_loss improved from 1.667148 to 1.530816, saving cvae_decoder to logs/cvae_decoder.02-1.53.hdf5
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
val_loss improved from 1.530816 to 1.431167, saving cvae_decoder to logs/cvae_decoder.10-1.43.hdf5
Epoch 11/20
val_loss improved from 1.431167 to 1.405038, saving cvae_decoder to logs/cvae_decoder.11-1.41.hdf5
Epoch 12/20
val_loss improved from 1.405038 to 1.380366, saving cvae_decoder to logs/cvae_decoder.12-1.38.hdf5
Epoch 13/20
val_loss improved from 1.380366 to 1.349127, saving cvae_decoder to logs/cvae_decoder.13-1.35.hdf5
Epoch 14/20
Epoch 15/20
Epoch 16/20
val_loss improved from 1.349127 to 1.258421, saving cvae_decoder to logs/cvae_decoder.16-1.26.hdf5
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
val_loss improved from 1.258421 to

<keras.callbacks.History at 0x120a9bbe0>

#### Predictions using decoder part

$ \phi_i = \mu(x_i,u_i,\theta'') $

In [12]:
from scipy.stats import sem

n_samples = xte.shape[0]
#ute = np.random.normal(0,1, [n_samples,n_u])

#yte_cvae_preds = cvae.full_model.predict([xte, yte_bit])

cvae_preds = cvae.full_model.predict([xte, yte_bit])
elbo_te, ll_te, kl_te = cvae_elbo_np(yte_bit, cvae_preds)

yte_preds = cvae.decoder_model.predict(xte)
yte_preds_bit = yte_preds[:,0:2]
kappa_preds_te = yte_preds[:,2:]

yte_preds_deg = bit2deg(yte_preds_bit)

loss_te = maad_from_deg(yte_preds_deg, yte_deg)
mean_loss_te = np.mean(loss_te)
std_loss_te = np.std(loss_te)

print("MAAD error (test) : %f ± %f" % (mean_loss_te, std_loss_te))

#kappa_preds_te = np.ones([xte.shape[0], 1]) 

print("kappa (test) : %f ± %f" % (np.mean(kappa_preds_te), np.std(kappa_preds_te)))

log_likelihood_loss_te = von_mises_log_likelihood_np(yte_bit, yte_preds_bit, kappa_preds_te,
                                                     input_type='biternion')


print("ELBO (test) : %f ± %f SEM" % (np.mean(-elbo_te), sem(-elbo_te)))
# print("log-likelihood (test) : %f ± %f SEM" % (np.mean(-ll_te), sem(-ll_te)))
print("KL(encoder|prior) (test) : %f ± %f SEM" % (np.mean(-kl_te), sem(-kl_te)))

print("log-likelihood (test) : %f±%fSEM" % (np.mean(log_likelihood_loss_te), sem(log_likelihood_loss_te)))

MAAD error (test) : 27.595090 ± 31.254913
kappa (test) : 4.427681 ± 2.550267
ELBO (test) : -0.895809 ± 0.040572 SEM
KL(encoder|prior) (test) : -0.000172 ± 0.000022 SEM
log-likelihood (test) : -0.893529±0.039949SEM


In [13]:
n_samples = xtr.shape[0]
#utr = np.random.normal(0,1, [n_samples,n_u])

#ytr_cvae_preds = cvae.full_model.predict([xtr, ytr_bit])

cvae_preds = cvae.full_model.predict([xtr, ytr_bit])
elbo_tr, ll_tr, kl_tr = cvae_elbo_np(ytr_bit, cvae_preds)

ytr_preds = cvae.decoder_model.predict(xtr)
ytr_preds_bit = ytr_preds[:,0:2]
kappa_preds_tr = ytr_preds[:,2:]

ytr_preds_deg = bit2deg(ytr_preds_bit)

loss_tr = maad_from_deg(ytr_preds_deg, ytr_deg)
mean_loss_tr = np.mean(loss_tr)
std_loss_tr = np.std(loss_tr)

print("MAAD error (train) : %f ± %f" % (mean_loss_tr, std_loss_tr))

#kappa_preds_tr = np.ones([xtr.shape[0], 1]) 

print("kappa (train) : %f ± %f" % (np.mean(kappa_preds_tr), np.std(kappa_preds_tr)))

log_likelihood_loss_tr = von_mises_log_likelihood_np(ytr_bit, ytr_preds_bit, kappa_preds_tr,
                                                     input_type='biternion')



print("ELBO (train) : %f ± %f SEM" % (np.mean(-elbo_tr), sem(-elbo_tr)))
# print("log-likelihood (train) : %f ± %f SEM" % (np.mean(-ll_tr), sem(-ll_tr)))
print("KL(encoder|prior) (train) : %f ± %f SEM" % (np.mean(-kl_tr), sem(-kl_tr)))

print("log-likelihood (train) : %f±%fSEM" % (np.mean(log_likelihood_loss_tr), sem(log_likelihood_loss_tr)))

MAAD error (train) : 23.851883 ± 28.897305
kappa (train) : 4.431825 ± 2.674264
ELBO (train) : -0.701403 ± 0.009008 SEM
KL(encoder|prior) (train) : -0.000169 ± 0.000008 SEM
log-likelihood (train) : -0.701012±0.009037SEM
