In [1]:
import tensorflow as tf
import numpy as np

# Size of the hidden state.
HIDDEN_SIZE = 4

# Size of the MANN memory.
MEMORY_SIZE = 6

# Batch size.
#BATCH_SIZE = 2

# Length of sequence (= single recurrent layer of number of units)
SEQ_LENGTH = 3 # in fact 4 elements, but: 0->1, 1->2, 2->3

# Number of smallest elements.
N_SMALLEST = 1

# "Update weight decay".
GAMMA = 0.1


In [2]:
# Reset graph - just in case.
tf.reset_default_graph()

# Placeholders for inputs.
with tf.name_scope("Memory"):
    # Memory.
    memory = tf.Variable(tf.zeros([HIDDEN_SIZE, MEMORY_SIZE]), trainable=False, name="Memory_M")
    # Latest vs LRU ratio.
    alpha = tf.Variable(tf.truncated_normal(shape=[1]), name="Alpha")


with tf.name_scope("Previous"):
    # Placeholders for previous weights.
    prev_read_weights_seq_batch = list()    
    prev_update_weights_seq_batch = list()    
    for i_seq in range(SEQ_LENGTH):
        prev_read_weights_seq_batch.append(tf.placeholder(tf.float32, shape=None, name="Prev_rw"))
        prev_update_weights_seq_batch.append(tf.placeholder(tf.float32, shape=None, name="Prev_uw"))
    
# SET INITIAL MEMORY STATE.
memory_set = memory.assign(tf.transpose([
    [0.0, 0, 0, 1],
    [0, 0, 1, 0],
    [0, 1, 0, 0],
    [1, 0, 0, 0],
    [1, 0, 1, 0],
    [1, 1, 0, 0]]))
alpha_set = alpha.assign([0.1])


# Placeholders for inputs.
with tf.name_scope("Input_data"):
    # Define input data buffers.
    data_buffers = list()
    for _ in range(SEQ_LENGTH + 1):
        # Collect placeholders for inputs/labels.
        data_buffers.append(tf.placeholder(tf.float32, shape=None, name="data_buffers"))
    print ("data_buffers shape =", data_buffers[0].shape)

    # Sequence of batches.
    input_seq_batch = data_buffers[:SEQ_LENGTH]
    print ("Seq length  =", len(input_seq_batch))
    print ("Batch shape =", input_seq_batch[0].shape)

    # Labels are pointing to the same placeholders!
    # Labels are inputs shifted by one time step.
    labels_seq_batch = data_buffers[1:]  
    # Concatenate targets into 2D tensor.
    target_batch = tf.concat(labels_seq_batch, 0)
    
    
# Create SEQ_LEN x BATCH_SIZE placeholders for similarity - each MEMORY_SIZE x 1,  
with tf.name_scope("Read_head"):

    # Normalize sequence of batches.
    norm_seq_batch = list()
    for i_seq in range(SEQ_LENGTH):
        # Collect placeholders for inputs/labels.
        norm_seq_batch.append(tf.nn.l2_normalize(input_seq_batch[i_seq],1) )
    
    # Normalize memory.
    norm_memory = tf.nn.l2_normalize(memory,0)
    print("norm_memory =", norm_memory)
    
    # Calculate cosine similarity.
    similarity_seq_batch = list()    
    for i_seq in range(SEQ_LENGTH):
        similarity_seq_batch.append(tf.matmul(norm_seq_batch[i_seq], norm_memory))
    
    # Calcualte read weights based on similarity.
    read_weights_seq_batch = list()    
    for i_seq in range(SEQ_LENGTH):
        read_weights_seq_batch.append(tf.nn.softmax(similarity_seq_batch[i_seq]))    

# TODO: add dependencies, that write will be done after read.
with tf.name_scope("Write_head"):

    # Calcualte read weights based on similarity.
    write_weights_seq_batch = list()  
    #test_batch = list()
    for i_seq in range(SEQ_LENGTH):
        # "Truncation scheme to update the least-used positions".
        # First, find (size-n) top elements (in each "batch sample"/head separatelly).
        top = tf.nn.top_k(-prev_update_weights_seq_batch[i_seq], N_SMALLEST)
        # To get boolean True/False values, you can first get the k-th value and then use tf.greater_equal:
        kth = tf.reduce_min(top.values, axis=1, keep_dims=True)
        top2 = tf.greater_equal(-prev_update_weights_seq_batch[i_seq], kth)
        # And finally - cast it to n smallest elements.
        prev_smallest_lru_weights = tf.cast(top2, tf.float32)

        #write_weights_seq_batch.append(prev_smallest_lru_weights)
        write_weights_seq_batch.append(tf.add(tf.sigmoid(alpha) * prev_read_weights_seq_batch[i_seq],
                               (1.0 - tf.sigmoid(alpha)) * prev_smallest_lru_weights,
                               name="Write_weights_ww"))
        #test_batch.append(top)
        
with tf.name_scope("Memory_update"):
    calculated_mem_update_seq_batch = list()
    for i_seq in range(SEQ_LENGTH):
        # Perform single update for each sequence/batch.
        calculated_mem_update_seq_batch.append(tf.tensordot(tf.transpose(input_seq_batch[i_seq]), 
                                                            write_weights_seq_batch[i_seq], axes=1))
    # Sum updates.
    mem_update = tf.add_n(calculated_mem_update_seq_batch)
    # Update the memory
    memory_update_op = memory.assign(memory + mem_update)

with tf.name_scope("Update_head"):
    # This relies on prev. weights and will be used in fact in the NEXT step.
    update_weights_seq_batch = list()    
    for i_seq in range(SEQ_LENGTH):
        update_weights_seq_batch.append(tf.add(GAMMA * prev_update_weights_seq_batch[i_seq],
                                               read_weights_seq_batch[i_seq] + write_weights_seq_batch[i_seq],
                                               name="Update_weights_uw"))

    
# Finally - initialize all variables.
initialize_model = tf.global_variables_initializer()    

data_buffers shape = <unknown>
Seq length  = 3
Batch shape = <unknown>
norm_memory = Tensor("Read_head/l2_normalize_3:0", shape=(4, 6), dtype=float32)


In [3]:
def create_feed_dict(batch_seq):
    """Creates a dictionaries for different sets: maps data onto Tensor placeholders."""
    feed_dict = dict()
    # Feed batch to input buffers.
    for i in range(SEQ_LENGTH + 1):
        feed_dict[data_buffers[i]] = batch_seq[i]
    # Reset previous state and output
    for i in range(SEQ_LENGTH):
        feed_dict[prev_read_weights_seq_batch[i]] = prev_rw_seq_batch[i]
        feed_dict[prev_update_weights_seq_batch[i]] = prev_uw_seq_batch[i]
    #feed_dict[prev_read_weights_seq_batch] = prev_rw
    #feed_dict[prev_update_weights_seq_batch] = prev_uw
    
    
    return feed_dict # {prev_output: train_output_zeros, prev_state: train_state_zeros }

In [4]:

########################
# Execute graph.
sess=tf.InteractiveSession()
# Initialize.
sess.run(initialize_model)
memory_, _, norm_memory_ = sess.run([memory_set, alpha_set, norm_memory])
print("Memory =\n",memory_)
print("norm_memory_ =\n",norm_memory_)

# Batch - of dimensions: SEQUENCE x BATCH x VECTOR SIZE
batch_seq = np.array([[[0, 0, 1, 0],[0, 0, 1, 1]],
             [[0, 1, 0, 0],[1, 0, 0, 0]],
             [[0, 0, 1, 0],[0,1,0,0]],
            [[0, 0, 1, 0],[0,1,0,0]]])
#batch_seq = np.array([[[0, 0, 1, 0]],
#             [[0, 1, 0, 0]],
#             [[-0.1, 0.2, 1, 0.1]],
#            [[-0.1, 0.2, 1, 0.1]]]) # "additional row"

#batch_seq = np.array([[[0, 0, 1, 0],[0, 1, 0, 0],[0, 0, 1, 0]],
#             [[0, 0, 1, 0],[1, 0, 0, 0],[0,1,0,0]]])
# Reset previous state and output
prev_rw_seq_batch = list()
prev_uw_seq_batch = list()
for i in range(SEQ_LENGTH):
    prev_rw_seq_batch.append(np.zeros([2, MEMORY_SIZE]))
    prev_uw_seq_batch.append(np.zeros([2, MEMORY_SIZE]))

#print("prev_uw_seq_batch=\n",prev_uw_seq_batch[0].shape)


print("Batch=\n",batch_seq.shape)

for i in range(2):
    print("\n=================\nIteration = ",i)
    input_seq_batch_, norm_seq_batch_, similarity_seq_batch_, prev_rw_seq_batch, write_weights_seq_batch_, prev_uw_seq_batch, mem_update_, memory_  = sess.run([
        input_seq_batch, norm_seq_batch, similarity_seq_batch, read_weights_seq_batch, write_weights_seq_batch, update_weights_seq_batch, mem_update, memory_update_op],
        feed_dict=create_feed_dict(batch_seq))

    for i in range(SEQ_LENGTH):
        print("inputs[",i, "] =\n",input_seq_batch_[i])

    #for i in range(SEQ_LENGTH):
    #    print("norm_batch[",i, "] =\n",norm_seq_batch_[i])

    #for i in range(SEQ_LENGTH):
    #    print("similarity_seq_batch_[",i, "] = ",similarity_seq_batch_[i])

    for i in range(SEQ_LENGTH):
        print("prev_rw_seq_batch[",i, "] = ",prev_rw_seq_batch[i])

    for i in range(SEQ_LENGTH):
        print("write_weights_seq_batch_[",i, "] = ",write_weights_seq_batch_[i])

    for i in range(SEQ_LENGTH):
        print("prev_uw_seq_batch[",i, "] = ",prev_uw_seq_batch[i])

    print("mem_update =\n", mem_update_)

    print("memory =\n ",memory_)
        
        

Memory =
 [[ 0.  0.  0.  1.  1.  1.]
 [ 0.  0.  1.  0.  0.  1.]
 [ 0.  1.  0.  0.  1.  0.]
 [ 1.  0.  0.  0.  0.  0.]]
norm_memory_ =
 [[ 0.          0.          0.          0.99999994  0.70710677  0.70710677]
 [ 0.          0.          0.99999994  0.          0.          0.70710677]
 [ 0.          0.99999994  0.          0.          0.70710677  0.        ]
 [ 0.99999994  0.          0.          0.          0.          0.        ]]
Batch=
 (4, 2, 4)

Iteration =  0
inputs[ 0 ] =
 [[ 0.  0.  1.  0.]
 [ 0.  0.  1.  1.]]
inputs[ 1 ] =
 [[ 0.  1.  0.  0.]
 [ 1.  0.  0.  0.]]
inputs[ 2 ] =
 [[ 0.  0.  1.  0.]
 [ 0.  1.  0.  0.]]
prev_rw_seq_batch[ 0 ] =  [[ 0.1143328   0.31078878  0.1143328   0.1143328   0.23188008  0.1143328 ]
 [ 0.23298407  0.23298407  0.11487716  0.11487716  0.1894004   0.11487716]]
prev_rw_seq_batch[ 1 ] =  [[ 0.1143328   0.1143328   0.31078878  0.1143328   0.1143328   0.23188008]
 [ 0.10230691  0.10230691  0.10230691  0.278099    0.20749018  0.20749018]]
prev_rw_seq_ba