From bd300def3576410ed029e08ee4cbdb2db4659276 Mon Sep 17 00:00:00 2001 From: "yueyu.lin" Date: Thu, 29 Jun 2017 11:37:03 +0800 Subject: [PATCH] Two fix: 1. for beam search use the shape.ndims to avoid out of index error 2. for greedy search, still use the shape.ndims to avoid the out of index error. Before that I misuse the slice operation:-( --- tensor2tensor/utils/t2t_model.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index 7dfe258ef..3ab97238b 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -196,7 +196,10 @@ def symbols_to_logits_fn(ids): if last_position_only: return tf.squeeze(logits, axis=[1, 2, 3]) current_output_position = tf.shape(ids)[1] - 1 # -1 due to the pad above. - logits = logits[:, current_output_position, :, :] + if current_output_position.shape.ndims >= 1: + logits = logits[:, current_output_position, :, :] + else: + logits = logits[:, -1 , :, :] return tf.squeeze(logits, axis=[1, 2]) batch_size = tf.shape(features["inputs"])[0] @@ -270,7 +273,7 @@ def infer_step(recent_output, _): cur_sample = samples[:, -1, :, :] else: #Avoid the out of index Error - if len(tf.shape(recent_output)) >= 2: + if tf.shape(recent_output).shape.ndims >= 2: cur_sample = samples[:, tf.shape(recent_output)[1], :, :] else: cur_sample = samples[:, -1, :, :]