In [1]:
from __future__ import print_function
import os

import numpy as np
import zipfile
import tarfile
from six.moves.urllib.request import urlretrieve
import shutil 
import random
import math

import string
import tensorflow as tf

# HYPERPARAMETERS.
# Size of batch
BATCH_SIZE = 1
# Length of sequence (=  number of units of controller (recurrent layer))
SEQ_LENGTH = 2

# Length of a vector - a single hidden state
HIDDEN_SIZE = 16

# Size of the input vector - 59 = [A-Z] + [a-z] + ' '
INPUT_SIZE = 59

#### MANN-related parameters.http://localhost:8888/notebooks/tensorflow-sandbox/ptb/pure_formulas/ptb_char_lstm_placeholder_memory.ipynb#
# Size of the local memory of each cell.
NUMBER_OF_SLOTS = 10
# Length of a single vector - a single memory slot.
ADDRESS_SIZE = 4
CONTENT_SIZE = 4
SLOT_SIZE = ADDRESS_SIZE + CONTENT_SIZE

# Dirs - must be absolute paths!
#LOG_DIR = '/tmp/tf/ptb_char_lstm_ntm_memory3D/B100S20_H64_N64A8C16/'
LOG_DIR = "/tmp/tf/ptb_char_lstm_ntm_memory3D/B"+\
    str(BATCH_SIZE)+"S"+str(SEQ_LENGTH)+"_H"+str(HIDDEN_SIZE)+"_N"+\
    str(NUMBER_OF_SLOTS)+"A"+str(ADDRESS_SIZE)+"C"+str(CONTENT_SIZE)+"_only_memory_output/"
print("Writing TB log to:",LOG_DIR)

# Local dir where PTB files will be stored.
PTB_DIR = "/home/tkornuta/data/ptb/"

# Filenames.
TRAIN = "ptb.train.txt"
VALID = "ptb.valid.txt"
TEST = "ptb.test.txt"



Writing TB log to: /tmp/tf/ptb_char_lstm_ntm_memory3D/B1S2_H16_N10A4C4_only_memory_output/


### Check/maybe download PTB.

In [2]:
def maybe_download_ptb(path, 
                       filename='simple-examples.tgz', 
                       url='http://www.fit.vutbr.cz/~imikolov/rnnlm/', 
                       expected_bytes =34869662):
  # Eventually create the PTB dir.
  extract = False
  if not tf.gfile.Exists(path):
    tf.gfile.MakeDirs(path)
  """Download a file if not present, and make sure it's the right size."""
  _filename = path+filename
  if not os.path.exists(_filename):
    print('Downloading %s...' % filename)
    _filename, _ = urlretrieve(url+filename, _filename)
    extract = True
  statinfo = os.stat(_filename)
  if statinfo.st_size == expected_bytes:
    print('Found and verified', (_filename), '(', statinfo.st_size, ')')
  else:
    print(statinfo.st_size)
    raise Exception(
      'Failed to verify ' + _filename + '. Can you get to it with a browser?')
  return extract

extract = maybe_download_ptb(PTB_DIR)

Found and verified /home/tkornuta/data/ptb/simple-examples.tgz ( 34869662 )


### Extract dataset-related files from the PTB archive.

In [4]:
def extract_ptb(path, filename='simple-examples.tgz', files=["ptb.train.txt", "ptb.valid.txt", "ptb.test.txt", 
                                       "ptb.char.train.txt", "ptb.char.valid.txt", "ptb.char.test.txt"]):
    """Extracts files from PTB archive."""
    # Extract
    tar = tarfile.open(path+filename)
    tar.extractall(path)
    tar.close()
    # Copy files
    for file in files:
        shutil.copyfile(PTB_DIR+"simple-examples/data/"+file, PTB_DIR+file)
    # Delete directory
    shutil.rmtree(PTB_DIR+"simple-examples/")        


if extract:
    extract_ptb(PTB_DIR)
    print("PTB extracted")
else:
    print("No need to extract PTB")    

No need to extract PTB


### Load train, valid and test texts.

In [5]:
def read_data(filename, path):
    with open(path+filename, 'r') as myfile:
        data=myfile.read()# .replace('\n', '')
        return data

train_text = read_data(TRAIN, PTB_DIR)
train_size=len(train_text)
print(train_size, train_text[:100])

valid_text = read_data(VALID, PTB_DIR)
valid_size=len(valid_text)
print(valid_size, valid_text[:64])

test_text = read_data(TEST, PTB_DIR)
test_size=len(test_text)
print(test_size, test_text[:64])

5101618  aer banknote berlitz calloway centrust cluett fromstein gitano guterman hydro-quebec ipo kia memote
399782  consumers may want to move their telephones a little closer to 
449945  no it was n't black monday 
 but while the new york stock excha


### Utility functions to map characters to vocabulary IDs and back.

In [6]:
first_letter = ord(string.ascii_uppercase[0]) # ascii_uppercase before lowercase! 
print("vocabulary size = ", INPUT_SIZE)
print(first_letter)

def char2id(char):
  """ Converts char to id (int) with one-hot encoding handling of unexpected characters"""
  if char in string.ascii_letters:# or char in string.punctuation or char in string.digits:
    return ord(char) - first_letter + 1
  elif char == ' ':
    return 0
  else:
    # print('Unexpected character: %s' % char)
    return 0
  
def id2char(dictid):
  """ Converts single id (int) to character"""
  if dictid > 0:
    return chr(dictid + first_letter - 1)
  else:
    return ' '

def characters(probabilities):
  """Turn a 1-hot encoding or a probability distribution over the possible
  characters back into its (most likely) character representation."""
  return [id2char(c) for c in np.argmax(probabilities, 1)]

def batches2string(batches):
  """Convert a sequence of batches back into their (most likely) string
  representation."""
  s = [''] * batches[0].shape[0]
  for b in batches:
    s = [''.join(x) for x in zip(s, characters(b))]
  return s

#print(len(string.punctuation))
#for i in string.ascii_letters:
#    print (i, char2id(i))


print(char2id('a'), char2id('A'), char2id('z'), char2id('Z'), char2id(' '), char2id('ï'))
print(id2char(char2id('a')), id2char(char2id('A')))
#print(id2char(65), id2char(33), id2char(90), id2char(58), id2char(0))
#bankno
sample = np.zeros(shape=(1, INPUT_SIZE), dtype=np.float)
sample[0, char2id(' ')] = 1.0
print(sample)

vocabulary size =  59
65
33 1 58 26 0 0
a A
[[ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.]]


### Helper class for batch generation

In [7]:
class BatchGenerator(object):
  def __init__(self, text, batch_size, seq_length, vocab_size):
    """
    Initializes the batch generator object. Stores the variables and first "letter batch".
    text is text to be processed
    batch_size is size of batch (number of samples)
    seq_length represents the length of sequence
    vocab_size is number of words in vocabulary (assumes one-hot encoding)
    """
    # Store input parameters.
    self._text = text
    self._text_size = len(text)
    self._batch_size = batch_size
    self._seq_length = seq_length
    self._vocab_size = vocab_size
    # Divide text into segments depending on number of batches, each segment determines a cursor position for a batch.
    segment = self._text_size // batch_size
    # Set initial cursor position.
    self._cursor = [ offset * segment for offset in range(batch_size)]
    # Store first "letter batch".
    self._last_letter_batch = self._next_letter_batch()
  
  def _next_letter_batch(self):
    """
    Returns a batch containing of encoded single letters depending on the current batch 
    cursor positions in the data.
    Returned "letter batch" is of size batch_size x vocab_size
    """
    letter_batch = np.zeros(shape=(self._batch_size, self._vocab_size), dtype=np.float)
    # Iterate through "samples"
    for b in range(self._batch_size):
      # Set 1 in position pointed out by one-hot char encoding.
      letter_batch[b, char2id(self._text[self._cursor[b]])] = 1.0
      self._cursor[b] = (self._cursor[b] + 1) % self._text_size
    return letter_batch
  
  def next(self):
    """Generate the next array of batches from the data. The array consists of
    the last batch of the previous array, followed by num_unrollings new ones.
    """
    # First add last letter from previous batch (the "additional one").
    batches = [self._last_letter_batch]
    for step in range(self._seq_length):
      batches.append(self._next_letter_batch())
    # Store last "letter batch" for next batch.
    self._last_letter_batch = batches[-1]
    return batches


In [8]:
# Create objects for training, validation and testing batch generation.
train_batches = BatchGenerator(train_text, BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)

# For validation  - process the whole text as one big batch.
VALID_BATCH_SIZE = int(np.floor(valid_size/SEQ_LENGTH))
valid_batches = BatchGenerator(valid_text, VALID_BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
# Get a single batch! 
valid_batch = valid_batches.next()

# For texting  - process the whole text as one big batch.
TEST_BATCH_SIZE = int(np.floor(test_size/SEQ_LENGTH))
test_batches = BatchGenerator(test_text, TEST_BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
# Get a single batch! 
test_batch = test_batches.next()


### Helper functions - used during graph definition

In [134]:
# Function adding visualization to a given "matrix" (keys, memory etc.) with additional normalization
def visualize_hot_cold_normalized(matrix_, axis_of_reduction_, name_):
    with tf.name_scope(name_+"_vis_hc_norm"):
        # Eps for normalization in visualization
        EPS = 1e-10
        
        # Create hot-cold visualization (red=positive/blue=negative)
        zeros = tf.zeros_like(matrix_) 
        print("zeros=",zeros)
        
        # Get negative values only.
        neg = tf.less(matrix_, zeros)
        blue = tf.multiply(tf.cast(neg, tf.float32), matrix_)
        min_blue = tf.reduce_min(matrix_, axis=axis_of_reduction_, keep_dims=True) + EPS
        norm_blue = 255.0 * blue/min_blue
        
        # Get positive values only.
        pos = tf.greater(matrix_, zeros)
        red = tf.multiply(tf.cast(pos, tf.float32), matrix_)
        max_red = tf.reduce_max(matrix_, axis=axis_of_reduction_, keep_dims=True) + EPS
        norm_red = 255.0 * red/max_red
        
        # Stack them into three channel image with hot-cold values.
        rgb = tf.stack([norm_red, zeros, norm_blue], axis=2)

        #print("name_=",name_)
        #print("np.int32(zeros.shape[1]=", np.int32(zeros.shape[1]))
        # TODO: find fix and get rid of the batch_size(!)
        rgb_reshaped = tf.reshape(rgb, [1, -1,  np.int32(zeros.shape[1]), 3])

        # Visualize read weights as image.
        rgb_reshaped_summary = tf.summary.image(name_+"_visv", rgb_reshaped)
        
        
def visualize_hot_cold_normalized_3D (matrix_3d_, axis_of_reduction_, name_):
    print("matrix_3d_=",matrix_3d_)
    #print("matrix_3d_[0]=",[0])
    #tf.foldl(matrix_3d_, )
    #two_slices = tf.slice (matrix_3d_, [0, 0, 0], [2, -1, -1])
    #print("two_slices=",two_slices)
    
    #matrix_list = tf.unstack(two_slices, axis=0)
    #for i in range(len(matrix_list)):
    #    visualize_hot_cold(matrix_list[i], name_+"_a"+str(i))
    

In [135]:
# Function adding visualization to a given "matrix" (keys, memory etc.)
def visualize_hot_cold(matrix_, name_):
    with tf.name_scope(name_+"_vis_hc"):
        # Eps for normalization in visualization
        EPS = 1e-10
        # Create hot-cold visualization (red=positive/blue=negative)
        zeros = tf.zeros_like(matrix_) 
        
        # Get negative values only.
        neg = tf.less(matrix_, zeros)
        blue = 255.0 * tf.multiply(tf.cast(neg, tf.float32), matrix_)
        
        # Get positive values only.
        pos = tf.greater(matrix_, zeros)
        red = 255.0 * tf.multiply(tf.cast(pos, tf.float32), matrix_)
        
        # Stack them into three channel image with hot-cold values.
        rgb = tf.stack([red, zeros, blue], axis=2)

        #print("name_=",name_)
        #print("np.int32(zeros.shape[1]=", np.int32(zeros.shape[1]))
        # TODO: find fix and get rid of the batch_size(!)
        rgb_reshaped = tf.reshape(rgb, [1, -1,  np.int32(zeros.shape[1]), 3])

        # Visualize read weights as image.
        rgb_reshaped_summary = tf.summary.image(name_+"_vis", rgb_reshaped)

#def visualize_hot_cold_3D (matrix_3d_, name_):
#    two_slices = tf.slice (matrix_3d_, [0, 0, 0], [2, -1, -1])
#    matrix_list = tf.unstack(two_slices)
#    for i in range(len(matrix_list)):
#        visualize_hot_cold(matrix_list[i], name_+"_a"+str(i))
 

In [136]:
# Function "fixing" tensordot lack of dimensions.
def tensordot_fix(matrix_a_, matrix_b_, axes_, output_shape_):
    tensor_without_dims = tf.tensordot(matrix_a_, matrix_b_, axes=axes_)
    # TF bugfix - tensordot not returning proper shapes when only partial shapes are known. :]
    # https://github.com/tensorflow/tensorflow/issues/6682
    tensor = tf.reshape(tensor_without_dims, output_shape_)
    return tensor 
    

# Head related ops.

In [149]:
# Additional ops used by NTM.
def build_focusing_by_location_3D(keys_BxS_, beta_Bx1_, prev_memory_BxSxN_):
    """Computes content addressing. Uses both address and content part for calculation of the similarity.
    Memory is 3D i.e. every sample in the batch has its own "memory slice".
    Args:
        keys_BxS_: a 2-D Tensor [BATCH_SIZE x SLOT_SIZE] 
        beta_Bx1_: a 2-D Tensor - key strength [BATCH_SIZE x 1]
        prev_memory_BxSxN_: a 3-D !! Tensor [BATCH_SIZE x SLOT_SIZE x NUMBER_OF_SLOTS]
    """
    with tf.name_scope("focusing_by_location"):
        
        # Normalize key - along samples [BATCH_SIZE x 1 x SLOT_SIZE]
        keys_Bx1xS = tf.expand_dims(keys_BxS_, 1)
        print("!!keys_Bx1xS=",keys_Bx1xS)
        norm_keys_Bx1xS = tf.nn.l2_normalize(keys_Bx1xS,2, name="norm_keys_Bx1xS")
        print("!!keys_Bx1xS=",keys_Bx1xS)

        # Normalize memory - along slots [BATCH_SIZE x SLOT_SIZE x NUMBER_OF_SLOTS]
        norm_memory_BxSxN = tf.nn.l2_normalize(prev_memory_BxSxN_, 1, name="norm_memory_BxSxN")
        print("!!keys_Bx1xS=",norm_memory_BxSxN)
        
        # Calculate batched cosine similarity [BATCH_SIZE x 1 x NUMBER_OF_SLOTS]
        similarity_Bx1xN= tf.matmul(norm_keys_Bx1xS, norm_memory_BxSxN, name="similarity_Bx1xN")
        
        # Element-wise multiplication [BATCH_SIZE x 1 x NUMBER_OF_SLOTS]
        beta_Bx1x1 = tf.expand_dims(beta_Bx1_, 1)
        strengthtened_similarity_Bx1xN = tf.matmul(beta_Bx1x1, similarity_Bx1xN, name="strengthtened_similarity_Bx1xN")

        # Calculate weighting based on similarity along the "slot dimension" [BATCH_SIZE x NUMBER_OF_SLOTS].
        weighting_Bx1xN = tf.nn.softmax(strengthtened_similarity_Bx1xN, dim=2)
            
        # "Squeeze" results into 2D tensor [BATCH_SIZE x NUMBER_OF_SLOTS]
        weighting_BxN = tf.reshape(weighting_Bx1xN, [-1, NUMBER_OF_SLOTS], name="weighting_BxN")
        print ("weighting_BxN=",weighting_BxN)
        return weighting_BxN

def circular_index(idx, size):
    if idx < 0: return size + idx
    if idx >= size : return idx - size
    else: return idx

def build_circular_convolution(batch, kernel):
    """Computes circular convolution.
    Args:
        batch: a 2-D Tensor [BATCH_SIZE x NUMBER_OF_SLOTS] 
        kernel: a 2-D Tensor [BATCH_SIZE x KERNEL_SIZE (e.g. 3)]
    """
    with tf.name_scope("circular_convolution"):
        size = int(batch.get_shape()[1])
        kernel_size = int(kernel.get_shape()[1])
        kernel_shift = int(math.floor(kernel_size/2.0))

        kernels = []
        for i in range(size):
            # Create a list of index vectors.
            indices = [circular_index(i+j, size) for j in range(kernel_shift, -kernel_shift-1, -1)]
            # Reorganize batch according to indices. 
            reorganized_batch = tf.gather(batch, indices, axis=1)
            # Perform convolution.
            kernels.append(tf.reduce_sum(reorganized_batch * kernel, 1))
            
        # Sum elements lying on the same positions.
        result_without_dims = tf.transpose(tf.dynamic_stitch([i for i in range(size)], kernels))
        result_BxN = tf.reshape(result_without_dims, [-1, size])
        return result_BxN


def build_sharpening(batch, gamma):
    """Computes sharpening.
    Args:
        batch: a 2-D Tensor [BATCH_SIZE x NUMBER_OF_SLOTS] 
        gamma: a 1-D Tensor [BATCH_SIZE x 1]
    """
    EPS = 1e-30
    with tf.name_scope("sharpening"):    
        number_of_slots = int(batch.get_shape()[1])

        # Duplicate gammas - tf.tile is not working for partially unknown shape :] 
        gammas = []
        for i in range(number_of_slots):
            # Truncates gamma to 50!
            gammas.append(tf.minimum(gamma[:,0], 50))
        gammas_stacked = tf.transpose(tf.stack(gammas))
        # Calculate powered batch [BATCH_SIZE x NUMBER_OF_SLOTS].
        powed_batch = tf.pow(batch, gammas_stacked)+EPS

        # "Normalization" [BATCH_SIZE x NUMBER_OF_SLOTS].
        sharpened_batch = (powed_batch) / (tf.reduce_sum(powed_batch, axis=1, keep_dims=True))

        return sharpened_batch


In [150]:
def build_head(keys_BxS, beta_Bx1, memory_BxSxN, interpolation_gate_Bx1, prev_weights_BxN, shift_Bx3, gamma_Bx1):
    # 1. Content addressing  [BATCH_SIZE x NUMBER_OF_SLOTS].
    content_weights_BxN = build_focusing_by_location_3D(keys_BxS, beta_Bx1, memory_BxSxN)
    print("content_weights_BxN=",content_weights_BxN)
    
    # 2. Perform "gated interpolation" [BATCH_SIZE x NUMBER_OF_SLOTS].
    gated_weights_BxN = tf.add(interpolation_gate_Bx1 * content_weights_BxN,
        (1 - interpolation_gate_Bx1) * prev_weights_BxN,
        name="gated_weights_BxN")

    # 3. Shift  weights [BATCH_SIZE x NUMBER_OF_SLOTS].
    shifted_weights_BxN = build_circular_convolution(gated_weights_BxN, shift_Bx3)

    # 4. "Sharpen" the weights [BATCH_SIZE x NUMBER_OF_SLOTS].
    sharpened_weights_BxN = build_sharpening(shifted_weights_BxN, gamma_Bx1)

    # Add histograms to TensorBoard.
    tf.summary.histogram("1.content_weights_BxN", content_weights_BxN)
    tf.summary.histogram("2.gated_weights_BxN", gated_weights_BxN)
    tf.summary.histogram("3.shifted_weights_BxN", shifted_weights_BxN)
    tf.summary.histogram("4.sharpened_weights_BxN", sharpened_weights_BxN)
    # Add hot-cold visbuild_ualizations.
    visualize_hot_cold(matrix_=content_weights_BxN, name_="1.content_weights_BxN")          
    visualize_hot_cold(matrix_=gated_weights_BxN, name_="2.gated_weights_BxN")          
    visualize_hot_cold(matrix_=shifted_weights_BxN, name_="3.shifted_weights_BxN")          
    visualize_hot_cold(matrix_=sharpened_weights_BxN, name_="4.sharpened_weights_BxN")          

    return sharpened_weights_BxN



# Memory related ops.

In [167]:
def build_memory_preservation(write_weights_BxN_, erase_vector_BxC_, prev_memory_BxSxN_):
    """Computes how much memory will be preserved using weights and erase vector as params.
    Batched version, i.e. all computations are computed without iteration through the batch samples.
    Args:
        write_weights_BxN: a 2-D Tensor [BATCH_SIZE x NUMBER_OF_SLOTS] 
        erase_vector_BxC: a 2-D Tensor [BATCH_SIZE x CONTENT_SIZE]
        prev_memory_BxSxN: a 3-D !! Tensor [BATCH_SIZE x SLOT_SIZE x NUMBER_OF_SLOTS]
    """
    with tf.name_scope("memory_preservation"):
        
        # Expand dimensions of weights and erase vectors to 3D.
        write_weights_Bx1xN = tf.expand_dims(write_weights_BxN_, axis=1)
        erase_BxCx1 = tf.expand_dims(erase_vector_BxC_, axis=2)
        
        # Calculate the erase content mask.
        erase_content_mask_BxCxN = tf.matmul(erase_BxCx1, write_weights_Bx1xN)
        
        # Calculate the preserved mask.
        preserved_content_mask_BxCxN = tf.ones_like(erase_content_mask_BxCxN) - erase_content_mask_BxCxN
        #print("preserved_content_mask_BxCxN=",preserved_content_mask_BxCxN)
        
        # Create the preserved address mask.
        #preserved_address_mask_BxAxN = tf.ones([int(tf.shape(prev_memory_BxSxN_)[0]), ADDRESS_SIZE, NUMBER_OF_SLOTS], tf.float32)
        preserved_address_mask_BxAxN = tf.ones_like(tf.slice (prev_memory_BxSxN_, [0, ADDRESS_SIZE, 0], [-1, CONTENT_SIZE, NUMBER_OF_SLOTS]))
        #print("preserved_address_mask_BxAxN=",preserved_address_mask_BxAxN)
        
        # Concatenate the latter two.
        preserved_memory_mask_BxSxN = tf.concat(
            [preserved_address_mask_BxAxN, preserved_content_mask_BxCxN],
            axis=1)

        # Finally, calculate the preserved memory part.
        preserved_memory_BxSxN = tf.multiply(preserved_memory_mask_BxSxN, prev_memory_BxSxN_)
        print("preserved_memory_BxSxN=",preserved_memory_BxSxN)
        
        return preserved_memory_BxSxN

def build_memory_update(write_weights_BxN_, add_vector_BxC_, prev_memory_BxSxN_):
    """Computes the update that will be added to the memory.
    Assumes that memory is a 3D tensor  [BATCH_SIZE x SLOT_SIZE x NUMBER_OF_SLOTS]
    Batched version, i.e. all computations are computed without iteration through the batch samples.
    Args:
        write_weights_BxN: a 2-D Tensor [BATCH_SIZE x NUMBER_OF_SLOTS] 
        add_vector_BxC_: a 2-D Tensor [BATCH_SIZE x CONTENT_SIZE]
    """
    with tf.name_scope("memory_update"):
        
        # Expand dimensions of weights and erase vectors to 3D.
        write_weights_Bx1xN = tf.expand_dims(write_weights_BxN_, axis=1)
        add_vector_BxCx1 = tf.expand_dims(add_vector_BxC_, axis=2)
        
        # Calculate the content update.
        content_update_BxCxN = tf.matmul(add_vector_BxCx1, write_weights_Bx1xN)
        
        # Create the addres part - all zeros, so it won't change.        
        #address_update_BxAxN = tf.zeros([int(write_weights_Bx1xN.get_shape()[0]), ADDRESS_SIZE, NUMBER_OF_SLOTS], tf.float32)
        address_update_BxAxN = tf.zeros_like(tf.slice (prev_memory_BxSxN_, [0, ADDRESS_SIZE, 0], [-1, CONTENT_SIZE, NUMBER_OF_SLOTS]))
        print("address_update_BxAxN=",address_update_BxAxN)
        
        # Concatenate the latter two.
        memory_update_BxSxN = tf.concat(
            [address_update_BxAxN, content_update_BxCxN],
            axis=1)

        return memory_update_BxSxN

def build_memory_output(read_weights_BxN_, prev_memory_BxSxN_):
    """Creates ops computing the memory output.
    Assumes that memory is a 3D tensor  [BATCH_SIZE x SLOT_SIZE x NUMBER_OF_SLOTS]
    Batched version, i.e. all computations are computed without iteration through the batch samples.
    Args:
        read_weights_BxN: a 2-D Tensor [BATCH_SIZE x NUMBER_OF_SLOTS] 
        prev_memory_BxSxN_: a 3-D !! Tensor [BATCH_SIZE x SLOT_SIZE x NUMBER_OF_SLOTS]
    """
    with tf.name_scope("memory_output"):
        
        # Expand dimensions of weights to 3D  [BATCH_SIZE x 1 x NUMBER_OF_SLOTS]
        read_weights_Bx1xN = tf.expand_dims(read_weights_BxN_, axis=1)
        print("read_weights_Bx1xN=",read_weights_Bx1xN)
        
        # Get the content  [BATCH_SIZE x NUMBER_OF_SLOTS x CONTENT_SIZE]
        prev_content_BxCxN = tf.slice (prev_memory_BxSxN_, [0, ADDRESS_SIZE, 0], [-1, CONTENT_SIZE, NUMBER_OF_SLOTS])
        prev_content_BxNxC = tf.transpose(prev_content_BxCxN, perm=[0, 2, 1])
        print("prev_content_BxNxC=",prev_content_BxNxC)
        
        # Calculate output [BATCH_SIZE x 1 x CONTENT_SIZE]
        output_Bx1xC = tf.matmul(read_weights_Bx1xN, prev_content_BxNxC, name="output_Bx1xC")
        print("output_Bx1xC=",output_Bx1xC)
        
        # Squeeze the output to [BATCH_SIZE x CONTENT_SIZE] 
        #output_BxC = tf.squeeze(output_Bx1xC, name="output_BxC")
        output_BxC = tf.reshape(output_Bx1xC, [-1, CONTENT_SIZE], name="output_BxC")
        print("output_BxC=",output_BxC)

        return output_BxC
    

In [168]:
# Definition of graph of the MANN controller cell.
def controller_cell(input_BxI, # input x [BATCH_SIZE x INPUT_SIZE]
                    prev_cell_state_BxH, # BATCH x HIDDEN
                    prev_cell_output_BxH, # BATCH X HIDDEN
                    prev_memory_output_BxS, # read vector from the memory returned by previous cell [BATCH_SIZE x SLOT_SIZE]
                    prev_memory_BxSxN, # Value of the memory from previous time state [BATCH_SIZE x SLOT_SIZE x NUMBER_OF_SLOTS]
                    prev_read_weights_BxN, # read weights from previous time state (t-1) [BATCH_SIZE x NUMBER_OF_SLOTS]
                    prev_write_weights_BxN, # read weights from previous time state (t-1) [BATCH_SIZE x NUMBER_OF_SLOTS]
                    name):
    """Creates a controller cell with shared memory"""
    #with tf.name_scope(name):
    
    with tf.name_scope("lstm"):
        #######################################################################################
        # B) Calculate LSTM hidden state and output.

        # Concatenate h_prev ("prev output") with x.
        x_prev_h = tf.concat([input_BxI, prev_cell_output_BxH], 1)

        # Calculate forget, input and output gates activations.
        forget_gate = tf.sigmoid(tf.matmul(x_prev_h, Wf) + bf, name="forget_gate")
        input_gate = tf.sigmoid(tf.matmul(x_prev_h, Wi) + bi, name="Input_gate")
        output_gate = tf.sigmoid(tf.matmul(x_prev_h, Wo) + bo, name="Output_gate")

        # Calcualte cell state update (C~).
        cell_update = tf.tanh(tf.matmul(x_prev_h, Wc) + bc, name="Cell_update")
        # New cell state (C).
        cell_state_BxH = tf.add(forget_gate * prev_cell_state_BxH, input_gate * cell_update, name = "Cell_state")

        # Calculate cell output (h).
        cell_output_BxH = output_gate * tf.tanh(cell_state_BxH)

        # Add histograms to TensorBoard.
        tf.summary.histogram("input_BxI", input_BxI)
        tf.summary.histogram("prev_cell_state_BxH", prev_cell_state_BxH)
        tf.summary.histogram("prev_cell_output_BxH", prev_cell_output_BxH)
        tf.summary.histogram("prev_memory_output_BxS", prev_memory_output_BxS)
        tf.summary.histogram("cell_state_BxH", cell_state_BxH)
        tf.summary.histogram("cell_output_BxH", cell_output_BxH)

        # Add hot-cold visualizations.
        visualize_hot_cold_normalized(matrix_=input_BxI, axis_of_reduction_=None, name_="input_BxI")
        visualize_hot_cold_normalized(matrix_=prev_cell_state_BxH, axis_of_reduction_=None, name_="prev_cell_state_BxH")
        visualize_hot_cold_normalized(matrix_=prev_cell_output_BxH, axis_of_reduction_=None, name_="prev_cell_output_BxH")
        visualize_hot_cold_normalized(matrix_=prev_memory_output_BxS, axis_of_reduction_=None, name_="prev_memory_output_BxS")
        visualize_hot_cold_normalized(matrix_=cell_state_BxH, axis_of_reduction_=None, name_="cell_state_BxH")      
        visualize_hot_cold_normalized(matrix_=cell_output_BxH, axis_of_reduction_=None, name_="cell_output_BxH")      

    # Build the read head.
    with tf.name_scope("read_head"): 
        #######################################################################################
        # B) "Emit" parameters on the basis of controller output.       
        # 1. Content addressing
        # Calculate read keys [BATCH_SIZE x SLOT_SIZE] range: -inf - +inf!
        rkey_BxS = tf.nn.relu(tf.matmul(cell_output_BxH, Wh_rk, name="key_BxS"))
        # Calculate betas - used for sharpening/smoothing similarity - scalar [BATCH x 1] range: +1 - +inf!
        rbeta_Bx1 = 5 + tf.log1p(tf.exp(tf.matmul(cell_output_BxH, Wh_rbeta)), name="beta_Bx1")

        # 2. Focusing by location - interpolation gate - [BATCH_SIZE x 1] - range: 0 - +1.
        rinterpolation_gate_Bx1 = tf.sigmoid(tf.matmul(cell_output_BxH, Wh_rig), name="interpolation_gate_Bx1")

        # 3. Shift weighting - for circular convolution [BATCH_SIZE x 3] - range: 0 - +1.
        rshift_Bx3 = tf.sigmoid(tf.matmul(cell_output_BxH, Wh_rsh), name="shift_Bx3")

        # 4. Sharpening
        # Gamma - scalar [BATCH x 1] range: 1 - +inf
        rgamma_Bx1 = 5 + tf.log1p(tf.exp(tf.matmul(cell_output_BxH, Wh_rgamma)))

        #######################################################################################
        # C) Build the read head! [BATCH_SIZE x NUMBER_OF_SLOTS]
        read_weights_BxN = build_head(rkey_BxS, rbeta_Bx1, prev_memory_BxSxN, rinterpolation_gate_Bx1, prev_read_weights_BxN, rshift_Bx3, rgamma_Bx1)
        print("read_weights_BxN=",read_weights_BxN)

         # Add histograms to TensorBoard.
        tf.summary.histogram("7.read_weights_BxN", read_weights_BxN)
        # Add hot-cold visualizations.
        visualize_hot_cold(matrix_=read_weights_BxN, name_="7.read_weights_BxN")          

    # Read from the memory.
    with tf.name_scope("memory_output"):
        # Calculate the memory output ("read vector").
        memory_output_BxS = build_memory_output(read_weights_BxN, prev_memory_BxSxN)

        # Add histograms to TensorBoard.
        tf.summary.histogram("8.memory_output_BxS", memory_output_BxS)

        # Add hot-cold visualizations.
        visualize_hot_cold_normalized(matrix_=memory_output_BxS, axis_of_reduction_=None, name_="8.memory_output_BxS")          

    # Build the write head.
    with tf.name_scope("write_head"): 
        #######################################################################################
        # B) "Emit" parameters on the basis of controller output.       
        # 1. Content addressing
        # Calculate write keys [BATCH_SIZE x SLOT_SIZE] range: -inf - +inf!
        wkey_BxS = tf.nn.relu(tf.matmul(cell_output_BxH, Wh_wk, name="key_BxS"))
        # Calculate betas - used for sharpening/smoothing similarity - scalar [BATCH x 1] range: +1 - +inf!
        wbeta_Bx1 = 5 + tf.log1p(tf.exp(tf.matmul(cell_output_BxH, Wh_wbeta)), name="beta_Bx1")

        # 2. Focusing by location - interpolation gate - [BATCH_SIZE x 1] - range: 0 - +1.
        winterpolation_gate_Bx1 = tf.sigmoid(tf.matmul(cell_output_BxH, Wh_wig), name="interpolation_gate_Bx1")

        # 3. Shift weighting - for circular convolution [BATCH_SIZE x 3] - range: 0 - +1.
        wshift_Bx3 = tf.sigmoid(tf.matmul(cell_output_BxH, Wh_wsh), name="shift_Bx3")

        # 4. Sharpening
        # Gamma - scalar [BATCH x 1] range: 1 - +inf!
        wgamma_Bx1 = 5 + tf.log1p(tf.exp(tf.matmul(cell_output_BxH, Wh_wgamma)))

        #######################################################################################
        # C) Build the read head! [BATCH_SIZE x NUMBER_OF_SLOTS]
        write_weights_BxN = build_head(wkey_BxS, wbeta_Bx1, prev_memory_BxSxN, winterpolation_gate_Bx1, prev_read_weights_BxN, wshift_Bx3, wgamma_Bx1)
        print("write_weights_BxN=",write_weights_BxN)

         # Add histograms to TensorBoard.
        tf.summary.histogram("7.write_weights_BxN", write_weights_BxN)
        # Add hot-cold visualizations.
        visualize_hot_cold(matrix_=write_weights_BxN, name_="7.write_weights_BxN")          

    with tf.name_scope("memory_update"):
        #######################################################################################
        # A) Get shared variables.   
        with tf.variable_scope("memory_update_variables", reuse=True):
            # Add vector
            Wh_wa = tf.get_variable("Wh_wa", [HIDDEN_SIZE, CONTENT_SIZE])
            # Erase vector
            Wh_we = tf.get_variable("Wh_we", [HIDDEN_SIZE, CONTENT_SIZE])

            # Add histograms to TensorBoard.
            tf.summary.histogram("Wh_wa", Wh_wa)
            tf.summary.histogram("Wh_we", Wh_we)

        #######################################################################################
        # B) "Emit" parameters on the basis of controller output.       
        # Add vector [BATCH_SIZE x CONTENT_SIZE] range: -inf - +inf!
        add_vector_BxC = tf.nn.relu(tf.matmul(cell_output_BxH, Wh_wa, name="add_vector_BxC"))

        # Erase vector [BATCH_SIZE x CONTENT_SIZE] range: 0 - +1
        erase_vector_BxC = 0.3 * tf.sigmoid(tf.matmul(cell_output_BxH, Wh_we, name="erase_vector_BxS"))

        # 3. Calculate the preserved content [BATCH_SIZE x SLOT_SIZE x NUMBER_OF_SLOTS]
        preserved_memory_BxSxN = build_memory_preservation(write_weights_BxN, erase_vector_BxC, prev_memory_BxSxN)
        print("???? preserved_memory_BxSxN=",preserved_memory_BxSxN)
        
        # Calculate update of memory [BATCH_SIZE x SLOT_SIZE x NUMBER_OF_SLOTS]
        memory_update_BxSxN = build_memory_update(write_weights_BxN, add_vector_BxC, prev_memory_BxSxN)
        print("!!! memory_update_BxSxN=",memory_update_BxSxN)

        # 5. Calculate "updated" memory [BATCH_SIZE x SLOT_SIZE x NUMBER_OF_SLOTS]
        #memory_updated_BxSxN = prev_memory_BxSxN          
        memory_updated_BxSxN = tf.add_n([preserved_memory_BxSxN])      
        #memory_updated_BxSxN = tf.tanh(preserved_memory_BxSxN + memory_update_BxSxN)
        print("??? memory_updated_BxSxN=",memory_updated_BxSxN)

        # Add histograms to TensorBoard.
        tf.summary.histogram("1.add_vector_BxC", add_vector_BxC)
        tf.summary.histogram("2.erase_vector_BxC", erase_vector_BxC)
        tf.summary.histogram("3.preserved_memory_BxSxN", preserved_memory_BxSxN)
        tf.summary.histogram("4.memory_update_BxSxN", memory_update_BxSxN)
        tf.summary.histogram("5.memory_updated_BxSxN", memory_updated_BxSxN)
        tf.summary.histogram("0.prev_memory_BxSxN", prev_memory_BxSxN)

        # Add hot-cold visualizations.
        visualize_hot_cold_normalized(matrix_=add_vector_BxC, axis_of_reduction_=None, name_="1.add_vector_BxC")          
        visualize_hot_cold_normalized(matrix_=erase_vector_BxC, axis_of_reduction_=None, name_="2.erase_vector_BxC")          
        visualize_hot_cold_normalized_3D(matrix_3d_=preserved_memory_BxSxN, axis_of_reduction_=None, name_="3.preserved_memory_BxSxN")          
        visualize_hot_cold_normalized_3D(matrix_3d_=memory_update_BxSxN, axis_of_reduction_=None, name_="4.memory_update_BxSxN")          
        visualize_hot_cold_normalized_3D(matrix_3d_=memory_updated_BxSxN, axis_of_reduction_=None, name_="5.memory_updated_BxSxN")          
        visualize_hot_cold_normalized_3D(matrix_3d_=prev_memory_BxSxN, axis_of_reduction_=None, name_="0.prev_memory_BxSxN")          

    return cell_state_BxH, cell_output_BxH, memory_output_BxS, memory_updated_BxSxN, read_weights_BxN, read_weights_BxN # FOR NOW!!

print("Cell definition OK")

Cell definition OK


## Built graph

In [171]:
# Reset graph - just in case.
tf.reset_default_graph()

# Create shared (reuseable) space for all cells.   
with tf.variable_scope("lstm_variables"):
    # Input gate: input, previous output, and bias.
    Wf = tf.Variable(tf.truncated_normal([INPUT_SIZE+HIDDEN_SIZE, HIDDEN_SIZE], -0.1, 0.1), name="Wf")
    bf = tf.Variable(tf.zeros([1, HIDDEN_SIZE]), name="bf")

    # Forget gate: input, previous output, and bias.
    Wi = tf.Variable(tf.truncated_normal([INPUT_SIZE+HIDDEN_SIZE, HIDDEN_SIZE], -0.1, 0.1), name="Wi")
    bi = tf.Variable(tf.zeros([1, HIDDEN_SIZE]), name="bi")

    # Memory cell: input, state and bias.                             
    Wc = tf.Variable(tf.truncated_normal([INPUT_SIZE+HIDDEN_SIZE, HIDDEN_SIZE], -0.1, 0.1), name="Wc")
    bc = tf.Variable(tf.zeros([1, HIDDEN_SIZE]), name="bc")

    # Output gate: input, previous output, and bias.
    Wo = tf.Variable(tf.truncated_normal([INPUT_SIZE+HIDDEN_SIZE, HIDDEN_SIZE], -0.1, 0.1), name="Wo")
    bo = tf.Variable(tf.zeros([1, HIDDEN_SIZE]), name="bo")
    
    # Initialization placeholders.
    init_cell_state = tf.placeholder(tf.float32, shape=[None, HIDDEN_SIZE], name="init_cell_state")
    init_cell_output = tf.placeholder(tf.float32, shape=[None, HIDDEN_SIZE], name="init_cell_output")

    # Add histograms to TensorBoard.
    tf.summary.histogram("Wf", Wf)
    tf.summary.histogram("Wi", Wi)
    tf.summary.histogram("Wc", Wc)
    tf.summary.histogram("Wo", Wo)
    
# Create shared (reuseable) space for all cells.   
with tf.variable_scope("read_head_variables"):
    # 1. Content addressing.
    Wh_rk = tf.get_variable("Wh_rk", [HIDDEN_SIZE, SLOT_SIZE])
    Wh_rbeta = tf.get_variable("Wh_rbeta", [HIDDEN_SIZE, 1])

    # 2. Interpolation gate.
    Wh_rig = tf.get_variable("Wh_rig", [HIDDEN_SIZE, 1])

    # 3. Shift.
    Wh_rsh = tf.get_variable("Wh_rsh", [HIDDEN_SIZE, 3])

    # 4. Sharpening.
    Wh_rgamma = tf.get_variable("Wh_rgamma", [HIDDEN_SIZE, 1])

    # Add histograms to TensorBoard.
    tf.summary.histogram("Wh_rk", Wh_rk)
    tf.summary.histogram("Wh_rbeta", Wh_rbeta)
    tf.summary.histogram("Wh_rig", Wh_rig)
    tf.summary.histogram("Wh_rsh", Wh_rsh)
    tf.summary.histogram("Wh_rgamma", Wh_rgamma)
    
with tf.variable_scope("write_head_variables"):
    # 1. Content addressing.
    Wh_wk = tf.get_variable("Wh_wk", [HIDDEN_SIZE, SLOT_SIZE])
    Wh_wbeta = tf.get_variable("Wh_wbeta", [HIDDEN_SIZE, 1])

    # 2. Interpolation gate.
    Wh_wig = tf.get_variable("Wh_wig", [HIDDEN_SIZE, 1])

    # 3. Shift.
    Wh_wsh = tf.get_variable("Wh_wsh", [HIDDEN_SIZE, 3])

    # 4. Sharpening.
    Wh_wgamma = tf.get_variable("Wh_wgamma", [HIDDEN_SIZE, 1])

    # Add histograms to TensorBoard.
    tf.summary.histogram("Wh_wk", Wh_wk)
    tf.summary.histogram("Wh_wbeta", Wh_wbeta)
    tf.summary.histogram("Wh_wig", Wh_wig)
    tf.summary.histogram("Wh_wsh", Wh_wsh)
    tf.summary.histogram("Wh_wgamma", Wh_wgamma)

# Memory initialization related ops.
with tf.name_scope("memory_initialization"):
    #memory = tf.get_variable("memory", tf.truncated_normal(shape=[HIDDEN_SIZE, NUMBER_OF_SLOTS]), trainable=False)
    #init_memory_state_sample_op = tf.truncated_normal(shape=[SLOT_SIZE, NUMBER_OF_SLOTS], mean=0.5, stddev=0.2, dtype=tf.float32)
    #init_memory_state_op = tf.clip_by_value(tf.truncated_normal(shape=[SLOT_SIZE, NUMBER_OF_SLOTS], dtype=tf.float32), 0.0, 1.0)

    # Sets sparse memory addresses.
    #memory_content = tf.slice(init_memory_state_op, [ADDRESS_SIZE, 0], [CONTENT_SIZE, NUMBER_OF_SLOTS])
    #sparse_address = tf.eye(ADDRESS_SIZE) # must be = NUMBER_OF_SLOTS!
    #init_sparse_addressing_memory_state_op = tf.concat([sparse_address, memory_content], axis = 0)

    # Placeholder for "initial" memory output [BATCH_SIZE X SLOT_SIZE].
    init_memory_output = tf.placeholder(tf.float32, shape=[None, SLOT_SIZE], name="init_memory_read")
    # Placeholder for "initial" memory state [BATCH_SIZE x SLOT_SIZE x NUMBER_OF_SLOTS].
    input_memory_state = tf.placeholder(tf.float32, shape=[None, SLOT_SIZE, NUMBER_OF_SLOTS], name="input_memory_state_BxSxN")
    
    
with tf.variable_scope("memory_update_variables"):
    # Add vector
    Wh_wa = tf.get_variable("Wh_wa", [HIDDEN_SIZE, CONTENT_SIZE])
    # Erase vector
    Wh_we = tf.get_variable("Wh_we", [HIDDEN_SIZE, CONTENT_SIZE])

    # Add histograms to TensorBoard.
    tf.summary.histogram("Wh_wa", Wh_wa)
    tf.summary.histogram("Wh_we", Wh_we)

# 0. Previous variables.
with tf.name_scope("Previous_variables"):

    # Placeholders for previous weights.
    # TODO: Each cell (element in sequence) has its own prev read/update vector - ONE FOR THE WHOLE BATCH [NUMBER_OF_SLOTS]
    prev_read_weights_batch_seq = list()    
    for i in range(SEQ_LENGTH):
        prev_read_weights_batch_seq.append(tf.placeholder(tf.float32, shape=[None, NUMBER_OF_SLOTS], name="prev_read_weights_BxN"+str(i)))
    prev_write_weights_batch_seq = list()    
    for i in range(SEQ_LENGTH):
        prev_write_weights_batch_seq.append(tf.placeholder(tf.float32, shape=[None, NUMBER_OF_SLOTS], name="prev_write_weights_BxN"+str(i)))

# 1. Placeholders for inputs.
with tf.name_scope("input_data"):
    # Define input data buffers.
    data_buffers = list()
    for i in range(SEQ_LENGTH + 1):
        # Collect placeholders for inputs/labels: Batch x Vocab size.
        data_buffers.append(tf.placeholder(tf.float32, shape=[None, INPUT_SIZE], name="data_buffers_BxI_"+str(i)))

    # Sequence of batches.
    input_BxI_L = data_buffers[:SEQ_LENGTH]

    # Labels are pointing to the same placeholders!
    # Labels are inputs shifted by one time step.
    target_BxI_L = data_buffers[1:]  
    # Concatenate targets into 2D tensor.
    target_LBxI = tf.concat(target_BxI_L, 0)

    # Add histograms to TensorBoard.
    tf.summary.histogram("input_BxI_L", input_BxI_L)


# Build list of outpus for sequence (at the end of size SEQ_LENGTH).
memory_outputs_batch_seq = list()
cell_output_batch_seq = list()
# Two lists that will be "returned" and later passed as previous states. 
read_weights_batch_seq = list()  
write_weights_batch_seq = list()  
intermediate_memory_states = list()

# Build "unrolled controller" - "link" oldest memory output to initial placeholder.
print("Building NTM cells - it might take a while...")
# 2D tensors.
cell_state = init_cell_state 
cell_output = init_cell_output
memory_output = init_memory_output
# 3D tensor.
memory_state = input_memory_state
# For every buffer in input sequence batch buffers...
for i in range(SEQ_LENGTH):
    # ... add cell...     
    cell_state, cell_output, memory_output, memory_state, read_weights, write_weights = controller_cell(
        input_BxI_L[i], 
        cell_state,
        cell_output,
        memory_output,
        memory_state,
        prev_read_weights_batch_seq[i],
        prev_write_weights_batch_seq[i],
        "cell_"+str(i))
    # .. and store outputs and read/write weights.
    memory_outputs_batch_seq.append(memory_output)
    cell_output_batch_seq.append(cell_output)
    read_weights_batch_seq.append(read_weights)
    write_weights_batch_seq.append(write_weights)
    intermediate_memory_states.append(memory_state)

with tf.name_scope("memory_state"):
    # Final memory state.
    final_memory_state = memory_state

    # Add histograms to TensorBoard.
    visualize_hot_cold_normalized_3D(matrix_3d_=input_memory_state, axis_of_reduction_=None, name_="input_memory_state_SxN")          
    visualize_hot_cold_normalized_3D(matrix_3d_=final_memory_state, axis_of_reduction_=None, name_="final_memory_state_SxN")          
    for i in range(SEQ_LENGTH):
        tf.summary.histogram("intermediate_memory_state_SxN_"+str(i), intermediate_memory_states[i])
        visualize_hot_cold_normalized(matrix_=intermediate_memory_states[i], axis_of_reduction_=None, name_="intermediate_memory_state_SxN_"+str(i))          

print("Building output...")
with tf.variable_scope("output_variables"):
    # Cell output weights and biases.
    Wcol = tf.get_variable("Wcol", [HIDDEN_SIZE, INPUT_SIZE])
    bcol = tf.get_variable("bcol", [INPUT_SIZE])
    # Memory output weights and biases.
    Wmol = tf.get_variable("Wmol", [CONTENT_SIZE, INPUT_SIZE])
    bmol = tf.get_variable("bmol", [INPUT_SIZE])
    
# 3. Output ops.
with tf.name_scope("output"):
    # Concatenate cell output sequence along time axis.
    cell_output_batch = tf.concat(cell_output_batch_seq, 0) 
    # Memory logits.
    cell_logits_batch = tf.nn.xw_plus_b(cell_output_batch, Wcol, bcol, name = "cell_logits_batch")

    # Concatenate memory output sequence along time axis.
    memory_output_batch = tf.concat(memory_outputs_batch_seq, 0) 
    # Memory logits.
    memory_logits_batch = tf.nn.xw_plus_b(memory_output_batch, Wmol, bmol, name = "memory_logits_batch")
    
    # Add fully connected softmax layer on top.
    logits_batch = memory_logits_batch #+ cell_logits_batch
    prediction_batch = tf.nn.softmax(logits_batch)

 
# 4. Loss ops.
with tf.name_scope("loss"):
    # Loss function(s) - reduce accross every batch and sequence (output generated by every LSTM cell).
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=target_LBxI, logits=logits_batch))
    # Add loss summary.
    loss_summary = tf.summary.scalar("loss", loss)

# 5. Training ops.  
with tf.name_scope("optimization"):
    # Learning rate decay.
    global_step = tf.Variable(0)
    learning_rate = tf.train.exponential_decay(0.1, global_step, 5000, 0.9, staircase=True)
    # Optimizer.
    optimizer = tf.train.AdamOptimizer(learning_rate)
    grads_and_vars = optimizer.compute_gradients(loss)
    gradients, v = zip(*grads_and_vars)
    # Gradient clipping.
    gradients, _ = tf.clip_by_global_norm(gradients, 0.1)
    optimizer = optimizer.apply_gradients(zip(gradients, v), global_step=global_step)

    print("Building summaries for all variables and gradients - it might take a while...")
    for i, (grad, var) in enumerate(grads_and_vars):
        if grad is not None:
            gradients[i] = (tf.clip_by_value(grad, -10, 10), var)
            tf.summary.histogram(var.name, var)
            tf.summary.histogram(var.name + '/grad', grad)
    
    
# Merge all summaries.
merged_summaries = tf.summary.merge_all()

print("Graph definition OK")

Building NTM cells - it might take a while...
zeros= Tensor("lstm/input_BxI_vis_hc_norm/zeros_like:0", shape=(?, 59), dtype=float32)
zeros= Tensor("lstm/prev_cell_state_BxH_vis_hc_norm/zeros_like:0", shape=(?, 16), dtype=float32)
zeros= Tensor("lstm/prev_cell_output_BxH_vis_hc_norm/zeros_like:0", shape=(?, 16), dtype=float32)
zeros= Tensor("lstm/prev_memory_output_BxS_vis_hc_norm/zeros_like:0", shape=(?, 8), dtype=float32)
zeros= Tensor("lstm/cell_state_BxH_vis_hc_norm/zeros_like:0", shape=(?, 16), dtype=float32)
zeros= Tensor("lstm/cell_output_BxH_vis_hc_norm/zeros_like:0", shape=(?, 16), dtype=float32)
!!keys_Bx1xS= Tensor("read_head/focusing_by_location/ExpandDims:0", shape=(?, 1, 8), dtype=float32)
!!keys_Bx1xS= Tensor("read_head/focusing_by_location/ExpandDims:0", shape=(?, 1, 8), dtype=float32)
!!keys_Bx1xS= Tensor("read_head/focusing_by_location/norm_memory_BxSxN:0", shape=(?, 8, 10), dtype=float32)
weighting_BxN= Tensor("read_head/focusing_by_location/weighting_BxN:0", shape=(?

INFO:tensorflow:Summary name lstm_variables/Wc:0 is illegal; using lstm_variables/Wc_0 instead.
INFO:tensorflow:Summary name lstm_variables/Wc:0/grad is illegal; using lstm_variables/Wc_0/grad instead.
INFO:tensorflow:Summary name lstm_variables/bc:0 is illegal; using lstm_variables/bc_0 instead.
INFO:tensorflow:Summary name lstm_variables/bc:0/grad is illegal; using lstm_variables/bc_0/grad instead.
INFO:tensorflow:Summary name lstm_variables/Wo:0 is illegal; using lstm_variables/Wo_0 instead.
INFO:tensorflow:Summary name lstm_variables/Wo:0/grad is illegal; using lstm_variables/Wo_0/grad instead.
INFO:tensorflow:Summary name lstm_variables/bo:0 is illegal; using lstm_variables/bo_0 instead.
INFO:tensorflow:Summary name lstm_variables/bo:0/grad is illegal; using lstm_variables/bo_0/grad instead.
INFO:tensorflow:Summary name read_head_variables/Wh_rk:0 is illegal; using read_head_variables/Wh_rk_0 instead.
INFO:tensorflow:Summary name read_head_variables/Wh_rk:0/grad is illegal; using 

###  Definition of tensor graph

In [175]:
def create_feed_dict(set_type_):
    """Creates feed dictionaries and set initial values for placeholders for different sets"""
    feed_dict = dict()
    
    if set_type_=="train":
        # Get next batch and create a feed dict.
        next_batch = train_batches.next()
        # Feed batch to input buffers.
        for i in range(SEQ_LENGTH + 1):
            feed_dict[data_buffers[i]] = next_batch[i]

        # Set previous weights of read and write heades.
        for i in range(SEQ_LENGTH):
            feed_dict[prev_read_weights_batch_seq[i]] = prev_rw_batch_seq[i]
            feed_dict[prev_write_weights_batch_seq[i]] = prev_ww_batch_seq[i]

        # Reset "initial" memory output.
        feed_dict[init_cell_state] = np.zeros([BATCH_SIZE, HIDDEN_SIZE])
        feed_dict[init_cell_output] = np.zeros([BATCH_SIZE, HIDDEN_SIZE])
        feed_dict[init_memory_output] = np.zeros([BATCH_SIZE, SLOT_SIZE])
        # Pass memory state.
        feed_dict[input_memory_state] = init_mem_state
        #initial_memory_state #memory_state
            
    elif set_type_=="valid":
        for i in range(SEQ_LENGTH + 1):
            feed_dict[data_buffers[i]] = valid_batch[i]

        # Set previous weights of read and write heades.
        for i in range(SEQ_LENGTH):
            feed_dict[prev_read_weights_batch_seq[i]] = prev_rw_batch_seq[i]
            feed_dict[prev_write_weights_batch_seq[i]] = prev_ww_batch_seq[i]

        # Reset "initial" memory output.
        feed_dict[init_cell_state] = np.zeros([valid_size, HIDDEN_SIZE])
        feed_dict[init_cell_output] = np.zeros([valid_size, HIDDEN_SIZE])
        feed_dict[init_memory_output] = np.zeros([valid_size, SLOT_SIZE])
        # Pass memory state.
        #feed_dict[input_memory_state] = np.random.randn([valid_size, SLOT_SIZE, NUMBER_OF_SLOTS])

    else: # test
        for i in range(SEQ_LENGTH + 1):
            feed_dict[data_buffers[i]] = test_batch[i]
        
        # Set previous weights of read and write heades.
        for i in range(SEQ_LENGTH):
            feed_dict[prev_read_weights_batch_seq[i]] = prev_rw_batch_seq[i]
            feed_dict[prev_write_weights_batch_seq[i]] = prev_ww_batch_seq[i]

        # Reset "initial" memory output.
        feed_dict[init_cell_state] = np.zeros([test_size, HIDDEN_SIZE])
        feed_dict[init_cell_output] = np.zeros([test_size, HIDDEN_SIZE])
        feed_dict[init_memory_output] = np.zeros([test_size, SLOT_SIZE])
        # Reset memory state.
        #feed_dict[input_memory_state] = np.random.randn([valid_size, SLOT_SIZE, NUMBER_OF_SLOTS])
       
    return feed_dict

print("Feed_dict definition OK")

Feed_dict definition OK


### Session execution

In [176]:
# Eventually clear the log dir.
if tf.gfile.Exists(LOG_DIR):
  tf.gfile.DeleteRecursively(LOG_DIR)
# Create (new) log dir.
tf.gfile.MakeDirs(LOG_DIR)

print("Log dir CLEARED")

Log dir CLEARED


In [179]:

# Create session to execute graph.
sess=tf.InteractiveSession()

# Create summary writers, point them to LOG_DIR.
train_writer = tf.summary.FileWriter(LOG_DIR + '/train', sess.graph)
valid_writer = tf.summary.FileWriter(LOG_DIR + '/valid')
test_writer = tf.summary.FileWriter(LOG_DIR + '/test')

# Initialize global variables.
tf.global_variables_initializer().run()
print('Variables initialized')

# Create initial previous read and update - full of zeros. 
# TODO: change BATCH_SIZE to 1 - reduce_mean all previous r/w weights. :]
prev_rw_batch_seq = list()
for i in range(SEQ_LENGTH):
    prev_rw_batch_seq.append(np.zeros([BATCH_SIZE, NUMBER_OF_SLOTS]))
prev_ww_batch_seq = list()
for i in range(SEQ_LENGTH):
    prev_ww_batch_seq.append(np.zeros([BATCH_SIZE, NUMBER_OF_SLOTS]))


# Initialize memory! random matrix 
#initial_memory_state = sess.run(init_memory_state_op)
#memory_state = initial_memory_state
#print("memory_state =",memory_state.shape)
init_mem_state = np.random.randn(BATCH_SIZE, SLOT_SIZE, NUMBER_OF_SLOTS)
print(init_mem_state)


# Determine how long to perform the training and how often the test loss on validation batch will be computed. 
num_steps = 10000 #// (BATCH_SIZE*SEQ_LENGTH)
#num_steps = train_size // (BATCH_SIZE*SEQ_LENGTH)
summary_frequency = 10
#validation_frequency = 1000
print("Number of iterations per epoch =", num_steps)
for step in range(num_steps):
    memory_state, prev_rw_batch_seq, prev_ww_batch_seq, summaries, _, loss_, lr_ = sess.run([
        final_memory_state, read_weights_batch_seq, write_weights_batch_seq, merged_summaries, optimizer, loss, learning_rate],
        feed_dict=create_feed_dict("train"))
    # Every (100) steps collect statistics.
    if step % summary_frequency == 0:
        # Add summary.
        train_writer.add_summary(summaries, step*BATCH_SIZE*SEQ_LENGTH)
        train_writer.flush()
        print('Training set BPC at step %d: %0.5f learning rate: %f' % (step, loss_, lr_))

        #print("memory=\n", memory_)
        # Print loss from last batch.
    
    #if step % validation_frequency == 0:
    #    # Validation set BPC.
    #    print('=' * 80)
    #    print("Calculating BPC on validation set")
    #    v_summaries, v_loss = sess.run([merged_summaries, loss], feed_dict=create_feed_dict("valid"))
    #    print("Validation set BPC: %.5f" % v_loss)
    #    valid_writer.add_summary(v_summaries, step*BATCH_SIZE*SEQ_LENGTH)
    #    valid_writer.flush()
    # End of statistics collection

#for gv in grads_and_vars:
#    #print(str(sess.run(gv[0])) + " - " + gv[1].name)        
#    print(" - " + gv[1].name)        

# Test set BPC.
#print('=' * 80)tf.relu(
#print("Calculating BPC on test set")
#t_summary, t_loss = sess.run([merged_summaries, loss], feed_dict=create_feed_dict("test"))
#print("Final test set BPC: %.5f" % t_loss)
#test_writer.add_summary(t_summary, step*BATCH_SIZE*SEQ_LENGTH)
#test_writer.flush()
    
# Close writers and session.
train_writer.close()
valid_writer.close()
test_writer.close()
sess.close() 

Variables initialized
[[[ 1.30571256 -0.135432   -0.44877223 -0.3938189  -0.64267935 -0.40237742
   -1.73061261 -0.11024676 -0.39007158 -0.09439799]
  [ 1.60313754 -0.6506589  -1.48975658 -1.53622514  1.21927447  0.39993844
    1.77962442  1.22417579 -0.69161813  0.14447423]
  [ 0.13321996  0.835875   -1.0019517   0.72414486  0.68147479  0.58661264
    1.5739242  -1.72167776  2.98298157  0.75737819]
  [-0.79495484 -0.12943645  1.20061979  0.41584828 -0.32108833  1.70173021
   -2.26009351  0.62246498  1.55788836  0.67115552]
  [-0.84867757 -1.2484619  -1.81962797 -0.2694866   2.34114377 -1.24178708
    0.891323    0.34955352 -0.84810419 -0.70257412]
  [-1.21784673 -0.91226095  0.68973001  0.07040969 -0.91981963  2.06499206
   -0.85778984  1.90451432 -0.0519562   0.67395742]
  [ 2.06087467 -0.24335     0.90650813  0.41763805  1.73007673  0.14043824
    1.99501459 -2.54144358  1.55097999  0.5698739 ]
  [-0.52060255 -0.8352278   0.1836332  -1.04721283 -0.35105628  0.02056907
    0.34924683

TypeError: 'list' object cannot be interpreted as an integer