Skip to content

Bert2Bert model call fails when setting padded_decode to True #10222

@LoicDagnas

Description

@LoicDagnas

1. The entire URL of the file you are using

https://github.com/tensorflow/models/blob/master/official/nlp/nhnet/models.py

2. Describe the bug

While using a Bert2Bert model instance setting the padded_decode parameter to True (e.g. for TPU usage), I am forced to specify the batch size in the input when calling the model.

3. Steps to reproduce

You can simply run the following code:

import tensorflow as tf
from official.nlp.nhnet.configs import UNITTEST_CONFIG, BERT2BERTConfig
from official.nlp.nhnet.models import Bert2Bert, get_bert2bert_layers

bert2bert_config_dict = UNITTEST_CONFIG.copy()
bert2bert_config_dict["len_title"] = 32
bert2bert_config_dict["max_position_embeddings"] = 200
bert2bert_config_dict["padded_decode"] = True

bert2bert_config = BERT2BERTConfig.from_args(**bert2bert_config_dict)
bert_layer, decoder_layer = get_bert2bert_layers(params=bert2bert_config)

bert2bert = Bert2Bert(bert2bert_config, bert_layer, decoder_layer)

inputs = {
    "input_ids": tf.keras.layers.Input((200,), dtype=tf.int32, name="input_ids"),
    "input_mask": tf.keras.layers.Input((200,), dtype=tf.int32, name="input_mask"),
    "segment_ids": tf.keras.layers.Input((200,), dtype=tf.int32, name="segment_ids"),
    "target_ids": tf.keras.layers.Input((32,), dtype=tf.int32, name="target_ids")
}

output = bert2bert(inputs, mode='predict')

you'll get the following stack:

[...]
C:\dev\ml\OnnxConversionLab\venv\lib\site-packages\official\nlp\nhnet\models.py:168 predict_decode  *
        decoded_ids, scores = beam_search.sequence_beam_search(
    C:\dev\ml\OnnxConversionLab\venv\lib\site-packages\official\nlp\modeling\ops\beam_search.py:622 sequence_beam_search  *
        return sbs.search(initial_ids, initial_cache)
    C:\dev\ml\OnnxConversionLab\venv\lib\site-packages\official\nlp\modeling\ops\beam_search.py:158 search  *
        state, state_shapes = self._create_initial_state(initial_ids, initial_cache,
    C:\dev\ml\OnnxConversionLab\venv\lib\site-packages\official\nlp\modeling\ops\beam_search.py:419 _create_initial_state  *
        alive_log_probs = tf.tile(initial_log_probs, [batch_size, 1])
    C:\dev\ml\OnnxConversionLab\venv\lib\site-packages\tensorflow\python\ops\gen_array_ops.py:11530 tile  **
        _, _, _op, _outputs = _op_def_library._apply_op_helper(
    C:\dev\ml\OnnxConversionLab\venv\lib\site-packages\tensorflow\python\framework\op_def_library.py:525 _apply_op_helper
        raise err
    C:\dev\ml\OnnxConversionLab\venv\lib\site-packages\tensorflow\python\framework\op_def_library.py:511 _apply_op_helper
        values = ops.convert_to_tensor(
    C:\dev\ml\OnnxConversionLab\venv\lib\site-packages\tensorflow\python\profiler\trace.py:163 wrapped
        return func(*args, **kwargs)
    C:\dev\ml\OnnxConversionLab\venv\lib\site-packages\tensorflow\python\framework\ops.py:1566 convert_to_tensor
        ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
    C:\dev\ml\OnnxConversionLab\venv\lib\site-packages\tensorflow\python\framework\constant_op.py:346 _constant_tensor_conversion_function
        return constant(v, dtype=dtype, name=name)
    C:\dev\ml\OnnxConversionLab\venv\lib\site-packages\tensorflow\python\framework\constant_op.py:271 constant
        return _constant_impl(value, dtype, shape, name, verify_shape=False,
    C:\dev\ml\OnnxConversionLab\venv\lib\site-packages\tensorflow\python\framework\constant_op.py:288 _constant_impl
        tensor_util.make_tensor_proto(
    C:\dev\ml\OnnxConversionLab\venv\lib\site-packages\tensorflow\python\framework\tensor_util.py:551 make_tensor_proto
        raise TypeError("Failed to convert object of type %s to Tensor. "

    TypeError: Failed to convert object of type <class 'list'> to Tensor. Contents: [None, 1]. Consider casting elements to a supported type.

but if you give the following input with the batch size specified:

inputs = {
    "input_ids": tf.keras.layers.Input((200,), dtype=tf.int32, name="input_ids", batch_size=8),
    "input_mask": tf.keras.layers.Input((200,), dtype=tf.int32, name="input_mask", batch_size=8),
    "segment_ids": tf.keras.layers.Input((200,), dtype=tf.int32, name="segment_ids", batch_size=8),
    "target_ids": tf.keras.layers.Input((32,), dtype=tf.int32, name="target_ids", batch_size=8)
}

it will work.

4. Expected behavior

I was expecting that it works in both case i.e. specifying the batch size or not.

5. Additional context

X

6. System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Windows 10.0.19042
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 2.6.0
  • Python version: 3.7.6

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions