diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index 7dfe258ef..3ab97238b 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -196,7 +196,10 @@ def symbols_to_logits_fn(ids): if last_position_only: return tf.squeeze(logits, axis=[1, 2, 3]) current_output_position = tf.shape(ids)[1] - 1 # -1 due to the pad above. - logits = logits[:, current_output_position, :, :] + if current_output_position.shape.ndims >= 1: + logits = logits[:, current_output_position, :, :] + else: + logits = logits[:, -1 , :, :] return tf.squeeze(logits, axis=[1, 2]) batch_size = tf.shape(features["inputs"])[0] @@ -270,7 +273,7 @@ def infer_step(recent_output, _): cur_sample = samples[:, -1, :, :] else: #Avoid the out of index Error - if len(tf.shape(recent_output)) >= 2: + if tf.shape(recent_output).shape.ndims >= 2: cur_sample = samples[:, tf.shape(recent_output)[1], :, :] else: cur_sample = samples[:, -1, :, :]