#  TensorFlow Scan Function
## Refs
1.  https://rdipietro.github.io/tensorflow-scan-examples/  

## The reason for learning about the scan function

The output of an RNN comes from the repetitive application of the RNN cell's basic function mapping input and the last value of the RNN output to a new value of RNN output.  For the RNN to learn to remember things for several steps it needs to know the gradient of the output with respect to the input several stages ago.  Training RNNs is a gradient descent process, just like for other neural nets.  So to train the weights to pay attention to past data, the gradient must include information about how the weights several stages in the past would change current output values.  Calculating that can be problematic.  A simple example will illustrate.  

Suppose you're using a recursion to raise w to an integer power and you want to take the derivative of the resulting function with respect to w.  Here's tensorflow code to accomplish that. 

In [2]:
import tensorflow as tf
import numpy as np


g1 = tf.Graph()

with g1.as_default():
    
    #define recursion    
#def fn(X, w):
#    print X, w
#    return tf.matmul(w, X)

    #result of recursion
    x = tf.Variable(1.0, dtype=tf.float32, name='x')
    
    #w
    w = tf.Variable(1.0, dtype=tf.float32, name='w')
    
    #successive application of recursion
    wx = x*w
    w2x = wx* w
    w3x = w2x*w
    #print w3x
    
    #calculate gradient
    gradW = tf.gradients(w3x, [w])[0]

#with tf.Session(graph=g1) as sess:
with tf.Session(graph=g1) as sess:
    writer = tf.train.SummaryWriter('logs/',graph=sess.graph)
    sess.run(tf.initialize_all_variables())
    deriv = sess.run([w3x, gradW])
    print deriv, gradW


[1.0, 3.0] Tensor("gradients/AddN:0", shape=(), dtype=float32)


In [3]:
sess.close()

Use tensorboard to check out the graph of this tf program.  It's pretty simple.  Three multiplications by w in a sequence.  So it produces $w^3 x$ and takes the derivative of that function with respect to w with x=1.0.  The result gets printed, so you can see that the function $w^3 x = 1.0$ (w and x are both initialized to 1.0).  The derivative of the function is 3.0 as you'd expect.

## Q
1.  What would you have to do to calculate the derivative of $w^4 x$?
2.  How would you approach the problem of calculating the derivative  of $w^n x$ where n can be any ingeter in tensorflow?  
3.  Suppose the input to you tensorflow model was a list of integers fed in one at a time and your task was to return the correct derivative of w raised to that integer power.  How would you attack that?  
4.  How does this relate to recurrent neural nets?  

The code below may help you understand why this matters for RNN's.  

In [4]:
g2 = tf.Graph()

with g2.as_default():
    
    #define recursion
    def fn(X, W):
        return W * X

    #result of recursion
    x = tf.Variable(1.0, dtype=tf.float32, name='x')
    
    #w
    w = tf.Variable(1.0, dtype=tf.float32, name='w')
    
    #successive application of recursion
    wnx = x
    for i in range(3):
        wnx = fn(wnx, w)
    
    #print w3x
    
    #calculate gradient
    gradW = tf.gradients(wnx, [w])[0]

#with tf.Session(graph=g1) as sess:
with tf.Session(graph=g2) as sess:
    writer = tf.train.SummaryWriter('logs/',graph=sess.graph)
    sess.run(tf.initialize_all_variables())
    deriv = sess.run([wnx, gradW])
    print deriv, gradW

[1.0, 3.0] Tensor("gradients/AddN:0", shape=(), dtype=float32)


## Q
How is this related to the recursion in an RNN?

The recursion for a simple Elman net goes something like $h_t = f(W_hh_{t-1} + W_xx_t)$, where $h_t$ is the hidden state at t and $x_t$ is the input at t.  Adjusting the weight matrices ($W_h$ and $W_x$) requires taking gradients backward in time through several recursive applications of the RNN recursion f().  If the number of recursions is fixed for some reason, then you can take the approach you've seen above.  Build a for-loop with the correct number of steps and have each step backwards laid out on the computation graph.  However, in some problems the sequences that are input for training do not have fixed lengths.  Protein folding and machine translation are both examples where each of the input examples has a different length.  For these problems, it is necessary to have a more general representation of the derivative of "n" recursive applications of RNN recursion.  This would be somewhat analogous to encoding the general expression for the derivative of variable raised to an integer power - $\frac{d}{dw} w^n x = nw^{n-1}x$ instead of performing each of the multiplications as you saw in the examples above.  Then the iterative application of the recursion and its gradient can be represented as a single block on the graph - a block that incorporates the smarts to compute the derivative correctly no matter the length on the sequence.  

In [8]:
import tensorflow as tf

g3 = tf.Graph()
with g3.as_default():
    #define recursion
    def fn(previous_output, current_input):
        w = tf.Variable(2.0, name='w')
        return w * previous_output * current_input
    
    inputList = tf.placeholder(tf.float32, shape=(None))
    seqLength = tf.placeholder(tf.int32, shape=[1])
    print(seqLength)
    elems = inputList
    #elems = tf.identity(elems)
    initializer = tf.constant(1.0)
    out = tf.scan(fn, elems, initializer=initializer)
    
    outShape = tf.shape(out)
    nList = outShape[0]
    trainables = tf.trainable_variables()
    grad = tf.gradients(tf.slice(out, seqLength-1, [1]), trainables)[0]
    print(grad)

with tf.Session(graph=g3) as sess:
    writer = tf.train.SummaryWriter('logs/',graph=sess.graph)
    sess.run(tf.initialize_all_variables())
    print(sess.run([outShape, out, grad, nList], {inputList:[1.0, 1.0, 1.0, 1.0, 1.0] , seqLength: [5] }))
    print(sess.run([outShape, out, grad, nList], {inputList:[1.0, 1.0, 1.0, 1.0, 1.0, 1.0] , seqLength: [6]}))

Tensor("Placeholder_1:0", shape=(1,), dtype=int32)
Tensor("gradients/scan/while/mul/Enter_grad/b_acc_3:0", dtype=float32)
[array([5], dtype=int32), array([  2.,   4.,   8.,  16.,  32.], dtype=float32), 80.0, 5]
[array([6], dtype=int32), array([  2.,   4.,   8.,  16.,  32.,  64.], dtype=float32), 192.0, 6]


As you can see when the input sequence is 5 long, the derivative of the scan function output is calculating the derivative of $w^5$ and getting $5w^4$ and when it's 6 long the is being taken of $w^6$ and the calculation yields $6w^5$.  The scan function uses the input function (fn in the example above) to recursively derive the network output and recursively applies the derivative of fn to recursively define the gradients of the network output - including dependences on inputs and outputs (potentially) many steps into the past.  

In [27]:
import tensorflow as tf

g4 = tf.Graph()
with g4.as_default():
        
    #define recursion
    def fn(previous_output, current_input):
        w = tf.Variable(2.0, name='w')
        return w * previous_output * current_input
    
    
    inputList = tf.placeholder(tf.float32, shape=(None))
    seqLength = tf.shape(inputList)
    #seqLength = tf.placeholder(tf.int32, shape=[1])
    print(seqLength)
    elems = inputList
    #elems = tf.identity(elems)
    #initializer = tf.constant(1.0)
    lastOutput = tf.Variable(1.0, name='lastOutput')
    #lastOutput = tf.constant(1.0)
    out = tf.scan(fn, elems, initializer=lastOutput)
    print(out)
    
    
    lou = lastOutput.assign(tf.squeeze(tf.slice(out, seqLength-1, [1])))
    trainables = tf.trainable_variables()
    grad = tf.gradients(tf.slice(out, seqLength-1, [1]), trainables)[0]
    print(grad)

with tf.Session(graph=g4) as sess:
    writer = tf.train.SummaryWriter('logs/',graph=sess.graph)
    sess.run(tf.initialize_all_variables())
    print(sess.run([out, grad, seqLength, lou], {inputList:[1.0, 1.0, 1.0, 1.0, 1.0] }))
    print(sess.run([out, grad, seqLength, lou], {inputList:[1.0, 1.0, 1.0, 1.0, 1.0, 1.0] }))
    #print(sess.run([outShape, out, grad, nList], {inputList:[1.0, 1.0, 1.0, 1.0, 1.0] , seqLength: [5] }))
    #print(sess.run([outShape, out, grad, nList], {inputList:[1.0, 1.0, 1.0, 1.0, 1.0, 1.0] , seqLength: [6]}))
        

Tensor("Shape:0", shape=(?,), dtype=int32)
Tensor("scan/TensorArrayPack:0", dtype=float32)
Tensor("gradients/scan/while/Enter_1_grad/Exit:0", shape=(), dtype=float32)
[array([  2.,   4.,   8.,  16.,  32.], dtype=float32), 32.0, array([5], dtype=int32), 32.0]
[array([   64.,   128.,   256.,   512.,  1024.,  2048.], dtype=float32), 64.0, array([6], dtype=int32), 2048.0]


In [8]:
import tensorflow as tf

g5 = tf.Graph()
with g5.as_default():
    
    def init_weights(shape, name, glorot=False):
        [n_inputs, n_outputs] = shape
        init_range = tf.sqrt(6.0 / (n_inputs + n_outputs))
        if glorot: return tf.Variable(tf.random_uniform(shape, -init_range, init_range), name=name)
        else: return tf.Variable(tf.random_normal(shape, stddev=0.01), name=name)
    
    def bias_variable(shape, name):
        initial = tf.constant(0.1, shape=shape)
        return tf.Variable(initial, name=name)
        
    #define recursion
    def fn(htm1, current_input):
        #h_t = f(h_{t-1} W_hh + i_t W_hi + b_h)
        #o_t = f2(h_{t-1} W_oh )
        netDim = 200
        inputDim = 1
        b_h = bias_variable(shape=[1, netDim], name='b_h')
        W_hh = init_weights(shape=[netDim, netDim], name='W_hh')
        W_ih = init_weights(shape=[inputDim, netDim], name='W_ih')
                
        ht = tf.nn.elu(tf.matmul(htm1, W_hh) + tf.matmul(current_input, W_ih) + b_h)
        print(tf.shape(ht))
        return ht
    
    inputList = tf.placeholder(tf.float32, shape=(None))
    label = tf.placeholder(tf.float32, shape=[None])
    seqLength = tf.shape(inputList)
    #seqLength = tf.placeholder(tf.int32, shape=[1])
    print(seqLength)
    elems = inputList
    #elems = tf.identity(elems)
    #initializer = tf.constant(1.0)
    lastOutput = tf.Variable(tf.ones([1,200], name='lastOutput'))
    #lastOutput = tf.constant(1.0)
    out = tf.scan(fn, elems, initializer=lastOutput)
    print(out)
    
    temp = tf.squeeze(tf.slice(out, seqLength-1, [1]))
    print(temp)
    lou = lastOutput.assign(temp)
    trainables = tf.trainable_variables()
    grad = tf.gradients(tf.slice(out, seqLength-1, [1]), trainables)[0]
    print(grad)

with tf.Session(graph=g5) as sess:
    writer = tf.train.SummaryWriter('logs/',graph=sess.graph)
    sess.run(tf.initialize_all_variables())
    print(sess.run([out, grad, seqLength, lou], {inputList:[1.0, 1.0, 1.0, 1.0, 1.0], label: 5.0}))
    

Tensor("Shape:0", shape=(?,), dtype=int32)
Tensor("scan/while/Shape:0", shape=(2,), dtype=int32)
Tensor("scan/TensorArrayPack:0", dtype=float32)
Tensor("Squeeze:0", shape=(), dtype=float32)


ValueError: Shapes (1, 200) and () are not compatible

In [28]:
print out

Tensor("scan/TensorArrayPack:0", dtype=float32)


In [29]:
print out[0]

Tensor("Squeeze_1:0", shape=(), dtype=float32)


In [30]:
print(tf.unpack(out))

ValueError: Cannot infer num from shape <unknown>