# Variational Autoencoder
Based on variational inference

$$P(Z|X) = \frac{P(Z,X)}{P(X)}$$

## Some basic information theory
### Information 
$I = -log(P(x))$
* Measures the factor of uncertainty removed when x is known
* 1 bit can be thought of as information that reduces uncertainty
    by a factor of 2
E.g. Let's say there is 50% chance of weather being sunny and 
50% chance of rain tomorrow. When the weather station tells us it is going to be sunny, they have given us 1 bit of information.

> Uncertainty reduction is the inverse of the event's probability

E.g. If the weather probabilities are sunny 75% and rain 25%,
finding out that it is going to be rainy will reduce our 
uncertainty by $1/0.25 = 4$. This corresponds to $log_2(4) = 2$ bits of information $log_2(\frac{1}{0.25}) = -log_2(0.25)$

### Entropy 
$H = \sum_x-log(P(x)) * P(x))$  
* Can be thought of as average amount of information relayed by a certain distribution
E.g in the above case, the weather station on average transmits
$$ H = 0.75 \times -log(0.75) + 0.25 \times (-log(0.25)) 
  = 0.81$$ bits of useful information

### Cross-entropy
$$H(P,Q) = \sum_x P(x)(-log(Q(x))$$

E.g Lets say we use 2 bits to encode our weather prediction
this can be thought of as us predicting the weather to have
a 25% chance of either being sunny or rainy
The average number of actual bits sent is
$H = 0.75 \times 2 + 0.25 \times 2 = 2 $bits. If using different number of bits for the different predictions $H = 0.75 \times 2 + 0.25 \times 3 = 2.25 $bits 

Based on the entropy and cross-entropy, we can see that our _predicted_ probability distribution Q(x) differs from the _actual_ probability distribution P(x) by 
$KL(P||Q) = 2.25 - 0.81 = 1.54 $bits  
If predictions are perfect i.e. Q(x) = P(x), then H(P,Q) = H(P)  
Therefore, $H(P,Q) = H(P) + KL(P||Q)$  
$KL(P||Q)$ means KL-divergence of Q w.r.t P

\begin{align}
KL(P||Q) &= H(P,Q) - H(P)\\
         &= \sum_x P(x)(-log(Q(x)) -  \sum_xP(x)(-log(P(x))\\
         &= \sum_{x} P(x)(-log(Q(x) - (-log(P(x)))\\
         &= \sum_x P(x)(log(P(x)) - log(Q(x)))\\
         &= \sum_x P(x)(log(\frac{P(x)}{Q(x)}))\\
\end{align}
#### Some properties of KL-divergence
1. $KL(P||Q)$ is alwaysgreater than or equal to 0
2. $KL(P||Q)$ is not the same as $KL(Q||P)$


## Variational Bayes

$$P(Z|X) = \frac{P(Z,X)}{P(X)} = \frac{P(X|Z)P(Z)}{P(X)}$$

We don't know P(X). If we were to compute it,
$P(X) = \int{P(X|Z)P(Z)dZ}$
* Intractable in many cases
* If distributions are high dimensional, integral is multi-integral

Thus, we can try to approximate the distribution. One method to approximate is Monte Carlo method (Gibbs sampling and other sampling methods) which is unbiased with high variance.

Another is variational inference which has low variance but is biased

1. Approximate P(Z|X) with Q(Z) that is tractable e.g. Gaussian
2. Play with the parameters of Q(Z) in a way that it gets close enough to P(Z|X) i.e. minimize $KL(Q(Z)||P(Z|X))$

This brings us to the following objective of minimizing 


\begin{align}
KL(Q(Z)||P(Z|X)) &= \sum_z Q(Z)log(\frac{Q(Z)}{P(Z|X)})\\
                 &= - \sum_zQ(Z) log(\frac{P(Z|X)}{Q(Z)})\\
                 &= - \sum_z Q(Z) log(\frac{P(X,Z)}{P(X) Q(Z)})\\
                 &= - \sum_z Q(Z) (log(\frac{P(X,Z)}{Q(Z)}) - log(P(X)))\\
                 &= - \sum_z Q(Z) log(\frac{P(X,Z)}{Q(Z)}) + log(P(X))\\
\end{align}
\begin{align}
\therefore log(P(X))    &= KL(Q(Z)||P(Z|X)) + \sum_z Q(Z)log(\frac{P(X,Z)}{Q(Z)})\\
                 &= KL(Q(Z)||P(Z|X)) + L\\
\end{align}


As $log(P(X))$ is a constant, to minimize $KL(Q(Z)||P(Z|X))$,
we just need to maximize $L$.

$$\because KL(Q(Z)||P(Z|X)) \geq 0$$,  
$$L \leq P(X)$$ Thus, L is a lower bound of P(X).

\begin{align}
L &= \sum_z Q(Z) log(\frac{P(X,Z)}{Q(Z)})\\
  &= \sum_z Q(Z) log(\frac{P(X|Z)P(Z)}{Q(Z)})\\
  &= \sum_z Q(Z)(log(P(X|Z)) + log(\frac{P(Z)}{Q(Z)}))\\
  &= \sum_z Q(Z) log(P(X|Z))) + \sum_z Q(Z) log(\frac{P(Z)}{Q(Z)})\\
\end{align}

$$\sum_z Q(Z) log(P(X|Z))) = E_{Q(Z)}P(X|Z)$$
$$\sum_z Q(Z) log(\frac{P(Z)}{Q(Z)}) = -KL(Q(Z)||P(Z))$$
Representing L as an autoencoder

X --> Q(Z|X) --> Z --> P(X|Z) --> X'

$E_{Q(Z)}P(X|Z)$ term acts as reconstruction error.
$P(X|Z)$ is deterministic meaning one input will get the same output all the time. Thus, it can be considered $P(X|X')$.

If $P(X|X')$ is gaussian
$$P(X|X') = e^{-|X - X'|^2}$$
$$log(P(X|X')) = -|X - X'|^2$$ --> L2 loss

If Bernoulli distribution, will be similar to cross-entropy

So far the network is all deterministic
> To make it probabilistic, 
    encoder should not parametrize Z but instead the parametrize
    the distribution that generates Z i.e. $\mu$ and $\sigma$



In [3]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import traceback

  from ._conv import register_converters as _register_converters


In [21]:
df = pd.read_csv('../CapsNet/train.csv')
labels = np.asarray(df.label).reshape(42000,1)
train_x = np.asarray(df[df.columns[1:]]).reshape(42000,28*28)/255.0

In [30]:
############################ SIMPLE MODEL ###############################################
mu_data, sigma_data = 3., 1.75
mu_z, sigma_z = 0., 1.
alpha = 0.2
np.random.seed(0)
data_distribution = np.random.normal(mu_data, sigma_data, (40000, 10))
data_distribution /= np.max(data_distribution,axis=1).reshape(40000,1)

latent_distribution = np.random.normal(mu_z, sigma_z, (40000, 16))
tf.Graph().as_default()
X = tf.placeholder(tf.float32, (None, 784))
with tf.name_scope('encoder'):
    en_1 = tf.layers.Dense(512)(X)
    en_1 = tf.maximum(en_1, alpha*en_1)
    en_2 = tf.layers.Dense(256)(en_1)
    en_2 = tf.maximum(en_2, alpha*en_2)
    en_3 = tf.layers.Dense(128)(en_2)
    en_3 = tf.maximum(en_3, alpha*en_3)
    encoded = tf.layers.Dense(32)(en_3)
    #encoded = tf.maximum(encoded, alpha*encoded)
log_stdev = encoded[:,0:16]
mean = encoded[:,16:32]
latent_samples = tf.placeholder(tf.float32, (None,16))
Z = latent_samples * tf.exp(log_stdev) + mean
with tf.name_scope('decoder'):
    de_1 = tf.layers.Dense(128)(Z)
    de_1 = tf.maximum(de_1, alpha*de_1)
    de_2 = tf.layers.Dense(256)(de_1)
    de_2 = tf.maximum(de_2, alpha*de_2)
    de_3 = tf.layers.Dense(512)(de_2)
    de_3 = tf.maximum(de_3, alpha*de_3)
    output = tf.layers.Dense(784,activation=tf.nn.sigmoid)(de_3)
recon_loss = tf.reduce_sum((tf.square(X - output)),axis=1)
kl_loss = - 0.5 * tf.reduce_sum(1 + log_stdev - tf.square(mean) - tf.square(tf.exp(log_stdev)), axis=1)
loss = tf.reduce_mean(recon_loss + kl_loss)
adam = tf.train.AdamOptimizer()
train = adam.minimize(loss)

In [None]:
epochs = 2000
bs = 10
with tf.Session() as sess:
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    saver.restore(sess,'./checkpoints/simple.ckpt')
    try:
        for i in range(epochs):
            print('Ep: ' + str(i))
            for b in range(40000):
            #batch = train_x[b*batch_s:]
                sess.run(train,{X:train_x[b*bs:b*bs+1],latent_samples:latent_distribution[b*bs:b*bs+1]})
                if b % 10 == 0:
                    saver.save(sess,'./checkpoints/simple.ckpt')
                    print("Loss: " + str(sess.run(loss,{X:train_x[b*bs:b*bs+1],latent_samples:latent_distribution[b*bs:b*bs+1]})))
    except Exception as e:
        print(e)
    finally:
        saver.save(sess,'./checkpoints/simple.ckpt')
        print('Model saved')
        
    

Ep: 0
Loss: 176.71806
Loss: 86.13999
Loss: 68.960815
Loss: 42.089844
Loss: 37.327477
Loss: 105.81069
Loss: 60.07544
Loss: 40.60073
Loss: 47.660053
Loss: 77.876854
Loss: 38.797943
Loss: 54.746956
Loss: 71.219475
Loss: 41.788605
Loss: 47.27316
Loss: 26.522472
Loss: 47.80396
Loss: 50.082848
Loss: 61.484234
Loss: 65.725784
Loss: 70.058205
Loss: 54.295845
Loss: 31.869999
Loss: 27.829477
Loss: 72.57536
Loss: 26.830242
Loss: 57.81515
Loss: 24.323423
Loss: 41.57086
Loss: 79.89962
Loss: 46.510654
Loss: 53.951756
Loss: 21.056206
Loss: 62.47521
Loss: 19.016842
Loss: 56.53398
Loss: 63.367165
Loss: 57.221268
Loss: 39.302254
Loss: 27.662477
Loss: 53.691757
Loss: 38.90088
Loss: 52.052193
Loss: 58.164482
Loss: 45.844
Loss: 44.99283
Loss: 39.269173
Loss: 40.07831
Loss: 42.23399
Loss: 84.642204
Loss: 33.225983
Loss: 63.901604
Loss: 45.78642
Loss: 32.828465
Loss: 52.030678
Loss: 21.493973
Loss: 45.933163
Loss: 67.14117
Loss: 61.419632
Loss: 65.387886
Loss: 20.618263
Loss: 61.233204
Loss: 52.767204
Loss: 

In [35]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess,'./checkpoints/simple.ckpt')
    print(data_distribution[0:5])
    print(sess.run(output,{X:data_distribution,latent_samples:latent_distribution})[0:5])
    print(sess.run(loss,{X:data_distribution,latent_samples:latent_distribution}))

INFO:tensorflow:Restoring parameters from ./checkpoints/simple.ckpt
[[ 0.49072147  0.27344777  0.36601261  0.8159857   0.46677805  0.35776554
   0.54717307  0.4877827   0.2667516   1.        ]
 [ 0.15562885  0.36178482  0.61852379 -0.16648737  0.54728843  1.
   0.9120218   0.78121508 -0.19731969  0.61954722]
 [ 0.53872489  0.73663366  0.58234488  0.12916687  0.63438279  1.
   0.74062651  0.46483566  0.64851545  0.41318566]
 [ 0.37936773  0.3524125   1.          0.56451019  0.35917102  0.27761476
   0.71340978  0.68306209  0.42316618  0.32498537]
 [ 0.85991503  0.63204934  0.71561783  0.63673528  0.88973944  1.
   0.81692877  0.31732474  0.87881902  0.9355185 ]]
[[ 0.5648229   0.4378252   0.53567755  0.6407115   0.5929521   0.55325466
   0.7483129   0.55739665  0.5691584   0.70712   ]
 [ 0.5706934   0.564654    0.47659892  0.19214869  0.6301438   0.6107527
   0.8660286   0.8391267  -0.31949863  0.39746046]
 [ 0.5444164   0.7693986   0.44586438  0.582441    0.5344002   0.7162865
   0.513

In [5]:
####################### MNIST MODEL ########################################
k = 9
mu_z, sigma_z = 0., 1.
latent_dist = np.random.normal(mu_z,sigma_z,(42000,7,7,32))
tf.Graph().as_default()
X = tf.placeholder(tf.float32, (None, 28, 28, 1))
with tf.name_scope('encoder'):
    en1 = tf.layers.Conv2D(64, (k,k))(X)
    en1 = tf.nn.leaky_relu(en1)
    en2 = tf.layers.Conv2D(64, (7,7))(en1)
    en2 = tf.nn.leaky_relu(en2)
    encoded = tf.layers.Conv2D(64, (3,3),padding='same',strides=(2,2))(en2)
    #encoded = tf.nn.leaky_relu(encoded)
batch_size = tf.shape(encoded)[0]
latent_samples = tf.placeholder(tf.float32,(None,7,7,32))
log_sigma = encoded[:,:,:,0:32]
mu = encoded[:,:,:,32:64]
Z = log_sigma * latent_samples + mu 

with tf.name_scope('decoder'):
    de1 = tf.layers.Conv2DTranspose(64,(3,3),padding='same',strides=(2,2))(Z)
    de1 = tf.nn.leaky_relu(de1)
    de2 = tf.layers.Conv2DTranspose(64,(7,7))(de1)
    de2 = tf.nn.leaky_relu(de2)
    de3 = tf.layers.Conv2DTranspose(64,(9,9))(de2)
    de3 = tf.nn.leaky_relu(de3)
    output = tf.layers.Conv2DTranspose(1,(3,3),padding='same',activation=tf.nn.sigmoid)(de3)
    
squared_loss = tf.reduce_sum(tf.)#tf.reduce_sum(tf.square(X - output))
kl = 0.5 * tf.reduce_sum(tf.exp(log_sigma) + tf.square(mu) - 1. - log_sigma)
mnist_train = tf.train.AdamOptimizer().minimize(squared_loss+kl)

In [6]:
epochs = 10
batch_s = 1
with tf.Session() as sess:
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    #saver.restore(sess,'./checkpoints/simple.ckpt')
    try:
        for i in range(epochs):
            print('Epoch ' + str(i))
            for b in range(42000):
                sess.run(mnist_train,{X:train_x[b*batch_s:b*batch_s + batch_s],latent_samples:latent_dist[b*batch_s:b*batch_s + batch_s]})
                if b % 100 == 0:
                    saver.save(sess,'./checkpoints/mnist.ckpt')
                    print("Loss: " + str(sess.run(squared_loss,{X:train_x[b*batch_s:b*batch_s + batch_s],latent_samples:latent_dist[b*batch_s:b*batch_s + batch_s]})))
    except Exception as e:
        traceback.format_exc(e)
    finally:
        saver.save(sess,'./checkpoints/mnist.ckpt')
        print('Model saved')


Epoch 0
Loss: 125.301704
Loss: 87.7269
Loss: 110.86874
Loss: 84.88245
Loss: 46.786957
Loss: 221.25635
Loss: 80.600525
Loss: 47.646824
Loss: 80.658554
Loss: 133.14206
Loss: 38.80171
Loss: 68.400055
Loss: 119.239525
Loss: 75.19002
Loss: 80.298256
Loss: 46.840652
Loss: 89.259476
Loss: 85.79883
Loss: 108.181274
Loss: 110.957466
Loss: 137.32675
Loss: 100.30733
Loss: 52.80013
Loss: 33.30529
Loss: 90.97588
Loss: 66.37424
Loss: 105.6944
Loss: 25.201963
Loss: 52.026848
Loss: 115.495895
Loss: 85.96664
Loss: 92.94201
Loss: 34.9952
Loss: 131.6127
Loss: 31.889507
Loss: 93.80492
Loss: 132.92883
Loss: 120.321045
Loss: 50.53984
Loss: 56.542458
Loss: 121.10282
Loss: 72.479225
Loss: 79.92617
Loss: 105.00981
Loss: 74.58396
Loss: 87.32008
Loss: 87.19834
Loss: 69.84728
Loss: 86.359276
Loss: 160.2831
Loss: 62.92264
Loss: 124.19446
Loss: 101.26145
Loss: 64.333015
Loss: 73.89307
Loss: 42.36621
Loss: 87.12672
Loss: 109.871185
Loss: 116.77636
Loss: 95.64516
Loss: 34.949604
Loss: 88.78366
Loss: 106.774086
Loss: 

KeyboardInterrupt: 