In [2]:
%matplotlib inline
import tensorflow as tf
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from __future__ import print_function
from tensorflow.examples.tutorials import mnist
import os

What is the core idea behind DRAW?
DRAW extends the capability of variational auto-encoders using the following 2 methods:

* (1)<u><b>Progressive Refinement:</u></b>
    * The neural net is asked to merely improve the image, rather than finishing the image in 1shot.
    * Each step the distribution will improve small features at a time
    * The trick is to sample from <u>iterative refinement distribution</u> , $\boldsymbol { P(C_t | C_{t-1})}$ rather than directly from $\boldsymbol {P(C)}$
    <img src='draw-1.png' height=100px width=100px>
* (2)<u><b>Spatial Attention:</u></b>
    * This is an extension to the above refinement to spatial domain.
    * Here , we ask the network to improve a small region of the image at a time.

##### <font color="red">PS: In the above case, there are some subtleties to settle, like how big should the attention patch be? Should the "penstrokes" be sharp or blurry.? These all dynamic parameters  will be learned by the DRAW model.</font>

### The Model

The draw model uses RNN  that run for <b>T</b> steps of progressive refinement. Below pic is the neural network, where the RNNs have been unrolled across time - everything is feed forward now.

<img src='draw-2.png'>

In [3]:
# Constants 

data_dir = ""
read_attn=True 
write_attn=False

In [11]:
#MODEL Params

A=28 # width
B = 28 # Height
img_size = B * A # Canvas size
enc_size = 256 # num_hidden size / num_out in LSTM
dec_size = 256

read_n = 5
write_n = 5

#Below logic to set read window if attention switched on/off
read_size = 2* read_n * read_n if read_attn else 2 * img_size
write_size = write_n * write_n if write_attn else img_size

z_size=10 # QSampler output size

T =10 # MNIST Generation sequence length
batch_size = 100
train_iters = 10000
learning_rate = 1e-3
eps = 1e-8 # epsilon for numerical stability


DO_SHARE=None # hack for variable_scope(reuse=True)


In [12]:
# Now comes the model construction part
X = tf.placeholder(dtype=tf.float32,shape=(batch_size,img_size))
e = tf.random_normal((batch_size,z_size),mean=0,stddev=1) # QSampler noise standard normal distribution
lstm_enc = tf.contrib.rnn.LSTMCell(enc_size,state_is_tuple=True)
lstm_dec = tf.contrib.rnn.LSTMCell(dec_size,state_is_tuple=True)



Single cell eqn:<br>

<img src='draw-3.png' width =300px height=300px>

In [13]:
def linear(x,output_dim):
    """
    affine transformation Wx+b
    assumes x.shape=(batch_size,num_features)
    """
    w = tf.get_variable("w",[x.get_shape()[1],output_dim])
    b = tf.get_variable("b",[output_dim], initializer=tf.constant_initializer(0.0))
    return tf.matmul(x,w)+b


def filter_bank(gx,gy,sigma_sqr,delta,N):
    grid_i = tf.reshape(tf.cast(tf.range(N),tf.float32),[1,-1])
    mu_x = gx + (grid_i - N/2 - 0.5) * delta
    mu_y = gy + (grid_i - N/2 - 0.5) * delta
    a = tf.reshape(tf.cast(tf.range(A),tf.float32),[1,1,-1])
    b = tf.reshape(tf.cast(tf.range(B),tf.float32),[1,1,-1])
    
   
    
    
def attention_window(scope,h_dec,N):
    with tf.variable_scope(scope,reuse=DO_SHARE):
        params = linear(h_dec,5)
        
    gx_tilde,gy_tilde,log_sigma2,log_delta,log_gamma = tf.split(params,5,1)
    gx = (A+1)/2 * (gx_tilde+1)
    gy = (B+2)/2 * (gy_tilde+1)
    sigma_sqr = tf.exp(log_sigma2) # This ensures that the sigma goes in positive value range
    delta = (max(A,B)-1)/ (N-1) * tf.exp(log_delta) # batch x N
    
    

In [14]:
def bce(targ,out):
    return -(targ* tf.log(out+eps) + ((1.0-targ)* tf.log(1.0-out+eps)))

def read_no_attn(x,x_cap,h_dec_prev):
    return tf.concat([x,x_cap],axis=1)

def read_with_attn(x,x_cap,h_dec_prev):
    


In [9]:
#state variables
cs = [0.0]*T #Sequence of canvases
mus,logsigmas,sigmas = [0]*T,[0]*T,[0]*T #Gaussian params generated by SampleQ. We will need this for computing loss

#initial states
h_dec_prev = tf.zeros((batch_size,img_size))
enc_state = lstm_enc.zero_state(batch_size,tf.float32)
dec_state = lstm_dec.zero_state(batch_size,tf.float32)

#draw model

#construct the unrolled graph
#Single feed forwards....
for t in range(T):
    c_prev = tf.zeros((batch_size,img_size)) if t==0 else cs[t-1]
    x_cap = X-tf.sigmoid(c_prev)
    r = read(X,x_cap,h_dec_prev)
    h_enc,enc_state = encode(enc_state,tf.concat([r,h_dec_prev],axis=1))
    z,mus[t],logsigmas[t],sigmas[t]=sampleQ(h_enc)
    h_dec,dec_state = decode(dec_state,z)
    cs[t] = c_prev + write(h_dec)
    DO_SHARE=True
    
    
    
