diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index e010e371a71644..ad916b6b5fa7b6 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -269,6 +269,10 @@ def _reverse_seq(input_seq, lengths): # Join into (time, batch_size, depth) s_joined = array_ops.pack(input_seq) + # TODO(schuster, ebrevdo): Remove cast when reverse_sequence takes int32 + if lengths is not None: + lengths = math_ops.to_int64(lengths) + # Reverse along dimension 0 s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1) # Split again into list