In [1]:
import tensorflow as tf
import numpy as np


### Playing around with the functional ops

In [None]:
tf.map_fn?

In [None]:
tf.reset_default_graph()
sess = tf.InteractiveSession()

def double(v):
    return 2 * v

a = tf.Variable([[1,1],[2,2],[3,3],[4,4]])
f = tf.map_fn(double, a)

sess.run(tf.initialize_all_variables())
print sess.run(f)

### Define a sparse cross entropy that works on unknown batch size and seq len

In [None]:
def sparse_xent((logits, labels)):
    return tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels)

# shape will be [n_time, n_batch, n_class]
# with 2 timesteps, 3 sequences in the batch, and 2 classes (for brevity)
logits = tf.Variable([[[.4, .6], [.7, .3], [.1, .9]],
           [[.6, .4], [.3, .7], [.8, .2]],])
labels = tf.Variable([[1, 0, 1], [0, 1, 0]])
xent = tf.map_fn(sparse_xent, (logits, labels), dtype=tf.float32, back_prop=True)
loss = tf.reduce_mean(xent)
sess.run(tf.initialize_all_variables())
print sess.run(xent)

In [None]:
a = tf.Variable(1e50, trainable=False, name="best_valid_loss")
sess.run(a.initializer)

In [None]:
sess.run(a.assign(1))
a.eval()

### Playing around with `while_loop`, `TensorArray`, and `scatter` to build a function that extracts an entity mention tensor from the sequence label tensor

The output of the sequence tagging layer encodes the boundaries of the entity mentions.  In a batch, there may be hugely different number of mentions per sequence.

To predict relations between mentions, we need to know what the extracted mentions are.

# Function to extract the mentions

In [73]:
# create fake sequences over a 'batch'
# sequences are 
# [ O B I O B I I O ]
# [ B I O O O B O B ]
fake_tagged_seqs = np.array([[0,1,2,0,1,2,2,0],
                             [1,2,0,1,0,1,0,1]])

def extract_mentions(seqs, outside_token):
    """ Iterate over a batch of sequences, extracting the mentions encoded in them.
    
    Args:
      seqs: Tensor with shape [batch_size, max_timesteps]
      
    Returns:
      mentions: Tensor with shape [batch_size, max_extracted_mentions, 3].
        This tensor is padded to the max number of extracted mentions.
        The final dimension `3` encodes mention metadata where:
          0: mention left boundary index `i`
          1: mention right boundary index `j` (inclusive)
          2: mention group
      mentions_mask: Tensor with shape [batch_size, max_extracted_mentions].
        This tensor encodes the `pad` locations of the mentions tensor.
        
    Example (pseudocode):
      # fake sequences [ O B I O B I I O ]
      #                [ B I O O O B O B ]
      fake_tagged_seqs = np.array([[0,1,2,0,1,2,2,0],
                             [1,2,0,1,0,1,0,1]])
                             
      mentions = extract_mentions(seqs, outside_token) 
      print mentions[:,:,0] # starting boundaries
      >>> [[ 1  4 -1 -1]
           [ 0  3  5  7]]
      print mentions[:,:,1] # end boundaries (inclusive)
      >>> [[ 2  6 -1 -1]
           [ 1  3  5  7]]
            
    """ 
    def _sparse_update(update_mask, stitch_range, update_range, old_values, new_values):
        """ Return as sparsely updated tensor according to the mask.
        
        This allows for scattered updates to a `Tensor` (not just a `Variable`)
        by using dynamic stitch to overwrite values"""
        update_indices = tf.boolean_mask(update_range, update_mask)
        update_values = tf.boolean_mask(new_values, update_mask)
        return tf.dynamic_stitch([stitch_range, update_indices],
                                 [old_values, update_values])
    
    begin_tokens = tf.constant([[1]])
    outside_tokens = tf.constant([[0]])
    outside_token = tf.constant(0)
    def _start_mention(tags):
        start_mention = tf.reduce_any(tf.equal(tags, begin_tokens), 
                                      reduction_indices=[0])
        start_mention.set_shape((None,))
        return start_mention
    
    def _in_mention(tags):
#         in_mention = tf.reduce_any(tf.not_equal(tags, outside_tokens),
#                                    reduction_indices=[0])
        in_mention = tf.not_equal(tags, outside_token)
        print in_mention.get_shape()
#         in_mention.set_shape((batch_size,))
        return in_mention
    
    def _end_current_mention(in_mention, already_in_mention):
        end_current_mention = tf.logical_and(tf.logical_not(in_mention), already_in_mention)
        end_current_mention.set_shape((None,))
        return end_current_mention
        
    # first figure out the maximum number of mentions
    def _count_only_step(time, seqs_ta, batch_range, mention_counts, already_in_mention):
        """ Update mentions_count using the mention detection update rules. 
        
        This is done by marking if we are starting at a mention
        and then adding to the count to sequences in the batch 
        where we detect the end of a mention.
        
        This is done by stitching in the updates to overwrite the current
        values.  
        
        NOTE: Overwriting dynammic stitch may seem weird
          but it's due to some scatter/while_loop idiosyncrasies of tf
          (this finally works after my 4th implementation attempt)
        """
        tags = seqs_ta.read(time)
        in_mention = _in_mention(tags)
        end_current_mention = _end_current_mention(in_mention, already_in_mention)
        mention_counts = _sparse_update(end_current_mention,
                                        batch_range, batch_range,
                                        mention_counts, mention_counts+1)
        already_in_mention = in_mention
        return time+1, seqs_ta, batch_range, mention_counts, already_in_mention

        
#         in_mention = tf.not_equal(tags, outside_token)
#         end_current_mention = tf.logical_and(tf.logical_not(in_mention), already_in_mention)
#         end_current_mention.set_shape((None,))
#         update_indices = tf.boolean_mask(batch_range, end_current_mention)
#         update_values = tf.boolean_mask(mention_counts+1, end_current_mention)
#         mention_counts = tf.dynamic_stitch([batch_range, update_indices],
#                                            [mention_counts, update_values])
#         already_in_mention = in_mention
#         return time+1, seqs_ta, batch_range, mention_counts, already_in_mention
    
    # fill in empty bookkeeping tensors that encode the mentions
    def _extraction_step(time, 
                         seqs_ta,
                         outside_token,
                         linear_range,
                         linear_index,
                         batch_range,
                         mention_starts,
                         mention_ends,
                         mention_counts, 
                         already_in_mention):
        """ Extract mention boundaries at this timestep.
        
        Args:
          time: the current timestep in the batch of sequences
          seqs_ta: the batch of sequences
          mention_starts: the Tensor holding the start boundaries 
            of the extracted mentions so far
          mention_ends: the Tensor holding the end boundaries 
            of the extracted mentions so far
          mention_counts: the current number of mentions found
            for a single sequence in the batch.  This is used to 
            dynamically scatter extracted mentions into the `mentions`
            TensorArray so that it can be abstracted away from the 
            underlying sequence.
          in_mentions: whether or not the sequence is in a mention or not.
            As we scan, we extract mentions by looking for contiguous groups
            of non-'Outside' tags.  
            This way we can extract multi-token mentions into single elements.
        """
        # get the sequence tags at the current timestep
        tags = seqs_ta.read(time)
        
        # decide if they are in mention or not
        in_mention = _in_mention(tags)
        
        ### IF they are a mention but weren't before, start a new mention
        # do this by applying scattered updates (via dynamic_stitch) 
        # to the masked, linearly indexed locations of new mentions, 
        # where the updated values are the current timestep
        # eg, the location in the sequence where this mention is being started
        boundary = time * tf.ones_like(tags, dtype=tf.int32) 
        offsets = mention_counts + linear_index 
        start_new_mention = _start_mention(tags)
        mention_starts = _sparse_update(start_new_mention,
                                        linear_range, offsets,
                                        mention_starts, boundary)
#         start_new_mention = tf.logical_and(in_mention, tf.logical_not(already_in_mention))
#         start_new_mention.set_shape((None,))
#         update_indices = tf.boolean_mask(offsets, start_new_mention)
#         update_values = tf.boolean_mask(boundary, start_new_mention)
#         mention_starts = tf.dynamic_stitch([linear_range, update_indices],
#                                            [mention_starts, update_values])
        
        ### IF they aren't but were before, end the mention at t-1
        # first, compute the potential new end boundary for all mentions
        # this will be the previous timestep
        # and also increment the count if we finished the mention
#         boundary = boundary - 1
        end_current_mention = _end_current_mention(in_mention, already_in_mention)
#         end_current_mention = tf.logical_and(tf.logical_not(in_mention), already_in_mention)
#         end_current_mention.set_shape((None,))
        mention_ends = _sparse_update(end_current_mention,
                                      linear_range, offsets,
                                      mention_ends, boundary)
#         update_indices = tf.boolean_mask(offsets, end_current_mention)
#         update_values = tf.boolean_mask(boundary, end_current_mention)
#         mention_ends = tf.dynamic_stitch([linear_range, update_indices],
#                                            [mention_ends, update_values]

        # update mention counts where we've ended an extraction (same as `_count_step()`)
        mention_counts = _sparse_update(end_current_mention,
                                        batch_range, batch_range,
                                        mention_counts, mention_counts+1)
#         update_indices = tf.boolean_mask(batch_range, end_current_mention)
#         update_values = tf.boolean_mask(mention_counts+1, end_current_mention)
#         mention_counts = tf.dynamic_stitch([batch_range, update_indices],
#                                            [mention_counts, update_values])
        
        already_in_mention = in_mention
        
        return (time + 1, 
                seqs_ta,
                outside_token,
                linear_range,
                linear_index,
                batch_range,
                mention_starts, 
                mention_ends,
                mention_counts, 
                already_in_mention)
    
    # convert the sequences
    seqs_shape = tf.shape(seqs)
    batch_size = seqs_shape[0]
    time_steps = seqs_shape[1]
    
    # `TensorArray`'s read in time-major, so transpose
    seqs_ta = tf.TensorArray(dtype=seqs.dtype, size=time_steps, clear_after_read=False)
    seqs_ta = seqs_ta.unpack(tf.transpose(seqs, [1,0]))
    
    # bookkeeping tensors
    time = tf.constant(0, dtype=tf.int32)
    batch_range = tf.range(batch_size)
    mention_counts = tf.zeros(dtype=tf.int32, shape=(batch_size,))
    already_in_mention = tf.cast(mention_counts, tf.bool)
        
    # find the maximum number of mentions in batch
    (_, _, _,mention_counts, already_in_mention) = tf.while_loop(cond=lambda time, *_: time < time_steps,
                                                            body=_count_only_step,
                                                            loop_vars=(time, 
                                                                       seqs_ta, 
                                                                       batch_range,
                                                                       mention_counts, 
                                                                       already_in_mention))
    # add 1 to counts where we never detected the end
    already_in_mention.set_shape((None,))
    mention_counts = _sparse_update(already_in_mention,
                                    batch_range, batch_range,
                                    mention_counts, mention_counts+1)
#     already_in_mention.set_shape((None,))
#     update_indices = tf.boolean_mask(batch_range, already_in_mention)
#     update_values = tf.boolean_mask(mention_counts+1, already_in_mention)
#     mention_counts = tf.dynamic_stitch([batch_range, update_indices],
#                                        [mention_counts, update_values])
    
    # now create the tensors we will extract mention data into
    max_num_mentions = tf.reduce_max(mention_counts)
    
    # create a linearized version of the mention statistics we'll gather
    # it'll be padded with -1's (can't use 0 as its a valid index)
    mention_starts = -1*tf.ones(shape=(batch_size * max_num_mentions,), dtype=tf.int32)
    mention_ends = -1*tf.ones(shape=(batch_size * max_num_mentions,), dtype=tf.int32)
    
    # we also need a full linear range and a liear index into the mention_stats
    # tensors so we can dynamically overwrite the values
    linear_range = tf.range(batch_size * max_num_mentions)
    linear_index = max_num_mentions * tf.range(batch_size)
    
    # reset the bookkeeping tensors
    time = tf.constant(0, dtype=tf.int32)
    mention_counts = tf.zeros(dtype=tf.int32, shape=(batch_size,))
    already_in_mention = tf.cast(mention_counts, tf.bool)
    
    # extract the mentions
    (time, _, _, _, _, _,
     mention_starts, 
     mention_ends,
     mention_counts, 
     already_in_mention) = tf.while_loop(cond=lambda time, *_: time < time_steps,
                                         body=_extraction_step,
                                         loop_vars= (time, 
                                                     seqs_ta,
                                                     outside_token,
                                                     linear_range,
                                                     linear_index,
                                                     batch_range,
                                                     mention_starts, 
                                                     mention_ends,
                                                     mention_counts, 
                                                     already_in_mention))
    # if we ended on a mention, we need to compute final endpoints
    already_in_mention.set_shape((None,))
    boundary = time * tf.ones_like(mention_counts, dtype=tf.int32) 
    offsets = mention_counts + linear_index
    mention_ends = _sparse_update(already_in_mention,
                                  linear_range, offsets,
                                  mention_ends, boundary)
#     update_indices = tf.boolean_mask(offsets, already_in_mention)
#     update_values = tf.boolean_mask(boundary, already_in_mention)
#     mention_ends = tf.dynamic_stitch([linear_range, update_indices],
#                                      [mention_ends, update_values])
    mention_counts = _sparse_update(already_in_mention,
                                    batch_range, batch_range,
                                    mention_counts, mention_counts+1)
#     update_indices = tf.boolean_mask(batch_range, already_in_mention)
#     update_values = tf.boolean_mask(mention_counts+1, already_in_mention)
#     mention_counts = tf.dynamic_stitch([batch_range, update_indices],
#                                        [mention_counts, update_values])
    
    # finally concat and reshape extraction stats
    mention_starts = tf.reshape(mention_starts, (batch_size, max_num_mentions, 1))
    mention_ends = tf.reshape(mention_ends, (batch_size, max_num_mentions, 1))
    mentions = tf.concat(2, [mention_starts, mention_ends])
    return mentions
        

tf.reset_default_graph()
seqs = tf.Variable(fake_tagged_seqs, dtype=tf.int32)
outside_token = tf.constant(0, dtype=tf.int32, name='Outside_Mention_Token')

batch_size = tf.shape(seqs)[0]
max_num_mentions = tf.constant(5)

r = tf.range(0, batch_size) * max_num_mentions

a = tf.Variable([1,2,3,4], validate_shape=False, trainable=False)
z = tf.constant([True, False, False, True])
r = tf.range(4)
scatter_idxs = tf.boolean_mask(r, z)
scatter_vals = tf.Variable([0,0])

b = tf.dynamic_stitch([r, scatter_idxs], [a, scatter_vals])

# c = tf.select(z, a, b)
# c = tf.scatter_update(b, tf.Variable([1,2]), tf.Variable([5,3]))

# a = tf.Variable([0,1])
z = tf.constant(2)
z.set_shape(None)
# b = tf.tile(a, [z])

c = tf.constant(0)
d = c + 1
"""
New way to do the mention extraction:
1. Run through once, only updating mention counts
   to dynamically compute the max number of mentions and lengths
2. Create a linearly indexed zero filled tensor with total shape [batch_size x max_num_mentions]
   for each piece of metadata
3. Selectively update the linear index tensor with mention extractions
4. Finally reshape and stitch tensors together
"""

with tf.Session() as sess:
    print "Init"
    mentions = extract_mentions(seqs, outside_token)
    
#     writer = tf.train.SummaryWriter(".", sess.graph)
    tf.initialize_all_variables().run()
#     writer.flush()
#     writer.close()
    print "Done"
#     print c.eval()
#     print d.eval()
    print "Running"
    m = mentions.eval()
    print "Done"
    print fake_tagged_seqs
    print m[:,:,0]
    print m[:,:,1]

Init
(2,)
(2,)
Done
Running
Done
[[0 1 2 0 1 2 2 0]
 [1 2 0 1 0 1 0 1]]
[[ 1  4 -1 -1]
 [ 0  3  5  7]]
[[ 3  7 -1 -1]
 [ 2  4  6  8]]


# Funtion to extract the entire graph

In [111]:
def extract_mentions(seqs):#, features):
    """ Iterate over a batch of sequences, extracting the mentions encoded in them.
    
    Args:
      seqs: Tensor with shape [batch_size, max_timesteps]
      
    Returns:
      mentions: Tensor with shape [batch_size, max_extracted_mentions, 3].
        This tensor is padded to the max number of extracted mentions.
        The final dimension `3` encodes mention metadata where:
          0: mention left boundary index `i`
          1: mention right boundary index `j` (inclusive)
          2: mention group
      mentions_mask: Tensor with shape [batch_size, max_extracted_mentions].
        This tensor encodes the `pad` locations of the mentions tensor.
        
    Example (pseudocode):
      # fake sequences [ O B I O B I I O ]
      #                [ B I O O O B O B ]
      fake_tagged_seqs = np.array([[0,1,2,0,1,2,2,0],
                             [1,2,0,1,0,1,0,1]])
                             
      mentions = extract_mentions(seqs, outside_token) 
      print mentions[:,:,0] # starting boundaries
      >>> [[ 1  4 -1 -1]
           [ 0  3  5  7]]
      print mentions[:,:,1] # end boundaries (inclusive)
      >>> [[ 2  6 -1 -1]
           [ 1  3  5  7]]
            
    """ 
    def _sparse_update(update_mask, stitch_range, update_range, old_values, new_values):
        """ Return as sparsely updated tensor according to the mask.
        
        This allows for scattered updates to a `Tensor` (not just a `Variable`)
        by using dynamic stitch to overwrite values"""
        update_indices = tf.boolean_mask(update_range, update_mask)
        update_values = tf.boolean_mask(new_values, update_mask)
        return tf.dynamic_stitch([stitch_range, update_indices],
                                 [old_values, update_values])
    
    begin_tokens = tf.constant([[1]])
#     outside_tokens = tf.constant([[0]])
    outside_token = tf.constant(0)
    def _start_new_mention(tags):
        start_new_mention = tf.reduce_any(tf.equal(tags, begin_tokens), 
                                      reduction_indices=[0])
        start_new_mention.set_shape((None,))
        return start_new_mention
    
    def _in_mention(tags):
        in_mention = tf.not_equal(tags, outside_token)
        return in_mention
    
    def _end_current_mention(start_new_mention, in_mention, already_in_mention):
        start_or_out = tf.logical_or(start_new_mention, tf.logical_not(in_mention))
        end_current_mention = tf.logical_and(start_or_out, already_in_mention)
        end_current_mention.set_shape((None,))
        return end_current_mention
        
    # first figure out the maximum number of mentions
    def _count_only_step(time, mention_counts, already_in_mention):
        """ Update mentions_count using the mention detection update rules. 
        
        This is done by marking if we are starting at a mention
        and then adding to the count to sequences in the batch 
        where we detect the end of a mention.
        
        This is done by stitching in the updates to overwrite the current
        values.  
        
        NOTE: Overwriting dynammic stitch may seem weird
          but it's due to some scatter/while_loop idiosyncrasies of tf
          (this finally works after my 4th implementation attempt)
        """
        tags = seqs_ta.read(time)
        in_mention = _in_mention(tags)
        start_new_mention = _start_new_mention(tags)
        end_current_mention = _end_current_mention(start_new_mention, in_mention, already_in_mention)
        mention_counts = _sparse_update(end_current_mention,
                                        batch_range, batch_range,
                                        mention_counts, mention_counts+1)
        already_in_mention = in_mention
        return time+1, mention_counts, already_in_mention
    
    # fill in empty bookkeeping tensors that encode the mentions
    def _extraction_step(time, 
                         mention_starts,
                         mention_ends,
#                          mention_features,
                         mention_counts, 
                         mention_sizes,
                         already_in_mention):
        """ Extract mention boundaries at this timestep.
        
        Args:
          time: the current timestep in the batch of sequences
          seqs_ta: the batch of sequences
          mention_starts: the Tensor holding the start boundaries 
            of the extracted mentions so far
          mention_ends: the Tensor holding the end boundaries 
            of the extracted mentions so far
          mention_counts: the current number of mentions found
            for a single sequence in the batch.  This is used to 
            dynamically scatter extracted mentions into the `mentions`
            TensorArray so that it can be abstracted away from the 
            underlying sequence.
          in_mentions: whether or not the sequence is in a mention or not.
            As we scan, we extract mentions by looking for contiguous groups
            of non-'Outside' tags.  
            This way we can extract multi-token mentions into single elements.
        """        
        # get the sequence tags at the current timestep
        tags = seqs_ta.read(time)
        
        # decide if they are in mention or not
        in_mention = _in_mention(tags)
        
        ### IF they are a mention but weren't before, start a new mention
        # do this by applying scattered updates (via dynamic_stitch) 
        # to the masked, linearly indexed locations of new mentions, 
        # where the updated values are the current timestep
        # eg, the location in the sequence where this mention is being started
        boundary = time * tf.ones_like(tags, dtype=tf.int32) 
        start_new_mention = _start_new_mention(tags)
        end_current_mention = _end_current_mention(start_new_mention, in_mention, already_in_mention)
        
        ### IF they aren't but were before, end the mention at t
        mention_ends = _sparse_update(end_current_mention,
                                      linear_range, mention_counts + linear_index,
                                      mention_ends, boundary)
        
        mention_counts = _sparse_update(end_current_mention,
                                        batch_range, batch_range,
                                        mention_counts, mention_counts+1)
        
        mention_starts = _sparse_update(start_new_mention,
                                        linear_range, mention_counts + linear_index,
                                        mention_starts, boundary)

        # update mention counts where we've ended an extraction (same as `_count_step()`)
        mention_sizes = _sparse_update(end_current_mention,
                                       batch_range, batch_range,
                                       mention_sizes, tf.zeros_like(mention_sizes))
        mention_sizes = _sparse_update(in_mention,
                                       batch_range, batch_range,
                                       mention_sizes, mention_sizes+1)
        already_in_mention = in_mention
        
        return (time + 1, 
#                 seqs_ta,
#                 outside_token,
#                 linear_range,
#                 linear_index,
#                 batch_range,
                mention_starts, 
                mention_ends,
#                 mention_features,
                mention_counts, 
                mention_sizes,
                already_in_mention)
    
    # convert the sequences
    shape = tf.shape(seqs)
    batch_size = shape[0]
    time_steps = shape[1]
#     feature_size = shape[2]
    
    # `TensorArray`'s read in time-major, so transpose
    seqs_ta = tf.TensorArray(dtype=seqs.dtype, size=time_steps, clear_after_read=False)
    seqs_ta = seqs_ta.unpack(tf.transpose(seqs, [1,0]))
#     features_ta = tf.TensorArray(dtype=features.dtype, size=time_steps)
#     features_ta = features_ta.unpack(tf.transpose(features, [1,0,2]))
    
    # bookkeeping tensors
    time = tf.constant(0, dtype=tf.int32)
    batch_range = tf.range(batch_size)
    mention_counts = tf.zeros(dtype=tf.int32, shape=(batch_size,))
    already_in_mention = tf.cast(mention_counts, tf.bool)
        
    # find the maximum number of mentions in batch
    (_, mention_counts, already_in_mention) = tf.while_loop(cond=lambda time, *_: time < time_steps,
                                                            body=_count_only_step,
                                                            loop_vars=(time,
                                                                       mention_counts, 
                                                                       already_in_mention))
    # add 1 to counts where we never detected the end
    already_in_mention.set_shape((None,))
    mention_counts = _sparse_update(already_in_mention,
                                    batch_range, batch_range,
                                    mention_counts, mention_counts+1)
    
    # now create the tensors we will extract mention data into
    max_num_mentions = tf.reduce_max(mention_counts)
    max_num_relations = max_num_mentions*(max_num_mentions-1)/2
    
    # create a linearized version of the mention statistics we'll gather
    # it'll be padded with -1's (can't use 0 as its a valid index)
    mention_starts = -1*tf.ones(shape=(batch_size * max_num_mentions,), dtype=tf.int32)
    mention_ends = -1*tf.ones(shape=(batch_size * max_num_mentions,), dtype=tf.int32)
#     mention_features = -1*tf.ones(shape=(batch_size * max_num_mentions, feature_size), 
#                                   dtype=features.dtype)
    
    # we also need a full linear range and a liear index into the mention_stats
    # tensors so we can dynamically overwrite the values
    linear_range = tf.range(batch_size * max_num_mentions)
    linear_index = max_num_mentions * tf.range(batch_size)
    
    # reset the bookkeeping tensors
    time = tf.constant(0, dtype=tf.int32)
    mention_counts = tf.zeros(dtype=tf.int32, shape=(batch_size,))
    mention_sizes = tf.zeros(dtype=tf.int32, shape=(batch_size,))
    already_in_mention = tf.cast(mention_counts, tf.bool)
    
    # extract the mentions
    (time,
     mention_starts, 
     mention_ends,
#      mention_features,
     mention_counts, 
     mention_sizes,
     already_in_mention) = tf.while_loop(cond=lambda time, *_: time < time_steps,
                                         body=_extraction_step,
                                         loop_vars= (time, 
                                                     mention_starts, 
                                                     mention_ends,
#                                                      mention_features,
                                                     mention_counts, 
                                                     mention_sizes,
                                                     already_in_mention))
    # if we ended on a mention, we need to compute final endpoints
    already_in_mention.set_shape((None,))
    boundary = time * tf.ones_like(mention_counts, dtype=tf.int32) 
    offsets = mention_counts + linear_index
    mention_ends = _sparse_update(already_in_mention,
                                  linear_range, offsets,
                                  mention_ends, boundary)
    mention_counts = _sparse_update(already_in_mention,
                                    batch_range, batch_range,
                                    mention_counts, mention_counts+1)
#     mention_sizes = _sparse_update(already_in_mention,
#                                    batch_range, batch_range,
#                                    mention_sizes, mention_sizes+1)
    
    # finally concat and reshape extraction stats
    mention_starts = tf.reshape(mention_starts, (batch_size, max_num_mentions, 1))
    mention_ends = tf.reshape(mention_ends, (batch_size, max_num_mentions, 1))
    mentions = tf.concat(2, [mention_starts, mention_ends])
    return mentions, mention_sizes, mention_counts#, max_num_mentions, max_num_relations
        

"""
New way to do the mention extraction:
1. Run through once, only updating mention counts
   to dynamically compute the max number of mentions and lengths
2. Create a linearly indexed zero filled tensor with total shape [batch_size x max_num_mentions]
   for each piece of metadata
3. Selectively update the linear index tensor with mention extractions
4. Finally reshape and stitch tensors together
"""
# create fake sequences over a 'batch'
# sequences are 
# [ O B I O B I B O ]
# [ B I O B O B O B ]
tf.reset_default_graph()
fake_tagged_seqs = np.array([[0,1,2,0,1,2,1,0],
                             [1,2,0,1,0,1,0,1]])
seqs = tf.Variable(fake_tagged_seqs, dtype=tf.int32)

with tf.Session() as sess:
    print "Init"
    mentions = extract_mentions(seqs)
    tf.initialize_all_variables().run()
    print "Done"
#     print c.eval()
#     print d.eval()
    print "Running"
    m, s, c = mentions
    print "Done"
#     print n.eval(), r.eval()
    print "Input tags"
    print fake_tagged_seqs
    print "Extracted Mentions"
    print s.eval()
    print c.eval()
    print m.eval()

Init
Done
Running
Done
Input tags
[[0 1 2 0 1 2 1 0]
 [1 2 0 1 0 1 0 1]]
Extracted Mentions
[0 1]
[3 4]
[[[ 1  3]
  [ 4  6]
  [ 6  7]
  [-1 -1]]

 [[ 0  2]
  [ 3  4]
  [ 5  6]
  [ 7  8]]]


In [108]:
tf.reset_default_graph()
begin_tokens = tf.constant([[1]])
tok = tf.constant([0, 2, 1])
eq = tf.equal(tok, begin_tokens)
eq2 = tf.reduce_any(eq, reduction_indices=[0])
with tf.Session():
    tf.initialize_all_variables().run()
    print eq.eval()
    print eq2.eval()

[[False False  True]]
[False False  True]


In [None]:
tf.scatt

In [1]:
!mkdir ../summaries

In [112]:
# tf.reset_default_graph()
# start = tf.Variable([0, 1])
# r = tf.range(start)
# with tf.Session():
#     tf.initialize_all_variables().run()
#     print r.eval()

In [10]:
def extract_mentions(seqs, seq_features):
    """ Iterate over a batch of sequences, extracting the mentions encoded in them.
    
    Args:
      seqs: Tensor with shape [batch_size, max_timesteps]
      
    Returns:
      mentions: Tensor with shape [batch_size, max_extracted_mentions, 3].
        This tensor is padded to the max number of extracted mentions.
        The final dimension `3` encodes mention metadata where:
          0: mention left boundary index `i`
          1: mention right boundary index `j` (inclusive)
          2: mention group
      mentions_mask: Tensor with shape [batch_size, max_extracted_mentions].
        This tensor encodes the `pad` locations of the mentions tensor.
        
    Example (pseudocode):
      # fake sequences [ O B I O B I I O ]
      #                [ B I O O O B O B ]
      fake_tagged_seqs = np.array([[0,1,2,0,1,2,2,0],
                             [1,2,0,1,0,1,0,1]])
                             
      mentions = extract_mentions(seqs, outside_token) 
      print mentions[:,:,0] # starting boundaries
      >>> [[ 1  4 -1 -1]
           [ 0  3  5  7]]
      print mentions[:,:,1] # end boundaries (inclusive)
      >>> [[ 2  6 -1 -1]
           [ 1  3  5  7]]
            
    """ 
    def _sparse_update(update_mask, stitch_range, update_range, old_values, new_values):
        """ Return as sparsely updated tensor according to the mask.
        
        This allows for scattered updates to a `Tensor` (not just a `Variable`)
        by using dynamic stitch to overwrite values"""
        update_indices = tf.boolean_mask(update_range, update_mask)
        update_values = tf.boolean_mask(new_values, update_mask)
        return tf.dynamic_stitch([stitch_range, update_indices],
                                 [old_values, update_values])
    
    begin_tokens = tf.constant([[1]])
#     outside_tokens = tf.constant([[0]])
    outside_token = tf.constant(0)
    def _start_new_mention(tags):
        start_new_mention = tf.reduce_any(tf.equal(tags, begin_tokens), 
                                      reduction_indices=[0])
        start_new_mention.set_shape((None,))
        return start_new_mention
    
    def _in_mention(tags):
        in_mention = tf.not_equal(tags, outside_token)
        return in_mention
    
    def _end_current_mention(start_new_mention, in_mention, already_in_mention):
        start_or_out = tf.logical_or(start_new_mention, tf.logical_not(in_mention))
        end_current_mention = tf.logical_and(start_or_out, already_in_mention)
        end_current_mention.set_shape((None,))
        return end_current_mention
        
    # first figure out the maximum number of mentions
    def _count_only_step(time, mention_counts, already_in_mention):
        """ Update mentions_count using the mention detection update rules. 
        
        This is done by marking if we are starting at a mention
        and then adding to the count to sequences in the batch 
        where we detect the end of a mention.
        
        This is done by stitching in the updates to overwrite the current
        values.  
        
        NOTE: Overwriting dynammic stitch may seem weird
          but it's due to some scatter/while_loop idiosyncrasies of tf
          (this finally works after my 4th implementation attempt)
        """
        tags = seqs_ta.read(time)
        in_mention = _in_mention(tags)
        start_new_mention = _start_new_mention(tags)
        end_current_mention = _end_current_mention(start_new_mention, in_mention, already_in_mention)
        mention_counts = _sparse_update(end_current_mention,
                                        batch_range, batch_range,
                                        mention_counts, mention_counts+1)
        already_in_mention = in_mention
        return time+1, mention_counts, already_in_mention
    
    # fill in empty bookkeeping tensors that encode the mentions
    def _extraction_step(time, 
                         mention_starts,
                         mention_ends,
                         mention_features,
                         mention_counts, 
                         mention_sizes,
                         sliding_feature,
                         already_in_mention):
        """ Extract mention boundaries at this timestep.
        
        Args:
          time: the current timestep in the batch of sequences
          seqs_ta: the batch of sequences
          mention_starts: the Tensor holding the start boundaries 
            of the extracted mentions so far
          mention_ends: the Tensor holding the end boundaries 
            of the extracted mentions so far
          mention_counts: the current number of mentions found
            for a single sequence in the batch.  This is used to 
            dynamically scatter extracted mentions into the `mentions`
            TensorArray so that it can be abstracted away from the 
            underlying sequence.
          in_mentions: whether or not the sequence is in a mention or not.
            As we scan, we extract mentions by looking for contiguous groups
            of non-'Outside' tags.  
            This way we can extract multi-token mentions into single elements.
        """        
        # get the sequence tags at the current timestep
        tags = seqs_ta.read(time)
        features = tf.tile(features_ta.read(time), [max_num_mentions, 1])
        
        # decide if they are in mention or not
        in_mention = _in_mention(tags)
        
        boundary = time * tf.ones_like(tags, dtype=tf.int32) 
        
        # whether to start new mention and/or end the previous
        start_new_mention = _start_new_mention(tags)
        end_current_mention = _end_current_mention(start_new_mention, in_mention, already_in_mention)
        
        
        mention_ends = _sparse_update(end_current_mention,
                                      linear_range, mention_counts + linear_index,
                                      mention_ends, boundary)
        
        mention_counts = _sparse_update(end_current_mention,
                                        batch_range, batch_range,
                                        mention_counts, mention_counts+1)
        
        mention_starts = _sparse_update(start_new_mention,
                                        linear_range, mention_counts + linear_index,
                                        mention_starts, boundary)
        
        repeated_index = tf.tile(tf.reshape(mention_counts+linear_index, (-1,1)), [1,2])
        repeated_batch_range = tf.tile(tf.reshape(batch_range, (-1,1)), [1,2])
        mention_features = _sparse_update(in_mention,
                                          repeated_batch_range, repeated_index,
                                          mention_features, sliding_feature + features)
        mention_features.set_shape((None, None))
        
        sliding_feature = _sparse_update(end_current_mention,
                                          batch_range, batch_range,
                                          sliding_feature, tf.zeros_like(sliding_feature))
        
        # update mention counts where we've ended an extraction (same as `_count_step()`)
        mention_sizes = _sparse_update(end_current_mention,
                                       batch_range, batch_range,
                                       mention_sizes, tf.zeros_like(mention_sizes))
        mention_sizes = _sparse_update(in_mention,
                                       batch_range, batch_range,
                                       mention_sizes, mention_sizes+1)
        
        already_in_mention = in_mention
        
        return (time + 1, 
#                 seqs_ta,
#                 outside_token,
#                 linear_range,
#                 linear_index,
#                 batch_range,
                mention_starts, 
                mention_ends,
                mention_features,
                mention_counts, 
                mention_sizes,
                sliding_feature,
                already_in_mention)
    
    # convert the sequences
    shape = tf.shape(seq_features)
    batch_size = shape[0]
    time_steps = shape[1]
    feature_size = shape[2]
    
    # `TensorArray`'s read in time-major, so transpose
    seqs_ta = tf.TensorArray(dtype=seqs.dtype, size=time_steps, clear_after_read=False)
    seqs_ta = seqs_ta.unpack(tf.transpose(seqs, [1,0]))
    features_ta = tf.TensorArray(dtype=seq_features.dtype, size=time_steps)
    features_ta = features_ta.unpack(tf.transpose(seq_features, [1,0,2]))
    
    # bookkeeping tensors
    time = tf.constant(0, dtype=tf.int32)
    batch_range = tf.range(batch_size)
    mention_counts = tf.zeros(dtype=tf.int32, shape=(batch_size,))
    already_in_mention = tf.cast(mention_counts, tf.bool)
        
    # find the maximum number of mentions in batch
    (_, mention_counts, already_in_mention) = tf.while_loop(cond=lambda time, *_: time < time_steps,
                                                            body=_count_only_step,
                                                            loop_vars=(time,
                                                                       mention_counts, 
                                                                       already_in_mention))
    # add 1 to counts where we never detected the end
    already_in_mention.set_shape((None,))
    mention_counts = _sparse_update(already_in_mention,
                                    batch_range, batch_range,
                                    mention_counts, mention_counts+1)
    
    # now create the tensors we will extract mention data into
    max_num_mentions = tf.reduce_max(mention_counts)
    max_num_relations = max_num_mentions*(max_num_mentions-1)/2
    
    # create a linearized version of the mention statistics we'll gather
    # it'll be padded with -1's (can't use 0 as its a valid index)
    mention_starts = -1*tf.ones(shape=(batch_size * max_num_mentions,), dtype=tf.int32)
    mention_ends = -1*tf.ones(shape=(batch_size * max_num_mentions,), dtype=tf.int32)
    mention_features = tf.zeros(shape=(batch_size * max_num_mentions, feature_size),
                                dtype=seq_features.dtype)
#     mention_features = -1*tf.ones(shape=(batch_size * max_num_mentions, feature_size), 
#                                   dtype=features.dtype)
    
    # we also need a full linear range and a liear index into the mention_stats
    # tensors so we can dynamically overwrite the values
    linear_range = tf.range(batch_size * max_num_mentions)
    linear_index = max_num_mentions * tf.range(batch_size)
    
    # reset the bookkeeping tensors
    time = tf.constant(0, dtype=tf.int32)
    mention_counts = tf.zeros(dtype=tf.int32, shape=(batch_size,))
    mention_sizes = tf.zeros(dtype=tf.int32, shape=(batch_size,))
    sliding_feature = tf.zeros(dtype=seq_features.dtype, shape=(batch_size,))
    already_in_mention = tf.cast(mention_counts, tf.bool)
    
    # extract the mentions
    (time,
     mention_starts, 
     mention_ends,
     mention_features,
     mention_counts, 
     mention_sizes,
     sliding_feature,
     already_in_mention) = tf.while_loop(cond=lambda time, *_: time < time_steps,
                                         body=_extraction_step,
                                         loop_vars= (time, 
                                                     mention_starts, 
                                                     mention_ends,
                                                     mention_features,
                                                     mention_counts, 
                                                     mention_sizes,
                                                     sliding_feature,
                                                     already_in_mention))
    # if we ended on a mention, we need to compute final endpoints
    already_in_mention.set_shape((None,))
    boundary = time * tf.ones_like(mention_counts, dtype=tf.int32) 
    offsets = mention_counts + linear_index
    mention_ends = _sparse_update(already_in_mention,
                                  linear_range, offsets,
                                  mention_ends, boundary)
    mention_counts = _sparse_update(already_in_mention,
                                    batch_range, batch_range,
                                    mention_counts, mention_counts+1)
#     mention_sizes = _sparse_update(already_in_mention,
#                                    batch_range, batch_range,
#                                    mention_sizes, mention_sizes+1)
    
    # finally concat and reshape extraction stats
    mention_starts = tf.reshape(mention_starts, (batch_size, max_num_mentions, 1))
    mention_ends = tf.reshape(mention_ends, (batch_size, max_num_mentions, 1))
    mentions = tf.concat(2, [mention_starts, mention_ends])
    return mentions, mention_features, mention_sizes, mention_counts#, max_num_mentions, max_num_relations

In [11]:
# [ O B I ]
# [ B O B ]

# rnn outputs
tf.reset_default_graph()
seqs = tf.Variable(np.array([[0,1,1],
                             [1,0,1]]) , dtype=tf.int32, trainable=False)
ys = tf.Variable(np.array([[0,1],
                           [1,0]]), trainable=False)
embed = tf.Variable(np.array([[2., 2.],
                              [3., 3.]]))

embedded_seqs = tf.nn.embedding_lookup(embed, seqs)
print embedded_seqs.get_shape()
# cell = tf.nn.rnn_cell.BasicRNNCell(5)
# outputs, _ = tf.nn.dynamic_rnn(cell, embedded_seqs, dtype=tf.float64)

# print outputs.get_shape()

# extracted mentions
tags = tf.Variable(np.array([[0,1,2],
                             [1,0,1]]), dtype=tf.int32, trainable=False)
mention_spans, mention_features, _, _ = extract_mentions(tags, embedded_seqs)

with tf.Session() as sess:
    tf.initialize_all_variables().run()
    print mention_spans.eval()
    print mention_features.eval()
#     print sizes.eval()
#     print counts.eval()

        



(2, 3, 2)


ValueError: Shapes (?,) and (?, ?) are not compatible

In [None]:
def extract_indices(tags):
    shape = tf.shape(tags)
    batch_size = shape[0]
    time_steps = shape[1]
    
    tags_ta = tf.TensorArray(dtype=tags.dtype, size=time_steps, clear_after_read=True)
    tags_ta = seqs_ta.unpack(tf.transpose(tags, [1,0]))
    index_ta = tf.TensorArray(dtype=tags.dtype, size=time_steps, clear_after_read=True)
    
    begin_tokens = tf.constant([[1]])
#     outside_tokens = tf.constant([[0]])
    outside_token = tf.constant(0)
    def _start_new_mention(tags):
        start_new_mention = tf.reduce_any(tf.equal(tags, begin_tokens), 
                                      reduction_indices=[0])
        start_new_mention.set_shape((None,))
        return start_new_mention
    
    def _in_mention(tags):
        in_mention = tf.not_equal(tags, outside_token)
        return in_mention
    
    def _end_current_mention(start_new_mention, in_mention, already_in_mention):
        start_or_out = tf.logical_or(start_new_mention, tf.logical_not(in_mention))
        end_current_mention = tf.logical_and(start_or_out, already_in_mention)
        end_current_mention.set_shape((None,))
        return end_current_mention

    
    

# [ O B I ]
# [ B O B ]

# rnn outputs
tf.reset_default_graph()
seqs = tf.Variable(np.array([[0,1,1],
                             [1,0,1]]) , dtype=tf.int32, trainable=False)
ys = tf.Variable(np.array([[0,1],
                           [1,0]]), trainable=False)
embed = tf.Variable(np.array([[2., 2.],
                              [3., 3.]]))

embedded_seqs = tf.nn.embedding_lookup(embed, seqs)
print embedded_seqs.get_shape()

# extracted mentions
tags = tf.Variable(np.array([[0,1,2],
                             [1,0,1]]), dtype=tf.int32, trainable=False)

with tf.Session() as sess:
    tf.initialize_all_variables().run()
    print mention_spans.eval()
    print mention_features.eval()