New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TensorArray in BeamSearchDecoder state. #13312

Merged
merged 18 commits into from Mar 18, 2018

Conversation

@guillaumekln
Contributor

guillaumekln commented Sep 26, 2017

#13208 attempted to fix #13154 by representing the alignment_history field with a Tensor instead of a TensorArray. However, @ebrevdo pointed out that this approach led to a quadratic time and space overhead.

This PR fixes the issue by directly adding the support for TensorArray in the BeamSearchDecoder state as proposed by @ebrevdo.

@ebrevdo Let me know what you think of this implementation. Thanks!

@googlebot googlebot added the cla: yes label Sep 26, 2017

@tensorflow-jenkins

This comment has been minimized.

Show comment
Hide comment
@tensorflow-jenkins

tensorflow-jenkins Sep 26, 2017

Collaborator

Can one of the admins verify this patch?

Collaborator

tensorflow-jenkins commented Sep 26, 2017

Can one of the admins verify this patch?

@sb2nov sb2nov requested a review from ebrevdo Sep 26, 2017

@sb2nov

This comment has been minimized.

Show comment
Hide comment
@sb2nov

sb2nov Sep 26, 2017

Member

Jenkins, test this please.

Member

sb2nov commented Sep 26, 2017

Jenkins, test this please.

@sb2nov

This comment has been minimized.

Show comment
Hide comment
@sb2nov

sb2nov Sep 28, 2017

Member

Jenkins, test this please.

Member

sb2nov commented Sep 28, 2017

Jenkins, test this please.

@ebrevdo

This comment has been minimized.

Show comment
Hide comment
@ebrevdo

ebrevdo Oct 8, 2017

Contributor
Contributor

ebrevdo commented Oct 8, 2017

@steven-hh-ding

This comment has been minimized.

Show comment
Hide comment
@steven-hh-ding

steven-hh-ding Oct 8, 2017

Thank you ebrevdo. I deleted the previous post since it contains several errors in our old code. The gather_tree only support tensor rank of 3. The dtype is fine. We can use gather_tree to sort out the right index of beams for each time step. And collect the alignment history using the index. I will post our code here once it is tested.

steven-hh-ding commented Oct 8, 2017

Thank you ebrevdo. I deleted the previous post since it contains several errors in our old code. The gather_tree only support tensor rank of 3. The dtype is fine. We can use gather_tree to sort out the right index of beams for each time step. And collect the alignment history using the index. I will post our code here once it is tested.

@ebrevdo

This comment has been minimized.

Show comment
Hide comment
@ebrevdo

ebrevdo Oct 8, 2017

Contributor
Contributor

ebrevdo commented Oct 8, 2017

@steven-hh-ding

This comment has been minimized.

Show comment
Hide comment
@steven-hh-ding

steven-hh-ding Oct 9, 2017

def _gather_tree_for_array(t, parent_ids, sequence_length):
  """
  Convert tensor array that contains tensor with unsorted beams.
  Each element has shape [batch*beam, depth]
  Return sorted beam tensor array contains elements of shape
  [batch, beam, depth]
  Padding has a value of 0.
  """
  # (time, batch, beam)
  time = array_ops.shape(parent_ids)[0]
  batch = array_ops.shape(parent_ids)[1]
  beam = array_ops.shape(parent_ids)[2]
  sorted_indx = array_ops.expand_dims(
      array_ops.expand_dims(math_ops.range(beam), 0), 0)
  sorted_indx = array_ops.tile(sorted_indx, [time, batch, 1])
  sorted_beams = beam_search_ops.gather_tree(
      step_ids=sorted_indx, parent_ids=parent_ids,
      sequence_length=sequence_length)
  # all index of beam increase by 1. (-1 padding index becomes 0)
  # 0 indicates a padding beam which has zero alignments.
  sorted_beams = sorted_beams + 1

  def collect(collector, i):
      # concate a padding alignment (zeros) for each batch
      value = array_ops.reshape(t.read(i), [batch, beam, -1])
      padding = array_ops.zeros([batch, 1, array_ops.shape(value)[-1]])
      value = array_ops.concat([padding, value], axis=1)

      # collect value according to the sorted_beams
      # cannot use gather_helper as we increase the beam size by one
      # so the final_shape has a different beam size
      range_ = array_ops.expand_dims(
          math_ops.range(batch) * (beam + 1), 1)
      gather_indices = array_ops.reshape(sorted_beams[i] + range_, [-1])
      sorted_value = array_ops.gather(
          array_ops.reshape(
              value, [batch * (beam + 1), -1]),
          gather_indices)
      sorted_value = array_ops.reshape(sorted_value, [batch, beam, -1])

      collector = collector.write(i, sorted_value)
      return collector, i + 1

  collected = tensor_array_ops.TensorArray(
      size=t.size(), dynamic_size=True, dtype=dtypes.float32)
  collected, _ = control_flow_ops.while_loop(
      lambda _, i: i < t.size(),
      collect,
      loop_vars=(collected, 0),
      parallel_iterations=1)
  return collected

test code: (extended from testGatherTree)

step_ids = _transpose_batch_time(
    [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]], 
    [[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
parent_ids = _transpose_batch_time(
    [[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]], 
    [[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]])
sequence_length = [[3, 3, 3],[3, 1, 3]]

# make a dummy alignment history, which is the tiled step_ids
step_ids = ops.convert_to_tensor(step_ids)
time = array_ops.shape(step_ids)[0]
batch = array_ops.shape(step_ids)[1]
beam = array_ops.shape(step_ids)[2] 
alignment_length = array_ops.constant(10)
alignment_history = array_ops.tile(
    array_ops.expand_dims(step_ids,-1), 
    [1,1,1,alignment_length])
alignment_history = math_ops.cast(
    alignment_history, dtypes.float32)
alignment_history = array_ops.reshape(
    alignment_history, [time, batch, beam, alignment_length])
alignment_history = tensor_array_ops.TensorArray(
    size=0, dynamic_size=True, 
    dtype=dtypes.float32).unstack(alignment_history)

sorted_history = _gather_tree_for_array(
    alignment_history, parent_ids, sequence_length).stack()

sorted_step_ids = beam_search_ops.gather_tree(
    step_ids=step_ids, parent_ids=parent_ids,
    sequence_length=sequence_length)

sorted_step_ids = array_ops.tile(
    array_ops.expand_dims(sorted_step_ids,-1), 
    [1,1,1,alignment_length])
sorted_step_ids = math_ops.cast(
    clip_ops.clip_by_value(
        sorted_step_ids, 
        0, 
        math_ops.reduce_max(step_ids)),
    dtypes.float32)

# the sorted history should be the same as the tiled sorted step id
print(np.array_equal(sorted_history.eval(), sorted_step_ids.eval()))

We also tested with beam_search_decoder and dynamic_decode. It looks fine (in the finalized func).

steven-hh-ding commented Oct 9, 2017

def _gather_tree_for_array(t, parent_ids, sequence_length):
  """
  Convert tensor array that contains tensor with unsorted beams.
  Each element has shape [batch*beam, depth]
  Return sorted beam tensor array contains elements of shape
  [batch, beam, depth]
  Padding has a value of 0.
  """
  # (time, batch, beam)
  time = array_ops.shape(parent_ids)[0]
  batch = array_ops.shape(parent_ids)[1]
  beam = array_ops.shape(parent_ids)[2]
  sorted_indx = array_ops.expand_dims(
      array_ops.expand_dims(math_ops.range(beam), 0), 0)
  sorted_indx = array_ops.tile(sorted_indx, [time, batch, 1])
  sorted_beams = beam_search_ops.gather_tree(
      step_ids=sorted_indx, parent_ids=parent_ids,
      sequence_length=sequence_length)
  # all index of beam increase by 1. (-1 padding index becomes 0)
  # 0 indicates a padding beam which has zero alignments.
  sorted_beams = sorted_beams + 1

  def collect(collector, i):
      # concate a padding alignment (zeros) for each batch
      value = array_ops.reshape(t.read(i), [batch, beam, -1])
      padding = array_ops.zeros([batch, 1, array_ops.shape(value)[-1]])
      value = array_ops.concat([padding, value], axis=1)

      # collect value according to the sorted_beams
      # cannot use gather_helper as we increase the beam size by one
      # so the final_shape has a different beam size
      range_ = array_ops.expand_dims(
          math_ops.range(batch) * (beam + 1), 1)
      gather_indices = array_ops.reshape(sorted_beams[i] + range_, [-1])
      sorted_value = array_ops.gather(
          array_ops.reshape(
              value, [batch * (beam + 1), -1]),
          gather_indices)
      sorted_value = array_ops.reshape(sorted_value, [batch, beam, -1])

      collector = collector.write(i, sorted_value)
      return collector, i + 1

  collected = tensor_array_ops.TensorArray(
      size=t.size(), dynamic_size=True, dtype=dtypes.float32)
  collected, _ = control_flow_ops.while_loop(
      lambda _, i: i < t.size(),
      collect,
      loop_vars=(collected, 0),
      parallel_iterations=1)
  return collected

test code: (extended from testGatherTree)

step_ids = _transpose_batch_time(
    [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]], 
    [[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
parent_ids = _transpose_batch_time(
    [[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]], 
    [[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]])
sequence_length = [[3, 3, 3],[3, 1, 3]]

# make a dummy alignment history, which is the tiled step_ids
step_ids = ops.convert_to_tensor(step_ids)
time = array_ops.shape(step_ids)[0]
batch = array_ops.shape(step_ids)[1]
beam = array_ops.shape(step_ids)[2] 
alignment_length = array_ops.constant(10)
alignment_history = array_ops.tile(
    array_ops.expand_dims(step_ids,-1), 
    [1,1,1,alignment_length])
alignment_history = math_ops.cast(
    alignment_history, dtypes.float32)
alignment_history = array_ops.reshape(
    alignment_history, [time, batch, beam, alignment_length])
alignment_history = tensor_array_ops.TensorArray(
    size=0, dynamic_size=True, 
    dtype=dtypes.float32).unstack(alignment_history)

sorted_history = _gather_tree_for_array(
    alignment_history, parent_ids, sequence_length).stack()

sorted_step_ids = beam_search_ops.gather_tree(
    step_ids=step_ids, parent_ids=parent_ids,
    sequence_length=sequence_length)

sorted_step_ids = array_ops.tile(
    array_ops.expand_dims(sorted_step_ids,-1), 
    [1,1,1,alignment_length])
sorted_step_ids = math_ops.cast(
    clip_ops.clip_by_value(
        sorted_step_ids, 
        0, 
        math_ops.reduce_max(step_ids)),
    dtypes.float32)

# the sorted history should be the same as the tiled sorted step id
print(np.array_equal(sorted_history.eval(), sorted_step_ids.eval()))

We also tested with beam_search_decoder and dynamic_decode. It looks fine (in the finalized func).

@ebrevdo

This comment has been minimized.

Show comment
Hide comment
@ebrevdo

ebrevdo Oct 12, 2017

Contributor
Contributor

ebrevdo commented Oct 12, 2017

@guillaumekln

This comment has been minimized.

Show comment
Hide comment
@guillaumekln

guillaumekln Oct 12, 2017

Contributor

Thank you @steven-hh-ding. I integrated your code with some minor restyling and renaming.

So now, TensorArrays are ignored in _maybe_split_batch_beams, _maybe_merge_batch_beams, and _maybe_tensor_gather_helper and beams are eventually sorted in finalize(). Is that more appropriate?

If yes, @ebrevdo can review for details.

Contributor

guillaumekln commented Oct 12, 2017

Thank you @steven-hh-ding. I integrated your code with some minor restyling and renaming.

So now, TensorArrays are ignored in _maybe_split_batch_beams, _maybe_merge_batch_beams, and _maybe_tensor_gather_helper and beams are eventually sorted in finalize(). Is that more appropriate?

If yes, @ebrevdo can review for details.

@steven-hh-ding

This comment has been minimized.

Show comment
Hide comment
@steven-hh-ding

steven-hh-ding Oct 12, 2017

@guillaumekln Thank you! Please feel free to modify anything. It was done in a rush.

steven-hh-ding commented Oct 12, 2017

@guillaumekln Thank you! Please feel free to modify anything. It was done in a rush.

@guillaumekln

This comment has been minimized.

Show comment
Hide comment
@guillaumekln

guillaumekln Oct 23, 2017

Contributor

The implementation should be revised after some changes on master. Will try to work on this again in the coming days.

Contributor

guillaumekln commented Oct 23, 2017

The implementation should be revised after some changes on master. Will try to work on this again in the coming days.

@martinwicke

This comment has been minimized.

Show comment
Hide comment
@martinwicke

martinwicke Nov 8, 2017

Member

@guillaumekln any updates? It has accumulated another conflict.

Member

martinwicke commented Nov 8, 2017

@guillaumekln any updates? It has accumulated another conflict.

@guillaumekln

This comment has been minimized.

Show comment
Hide comment
@guillaumekln

guillaumekln Nov 8, 2017

Contributor

Yes, I looked at it again yesterday. I will update this thread again when it's ready for review. Thanks.

Contributor

guillaumekln commented Nov 8, 2017

Yes, I looked at it again yesterday. I will update this thread again when it's ready for review. Thanks.

@guillaumekln

This comment has been minimized.

Show comment
Hide comment
@guillaumekln

guillaumekln Nov 8, 2017

Contributor

I rebased the branch on master and revised the implementation following the recent changes made to gather_tree.

@ebrevdo Could you take a look when you have some time? Thanks.

Contributor

guillaumekln commented Nov 8, 2017

I rebased the branch on master and revised the implementation following the recent changes made to gather_tree.

@ebrevdo Could you take a look when you have some time? Thanks.

@martinwicke

This comment has been minimized.

Show comment
Hide comment
@martinwicke
Member

martinwicke commented Dec 13, 2017

ping @ebrevdo

@ebrevdo

This comment has been minimized.

Show comment
Hide comment
@ebrevdo

ebrevdo Dec 21, 2017

Contributor
Contributor

ebrevdo commented Dec 21, 2017

@ebrevdo

This comment has been minimized.

Show comment
Hide comment
@ebrevdo

ebrevdo Dec 21, 2017

Contributor
Contributor

ebrevdo commented Dec 21, 2017

@steven-hh-ding

This comment has been minimized.

Show comment
Hide comment
@steven-hh-ding

steven-hh-ding Dec 21, 2017

Oh cool. Good to know. Thank you @ebrevdo . @guillaumekln so we better use range and tile to have the indexes like below. Please double check. Thanks!

t_ind = tf.tile(tf.reshape(tf.range(time), [-1, 1, 1, 1]), [1, batch, beam, 1])
be_ind = tf.tile(tf.reshape(tf.range(batch), [-1, 1, 1, 1]), [1, time, beam, 1])
be_ind = tf.transpose(be_ind, perm=[1,0,2,3])  
t_be_ind = tf.concat([t_ind, be_ind, tf.expand_dims(beam_ids, -1)], -1)

tf.gather_nd(a, t_be_ind)

steven-hh-ding commented Dec 21, 2017

Oh cool. Good to know. Thank you @ebrevdo . @guillaumekln so we better use range and tile to have the indexes like below. Please double check. Thanks!

t_ind = tf.tile(tf.reshape(tf.range(time), [-1, 1, 1, 1]), [1, batch, beam, 1])
be_ind = tf.tile(tf.reshape(tf.range(batch), [-1, 1, 1, 1]), [1, time, beam, 1])
be_ind = tf.transpose(be_ind, perm=[1,0,2,3])  
t_be_ind = tf.concat([t_ind, be_ind, tf.expand_dims(beam_ids, -1)], -1)

tf.gather_nd(a, t_be_ind)
@tensorflowbutler

This comment has been minimized.

Show comment
Hide comment
@tensorflowbutler

tensorflowbutler Feb 9, 2018

Member

Nagging Assignee: It has been 14 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

Member

tensorflowbutler commented Feb 9, 2018

Nagging Assignee: It has been 14 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

@rmlarsen

This comment has been minimized.

Show comment
Hide comment
@rmlarsen

rmlarsen Feb 12, 2018

Member

@guillaumekln any change you could address the last minor comments?

Member

rmlarsen commented Feb 12, 2018

@guillaumekln any change you could address the last minor comments?

@guillaumekln

This comment has been minimized.

Show comment
Hide comment
@guillaumekln

guillaumekln Feb 12, 2018

Contributor

Yes, sorry for the delay. I had other priorities lately but I will finish this soon.

@ebrevdo Good call about unit testing the dynamic checks as it actually does not work as is. As all logical operands are evaluated at graph build time (I missed that), expressions like:

math_ops.equal(array_ops.shape(t)[2], beam_width)

will fail if t has a rank lower than 3.

Contributor

guillaumekln commented Feb 12, 2018

Yes, sorry for the delay. I had other priorities lately but I will finish this soon.

@ebrevdo Good call about unit testing the dynamic checks as it actually does not work as is. As all logical operands are evaluated at graph build time (I missed that), expressions like:

math_ops.equal(array_ops.shape(t)[2], beam_width)

will fail if t has a rank lower than 3.

@ttrouill ttrouill referenced this pull request Feb 25, 2018

Open

Attention Heatmap? #227

@ttrouill

This comment has been minimized.

Show comment
Hide comment
@ttrouill

ttrouill Feb 27, 2018

Hi, no pressure, but is there a chance this will be merged soon? Many thanks

ttrouill commented Feb 27, 2018

Hi, no pressure, but is there a chance this will be merged soon? Many thanks

@guillaumekln

This comment has been minimized.

Show comment
Hide comment
@guillaumekln

guillaumekln Mar 5, 2018

Contributor

Sorry, this was long overdue. I fixed the dynamic checks and added some tests for them. I also converted the error raised when static checks fail into a warning.

Contributor

guillaumekln commented Mar 5, 2018

Sorry, this was long overdue. I fixed the dynamic checks and added some tests for them. I also converted the error raised when static checks fail into a warning.

@rmlarsen

This comment has been minimized.

Show comment
Hide comment
@rmlarsen

rmlarsen Mar 16, 2018

Member

@ebrevdo another look, please?

Member

rmlarsen commented Mar 16, 2018

@ebrevdo another look, please?

@ebrevdo

awesome! thanks!

@rmlarsen

This comment has been minimized.

Show comment
Hide comment
@guillaumekln

This comment has been minimized.

Show comment
Hide comment
@guillaumekln

guillaumekln Mar 17, 2018

Contributor

Thanks, I fixed them.

Contributor

guillaumekln commented Mar 17, 2018

Thanks, I fixed them.

@rmlarsen

This comment has been minimized.

Show comment
Hide comment
@rmlarsen
Member

rmlarsen commented Mar 17, 2018

@guillaumekln Thanks!

@ebrevdo ebrevdo merged commit 838a8f5 into tensorflow:master Mar 18, 2018

15 checks passed

Android Demo App Internal CI build successful
Details
GPU CC Internal CI build successful
Details
GPU Python3 Internal CI build successful
Details
MacOS Contrib Internal CI build successful
Details
MacOS Python2 and CC Internal CI build successful
Details
Ubuntu CC Internal CI build successful
Details
Ubuntu Makefile Internal CI build successful
Details
Ubuntu Python2 Internal CI build successful
Details
Ubuntu Python3 Internal CI build successful
Details
Ubuntu Python3 PIP Internal CI build successful
Details
Ubuntu Sanity Internal CI build successful
Details
Ubuntu contrib Internal CI build successful
Details
Windows CMake Internal CI build successful
Details
XLA Internal CI build successful
Details
cla/google All necessary CLAs are signed
@ebrevdo

This comment has been minimized.

Show comment
Hide comment
@ebrevdo

ebrevdo Mar 18, 2018

Contributor

Thanks @guillaumekln for this PR. It's very nice and will help users be able to get attention alignment histories with beam search.

Contributor

ebrevdo commented Mar 18, 2018

Thanks @guillaumekln for this PR. It's very nice and will help users be able to get attention alignment histories with beam search.

@guillaumekln guillaumekln deleted the guillaumekln:beam-search-with-tensorarray branch Mar 22, 2018

StanislawAntol pushed a commit to StanislawAntol/tensorflow that referenced this pull request Mar 23, 2018

Support TensorArray in BeamSearchDecoder state. (tensorflow#13312)
* Support TensorArray in BeamSearchDecoder state.

* Use gather_nd for reordering and test more shapes.

* Add a flag to disable TensorArrays reordering.

* Add shape checks before reordering a TensorArray.

* Directly use float32 member of dtypes

* Directly access dimension value if defined

* Add more TensorArrays reordering constraints

* Do not unstack reordered TensorArrays

* Improve warning for ignored TensorArrays

* Consistent static and runtime dimensions check

* Use comparison operators

* Fix dynamic checks and add tests

* Make static checks error a warning

* Fix pylint errors
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment