Error on saving keras custom layer with tensorflow_text.BertTokenizer #224

galfridman opened this issue Feb 12, 2020 · 19 comments


galfridman commented Feb 12, 2020

Trying so save a keras custom layers with tokenizer in it fails
versions info:


Code to reproduce:

import tensorflow_text
import tensorflow as tf

class TokenizationLayer(tf.keras.layers.Layer):
    def __init__(self, vocab_path, **kwargs):
        self.vocab_path =vocab_path
        self.tokenizer = tensorflow_text.BertTokenizer(vocab_path, token_out_type=tf.int64)
        super(TokenizationLayer, self).__init__(**kwargs)
    def get_config(self):
        config = super(TokenizationLayer, self).get_config()
            'vocab_path': self.vocab_path,
        return config

    def call(self,inputs):
        return self.tokenizer.tokenize(inputs).to_tensor()

vocab_path = r"/home/resources/bert_en_uncased_L-12_H-768_A-12/1/assets/vocab.txt"
# tensorflow_text.BertTokenizer(vocab_lookup_table = vocab_path, token_out_type=tf.int64)
inputs = tf.keras.layers.Input(shape=(), dtype=tf.string)
tokenization_layer = TokenizationLayer(vocab_path)
outputs = tokenization_layer(inputs)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)"./test")

It also gives error on

 def call(self,inputs):
        return self.tokenizer.tokenize(inputs)


AssertionError                            Traceback (most recent call last)
<ipython-input-55-e49dd5ac9a41> in <module>
----> 1"./test")

~/.local/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/ in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options)
   1006     """
   1007     save.save_model(self, filepath, overwrite, include_optimizer, save_format,
-> 1008                     signatures, options)
   1010   def save_weights(self, filepath, overwrite=True, save_format=None):

~/.local/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/ in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options)
    113   else:
    114, filepath, overwrite, include_optimizer,
--> 115                           signatures, options)

~/.local/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/saved_model/ in save(model, filepath, overwrite, include_optimizer, signatures, options)
     76     # we use the default replica context here.
     77     with distribution_strategy_context._get_default_replica_context():  # pylint: disable=protected-access
---> 78, filepath, signatures, options)
     80   if not include_optimizer:

~/.local/lib/python3.6/site-packages/tensorflow_core/python/saved_model/ in save(obj, export_dir, signatures, options)
    907   object_saver = util.TrackableSaver(checkpoint_graph_view)
    908   asset_info, exported_graph = _fill_meta_graph_def(
--> 909       meta_graph_def, saveable_view, signatures, options.namespace_whitelist)
    910   saved_model.saved_model_schema_version = (
    911       constants.SAVED_MODEL_SCHEMA_VERSION)

~/.local/lib/python3.6/site-packages/tensorflow_core/python/saved_model/ in _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions, namespace_whitelist)
    586   with exported_graph.as_default():
--> 587     signatures = _generate_signatures(signature_functions, resource_map)
    588     for concrete_function in saveable_view.concrete_functions:
    589       concrete_function.add_to_graph()

~/.local/lib/python3.6/site-packages/tensorflow_core/python/saved_model/ in _generate_signatures(signature_functions, resource_map)
    456             argument_inputs, signature_key,
    457     outputs = _call_function_with_mapped_captures(
--> 458         function, mapped_inputs, resource_map)
    459     signatures[signature_key] = signature_def_utils.build_signature_def(
    460         _tensor_dict_to_tensorinfo(exterior_argument_placeholders),

~/.local/lib/python3.6/site-packages/tensorflow_core/python/saved_model/ in _call_function_with_mapped_captures(function, args, resource_map)
    408   """Calls `function` in the exported graph, using mapped resource captures."""
    409   export_captures = _map_captures_to_created_tensors(
--> 410       function.graph.captures, resource_map)
    411   # Calls the function quite directly, since we have new captured resource
    412   # tensors we need to feed in which weren't part of the original function

~/.local/lib/python3.6/site-packages/tensorflow_core/python/saved_model/ in _map_captures_to_created_tensors(original_captures, resource_map)
    330            "be tracked by assigning them to an attribute of a tracked object "
    331            "or assigned to an attribute of the main object directly.")
--> 332           .format(interior))
    333     export_captures.append(mapped_resource)
    334   return export_captures

AssertionError: Tried to export a function which references untracked object Tensor("StatefulPartitionedCall/args_1:0", shape=(), dtype=resource).TensorFlow objects (e.g. tf.Variable) captured by functions must be tracked by assigning them to an attribute of a tracked object or assigned to an attribute of the main object directly.

Hi @galfridman -

We're still working on enabling saving of lookup tables in core Keras - that fix went in internally yesterday. Now that it's in, we will create a Keras layer that will wrap BertTokenizer- stay tuned!

yynil commented Mar 6, 2020

I have the same issue that I implemented a BertVocabLayer:

def readSpecialTokens(dict_file):
    pad = 0
    unk = 0
    cls = 0
    sep = 0
    with open(dict_file,'r') as dictfile:
        index = 0
        finger = 0
        for str in dictfile:
            stripped = str.strip()
            if stripped == '[PAD]':
                pad = index
                finger = finger + 1
            elif stripped == '[UNK]':
                unk = index
                finger = finger + 2
            elif stripped == '[CLS]':
                cls = index
                finger = finger + 4
            elif stripped == '[SEP]':
                sep = index
                finger = finger + 8
            index = index + 1
            if finger == 15:
        print("finish at {0}".format(index))
    return pad,unk,cls,sep

class BertVocabLayer(keras.layers.Layer):
    def __init__(self,vocab_file,max_len=128, mask=False, **kwargs ):
        super(BertVocabLayer, self).__init__(**kwargs)
        self._mask = mask
        self._compute_output_and_mask_jointly = True
        self._supports_ragged_inputs = True
        self.trainable = False

        self._tokenizer = text.BertTokenizer(vocab_lookup_table=vocab_file)
        self._max_len = max_len
        pad,unk,cls,sep = readSpecialTokens(vocab_file)
        self._cls = tf.constant([cls],dtype=tf.int64)
        self._sep = tf.constant([sep],dtype=tf.int64)
        self._pad_value = pad

    def build(self, input_shape):
    def call(self,input_tensor,training=False):
        #squeezed_tensor = tf.squeeze(input_tensor,axis=0)
        squeezed_tensor = input_tensor
        index = 0
        size = squeezed_tensor.shape[0]
        tensors = tf.TensorArray(
            dtype=tf.dtypes.int64, size=size,
            element_shape=[self._max_len], colocate_with_first_write_call=True, name=None
        for index in range(size):
            idfeature = self._tokenizer.tokenize(squeezed_tensor[index])
            idfeature = idfeature.merge_dims(0,-1)
            idfeature = idfeature[0:self._max_len-2]
            idfeature = tf.concat([self._cls,idfeature,self._sep],0)
            idfeature = tf.pad(idfeature,[[0,self._max_len-tf.shape(idfeature)[0]]])
            tensors = tensors.write(index,idfeature)
            index = index + 1
        outputs =  tensors.stack()
        outputs = tf.expand_dims(outputs,0)
        return outputs

I used that to build model and training successfully, however it throws exact same errors when saving it.
The model build function looks like below:

def build_model(bert_params,dict_file,batch_size):
    l_bert = bert.BertModelLayer.from_params(bert_params, name="bert")
    x = keras.layers.Input(shape=(1,),batch_size=batch_size,dtype='string')
    bertVocabLayer = BertVocabLayer(dict_file,max_len=max_len,dtype=tf.dtypes.string)
    l_input_ids = bertVocabLayer(x)
    l_input_ids = tf.squeeze(l_input_ids,axis=0)
    #l_input_ids = keras.layers.Input(shape=(128,), dtype='int32', name="input_ids")
    output = l_bert(l_input_ids)
    cls_out = keras.layers.Lambda(lambda x: x[:, 0, :])(output)
    cls_out = keras.layers.Dropout(0.5)(cls_out)
    logits = keras.layers.Dense(units=768, activation="tanh")(cls_out)
    logits = keras.layers.Dropout(0.5)(logits)
    logits = keras.layers.Dense(units=2, activation="softmax")(logits)
    model = keras.Model(inputs=x, outputs=logits),128,1))

    for weight in l_bert.weights:

    return model, l_bert

Any updates for this issue resolving? Thanks a lot!

MoggeM commented Mar 10, 2020

Im having the same issue but with tensorflow_text.SentencepieceTokenizer. Is the fix for all type of tokenizers or only Bert?

Hi @galfridman -

We're still working on enabling saving of lookup tables in core Keras - that fix went in internally yesterday. Now that it's in, we will create a Keras layer that will wrap BertTokenizer- stay tuned!

Hi Mark, Is this in yet / can you help with a temporary workaround? Thanks!

peakji commented Apr 15, 2020

Can we expect to save tokenizers into SavedModels when TensorFlow and TF Text 2.2.0 go stable?

@markomernick Confirming that this is still an issue on


Please let us know when there's a fix. Thanks!

I can also confirm this error on this version:
Would love to know a fix for this.

Hey all,

Unfortunately we had to push this out to ensure compatibility with DistributionStrategy. I'm working on it now and will have a fix in the nightly as soon as possible.

@markomernick any tips on how to hack some solution in the mean time that becomes official?

Copy link

@markomernick Any updates regarding this ??

Copy link

Hey - we are currently working on a BertTokenizer Keras layer, as well as a Wordpiece Keras layer. We expect these to be part of the TF.Text 2.3 release.

MoggeM commented Jun 11, 2020

Any news about the same for the SentencePiece tokenizer?

Copy link

@MoggeM We are thinking about this as well. I certainly hope to have it for 2.3, barring any unforeseen issues.

tf-text-github-robot pushed a commit that referenced this issue Jun 12, 2020
This fixes #224 where tokenizers were unable to be saved from within custom Keras layers.

PiperOrigin-RevId: 315784812
tf-text-github-robot pushed a commit that referenced this issue Jun 16, 2020
This fixes #224 where tokenizers were unable to be saved from within custom Keras layers.

PiperOrigin-RevId: 315784812
philipp-eisen commented Jun 16, 2020

@Mistobaan now It's a bit late not sure if still helpful, but you can hack your way around it like this.


import tensorflow as tf
import tensorflow_text as tf_text

class BertTokenizer(tf.Module):
    def __init__(self, vocab_file_path, sequence_length=BERT_MAX_SEQ_LEN, lower_case=True):
        self.CLS_ID = tf.constant(101, dtype=tf.int64)
        self.SEP_ID = tf.constant(102, dtype=tf.int64)
        self.PAD_ID = tf.constant(0, dtype=tf.int64)

        self.sequence_length = tf.constant(sequence_length)

        vocab = self.load_vocab(vocab_file_path)

        # These two lines are basically what makes it work
        # assigning the vocab to a tf.Module and then later assigning the
        # intantiated Module to e.g. a Keras Model
        self.bert_tokenizer = tf_text.BertTokenizer(

    def load_vocab(self, vocab_file):
        """Loads a vocabulary file into a list."""
        vocab = []
        with, "r") as reader:
            while True:
                token = reader.readline()
                if not token:
                token = token.strip()
        return vocab

    def create_vocab_table(self, vocab, num_oov=1):
        vocab_values = tf.range(tf.size(vocab, out_type=tf.int64), dtype=tf.int64)
        self.init = tf.lookup.KeyValueTensorInitializer(
            keys=vocab, values=vocab_values, key_dtype=tf.string, value_dtype=tf.int64
        self.vocab_table = tf.lookup.StaticVocabularyTable(
            self.init, num_oov, lookup_key_dtype=tf.string

    def __call__(self, text: tf.Tensor) -> tf.Tensor:
        Perform the BERT preprocessing from text -> input token ids
        # Convert text into token ids
        tokens = self.bert_tokenizer.tokenize(text)

        # Flatten the ragged tensors
        tokens = tokens.merge_dims(1, 2)

        # Add start and end token ids to the id sequence
        start_tokens = tf.fill([tf.shape(text)[0], 1], self.CLS_ID)
        end_tokens = tf.fill([tf.shape(text)[0], 1], self.SEP_ID)
        tokens = tf.concat([start_tokens, tokens, end_tokens], axis=1)

        # Truncate to sequence length
        tokens = tokens[:, : self.sequence_length]

        # Convert ragged tensor to tensor and pad with PAD_ID
        tokens = tokens.to_tensor(default_value=self.PAD_ID)

        # Pad to sequence length
        pad = self.sequence_length - tf.shape(tokens)[1]
        tokens = tf.pad(tokens, [[0, 0], [0, pad]], constant_values=self.PAD_ID)

        return tf.reshape(tokens, [-1, self.sequence_length])

# Dummy model to show that serialization works
model = tf.keras.Sequential([
    tf.keras.Input(shape=(1,), dtype=tf.float32),

model.tokenizer = BertTokenizer(vocab_file_path='./test_file')'./saved_model', signatures=model.tokenizer.__call__.get_concrete_function(tf.TensorSpec(None, tf.string)))

tf-text-github-robot pushed a commit that referenced this issue Jun 26, 2020
This fixes #224 where tokenizers were unable to be saved from within custom Keras layers.

PiperOrigin-RevId: 315784812
broken commented Jul 7, 2020

fyi: pr #328 will resolve this issue

Copy link

peakji commented Jul 28, 2020

fyi: pr #328 will resolve this issue

It seems that this PR is not included in the 2.3 branch?

Copy link

broken commented Jul 28, 2020

Yeah; we found an issue with it when Sentencepiece was used from within map_fn. That's fixed and it will be a part of the 2.3.0 release later today.

Copy link

broken commented Jul 28, 2020

2.3.0 is now released, so I'm going to close this bug. Feel free to reopen if a problem arises.

@broken broken closed this as completed Jul 28, 2020
peakji commented Jul 29, 2020

Saving now works properly in v2.3.0, but still got some warnings while loading the SavedModel:

WARNING:tensorflow:Unresolved object in checkpoint: (root).tokenizer._bert_tokenizer._wordpiece_tokenizer._vocab_lookup_table._initializer
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See for details.

Is this a known issue? I can confirm that the restored model is working despite the warnings.

Successfully merging a pull request may close this issue.