In [1]:
import tensorflow as tf
import numpy as np

In [2]:
# Time steps
T = 20

# Embedding size
m = 64

# Hidden size
hidden_size = 128

# Stack initialized?
init = False

In [3]:
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

In [4]:
s = tf.constant([], shape=(1,0), dtype=tf.float32, name='Strengths')
V = tf.constant([], shape=(m,0), dtype=tf.float32, name='Stack')

In [5]:
def stack_update(d, u, v, sess):
    '''
    Performs an update to the neural stack.
    
    Args:
      d: Push probability.
      u: Pop probability.
      v: Push value.
    
    Returns:
      r: The value read from the stack.
    '''
    global s, V, m, init
    
    if init:
        # Pop operation
        read0     = tf.zeros_like(V[:,0])   # Read value
        idx0      = tf.shape(V)[1] - 1      # Index into the stack
        rem0      = tf.constant(u)          # Remaining strength

        initialization = (V, s, rem0, read0, idx0)

        def check(stack, strengths, remaining, read, idx):
            # Bottom of stack
            return idx >= 0 #tf.logical_and(idx >= 1, remaining != 0)

        def update(stack, strengths, remaining, read, idx):
            # Amount we can use at this step
            this_qty = tf.minimum(remaining, strengths[:,idx])

            # Update read value
            read = tf.reshape(read + this_qty * V[:,idx], tf.shape(read))  # for shape constraints

            # Update remaining strength
            remaining = tf.reshape(remaining - this_qty, tf.shape(remaining))

            # Update strengths
            before = strengths[:,:idx]
            this   = [tf.sub(strengths[:,idx], this_qty)]
            after  = strengths[:,idx+1:]

            strengths = tf.reshape(tf.concat(1, [before, this, after]), tf.shape(strengths))

            # Update index
            idx = idx - 1

            return (stack, strengths, remaining, read, idx)

        result = tf.while_loop(check, update, initialization)

        # Update strengths and perform read
        _, s, _, r, _ = sess.run(result)
        
    else:
        r = np.zeros((1, m), dtype=np.float32)
        init = True
    
    # Perform push
    V = tf.concat(1, [V, tf.reshape(tf.constant(v, dtype=tf.float32), (m, 1))])
    s = tf.concat(1, [s, tf.reshape(tf.constant(d, dtype=tf.float32), (1, 1))])
    
    return r

In [6]:
# Push
for i in range(5):
    print stack_update(1.0, 1.0, np.eye(64)[i], sess)
    print

[[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]]

[ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]

[ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]

[ 0.  0.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0