Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Small_Bert] Could not find matching concrete function to call loaded from the SavedModel. #837

Closed
jeankhawand opened this issue Jan 5, 2022 · 9 comments

Comments

@jeankhawand
Copy link

jeankhawand commented Jan 5, 2022

I am trying to build a text classification program with small bert using the following code

with tf.device('/cpu:0'):
    preprocessing_layer =  hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")
    train_data = tf.data.Dataset.from_tensor_slices((preprocessing_layer(data_train.data), tf.keras.utils.to_categorical(data_train.target)))
    valid_data = tf.data.Dataset.from_tensor_slices((preprocessing_layer(data_test.data), tf.keras.utils.to_categorical(data_test.target)))
    for text,label in train_data.take(1):
        print(text)
        print(label)

def build_classifier_model():
  input_word_ids = tf.keras.Input(shape=(128,), dtype=tf.int32, name="input_word_ids")
  input_mask = tf.keras.Input(shape=(128,), dtype=tf.int32, name="input_mask")
  input_type_ids = tf.keras.Input(shape=(128,), dtype=tf.int32, name="input_type_ids")
  encoder = hub.KerasLayer("https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/2", trainable=True, name='BERT_encoder')
  bert_inputs = dict(
    input_word_ids=input_word_ids,
    input_mask=input_mask,
    input_type_ids=input_type_ids)
  outputs = encoder(bert_inputs)
  net = outputs['pooled_output']
  net = tf.keras.layers.Dropout(0.1)(net)
  net = tf.keras.layers.Dense(NB_CLASSES, activation=None, name='classifier')(net)
  return tf.keras.Model(inputs=[input_word_ids, input_mask, input_type_ids], outputs=net)

classifier_model = build_classifier_model()

classifier_model.compile(optimizer="adam",
                         loss="categorical_crossentropy",
                         metrics="accuracy")

history = classifier_model.fit(x=train_data,
                               validation_data=valid_data,
                               epochs=10)

that's how the output of text and label looks from Dataset train_data

{'input_mask': <tf.Tensor: shape=(128,), dtype=int32, numpy=
array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)>, 'input_word_ids': <tf.Tensor: shape=(128,), dtype=int32, numpy=
array([  101,  1045,  2064,  2425,  2017,  2008,  2043,  2572, 16846,
        3390,  2070,  5055,  2247,  1037,  3962,  5871,  1006,  2413,
        1007,  1010,  2008,  2076,  8272,  1997,  2070,  5693,  2006,
        3962,  1016,  1010,  2045,  4600,  4273,  8009, 14737,  2015,
        2040,  2018,  1037,  1036,  2202,  2053,  5895,  1005,  2298,
        2006,  2045,  5344,  1012,  3962, 14549,  2024,  3294,  5214,
        1997,  2725,  2070,  2200,  2204,  2006,  8753,  9867,  1012,
        1038, 12458,  1011,  1011,   102,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0], dtype=int32)>, 'input_type_ids': <tf.Tensor: shape=(128,), dtype=int32, numpy=
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)>}
tf.Tensor([0. 0. 0. 0. 0. 0. 0. 1. 0. 0.], shape=(10,), dtype=float32)

while running this example I am getting the following output error

Epoch 1/10
WARNING:tensorflow:Model was constructed with shape (None, 128) for input KerasTensor(type_spec=TensorSpec(shape=(None, 128), dtype=tf.int32, name='input_word_ids'), name='input_word_ids', description="created by layer 'input_word_ids'"), but it was called on an input with incompatible shape (128,).
WARNING:tensorflow:Model was constructed with shape (None, 128) for input KerasTensor(type_spec=TensorSpec(shape=(None, 128), dtype=tf.int32, name='input_word_ids'), name='input_word_ids', description="created by layer 'input_word_ids'"), but it was called on an input with incompatible shape (128,).
WARNING:tensorflow:Model was constructed with shape (None, 128) for input KerasTensor(type_spec=TensorSpec(shape=(None, 128), dtype=tf.int32, name='input_mask'), name='input_mask', description="created by layer 'input_mask'"), but it was called on an input with incompatible shape (128,).
WARNING:tensorflow:Model was constructed with shape (None, 128) for input KerasTensor(type_spec=TensorSpec(shape=(None, 128), dtype=tf.int32, name='input_mask'), name='input_mask', description="created by layer 'input_mask'"), but it was called on an input with incompatible shape (128,).
WARNING:tensorflow:Model was constructed with shape (None, 128) for input KerasTensor(type_spec=TensorSpec(shape=(None, 128), dtype=tf.int32, name='input_type_ids'), name='input_type_ids', description="created by layer 'input_type_ids'"), but it was called on an input with incompatible shape (128,).
WARNING:tensorflow:Model was constructed with shape (None, 128) for input KerasTensor(type_spec=TensorSpec(shape=(None, 128), dtype=tf.int32, name='input_type_ids'), name='input_type_ids', description="created by layer 'input_type_ids'"), but it was called on an input with incompatible shape (128,).
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-21-d6615a9fa20b> in <module>()
     30 history = classifier_model.fit(x=train_data,
     31                                validation_data=valid_data,
---> 32                                epochs=10)

1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in autograph_handler(*args, **kwargs)
   1127           except Exception as e:  # pylint:disable=broad-except
   1128             if hasattr(e, "ag_error_metadata"):
-> 1129               raise e.ag_error_metadata.to_exception(e)
   1130             else:
   1131               raise

ValueError: in user code:

    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 878, in train_function  *
        return step_function(self, iterator)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 867, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 860, in run_step  **
        outputs = model.train_step(data)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 808, in train_step
        y_pred = self(x, training=True)
    File "/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py", line 67, in error_handler
        raise e.with_traceback(filtered_tb) from None

    ValueError: Exception encountered when calling layer "BERT_encoder" (type KerasLayer).
    
    in user code:
    
        File "/usr/local/lib/python3.7/dist-packages/tensorflow_hub/keras_layer.py", line 237, in call  *
            result = smart_cond.smart_cond(training,
    
        ValueError: Could not find matching concrete function to call loaded from the SavedModel. Got:
          Positional arguments (3 total):
            * {'input_word_ids': <tf.Tensor 'inputs_2:0' shape=(128,) dtype=int32>, 'input_mask': <tf.Tensor 'inputs:0' shape=(128,) dtype=int32>, 'input_type_ids': <tf.Tensor 'inputs_1:0' shape=(128,) dtype=int32>}
            * True
            * None
          Keyword arguments: {}
        
         Expected these arguments to match one of the following 4 option(s):
        
        Option 1:
          Positional arguments (3 total):
            * {'input_word_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='input_word_ids'), 'input_mask': TensorSpec(shape=(None, None), dtype=tf.int32, name='input_mask'), 'input_type_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='input_type_ids')}
            * False
            * None
          Keyword arguments: {}
        
        Option 2:
          Positional arguments (3 total):
            * {'input_word_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='inputs/input_word_ids'), 'input_type_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='inputs/input_type_ids'), 'input_mask': TensorSpec(shape=(None, None), dtype=tf.int32, name='inputs/input_mask')}
            * False
            * None
          Keyword arguments: {}
        
        Option 3:
          Positional arguments (3 total):
            * {'input_type_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='inputs/input_type_ids'), 'input_mask': TensorSpec(shape=(None, None), dtype=tf.int32, name='inputs/input_mask'), 'input_word_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='inputs/input_word_ids')}
            * True
            * None
          Keyword arguments: {}
        
        Option 4:
          Positional arguments (3 total):
            * {'input_mask': TensorSpec(shape=(None, None), dtype=tf.int32, name='input_mask'), 'input_word_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='input_word_ids'), 'input_type_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='input_type_ids')}
            * True
            * None
          Keyword arguments: {}
    
    
    Call arguments received:
      • inputs={'input_word_ids': 'tf.Tensor(shape=(128,), dtype=int32)', 'input_mask': 'tf.Tensor(shape=(128,), dtype=int32)', 'input_type_ids': 'tf.Tensor(shape=(128,), dtype=int32)'}
      • training=True

Resources

Classify Text with Bert
Bert Layer
Bert Preprocessing Layer
Tensorflow v2.7.0
Tensorflow Hub v0.12.0
Tensorflow Text v2.7.3
Python v3.7.12

@WGierke
Copy link
Collaborator

WGierke commented Jan 6, 2022

From a quick glance, it looks like your input layers are 1D (tf.keras.Input(shape=(128,)) while the model expects 2D inputs:

          Positional arguments (3 total):
            * {'input_type_ids': TensorSpec(shape=(None, None), ...

Could you try changing the shapes of the input layers to shape=(None, 128)?

@jeankhawand
Copy link
Author

From a quick glance, it looks like your input layers are 1D (tf.keras.Input(shape=(128,)) while the model expects 2D inputs:

          Positional arguments (3 total):
            * {'input_type_ids': TensorSpec(shape=(None, None), ...

Could you try changing the shapes of the input layers to shape=(None, 128)?

I managed to change these

input_word_ids = tf.keras.Input(shape=(None, 128), dtype=tf.int32, name="input_word_ids")
input_mask = tf.keras.Input(shape=(None, 128), dtype=tf.int32, name="input_mask")
input_type_ids = tf.keras.Input(shape=(None, 128), dtype=tf.int32, name="input_type_ids")

I am getting something similar

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-4-6d8650da6e46> in <module>()
     22   return tf.keras.Model(inputs=[input_word_ids, input_mask, input_type_ids], outputs=net)
     23 
---> 24 classifier_model = build_classifier_model()
     25 
     26 classifier_model.compile(optimizer="adam",

2 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    697       except Exception as e:  # pylint:disable=broad-except
    698         if hasattr(e, 'ag_error_metadata'):
--> 699           raise e.ag_error_metadata.to_exception(e)
    700         else:
    701           raise

ValueError: Exception encountered when calling layer "BERT_encoder" (type KerasLayer).

in user code:

    File "/usr/local/lib/python3.7/dist-packages/tensorflow_hub/keras_layer.py", line 237, in call  *
        result = smart_cond.smart_cond(training,

    ValueError: Could not find matching concrete function to call loaded from the SavedModel. Got:
      Positional arguments (3 total):
        * {'input_word_ids': <tf.Tensor 'inputs_2:0' shape=(None, None, 128) dtype=int32>, 'input_mask': <tf.Tensor 'inputs:0' shape=(None, None, 128) dtype=int32>, 'input_type_ids': <tf.Tensor 'inputs_1:0' shape=(None, None, 128) dtype=int32>}
        * False
        * None
      Keyword arguments: {}
    
     Expected these arguments to match one of the following 4 option(s):
    
    Option 1:
      Positional arguments (3 total):
        * {'input_word_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='input_word_ids'), 'input_mask': TensorSpec(shape=(None, None), dtype=tf.int32, name='input_mask'), 'input_type_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='input_type_ids')}
        * False
        * None
      Keyword arguments: {}
    
    Option 2:
      Positional arguments (3 total):
        * {'input_mask': TensorSpec(shape=(None, None), dtype=tf.int32, name='inputs/input_mask'), 'input_word_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='inputs/input_word_ids'), 'input_type_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='inputs/input_type_ids')}
        * False
        * None
      Keyword arguments: {}
    
    Option 3:
      Positional arguments (3 total):
        * {'input_word_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='inputs/input_word_ids'), 'input_type_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='inputs/input_type_ids'), 'input_mask': TensorSpec(shape=(None, None), dtype=tf.int32, name='inputs/input_mask')}
        * True
        * None
      Keyword arguments: {}
    
    Option 4:
      Positional arguments (3 total):
        * {'input_mask': TensorSpec(shape=(None, None), dtype=tf.int32, name='input_mask'), 'input_word_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='input_word_ids'), 'input_type_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='input_type_ids')}
        * True
        * None
      Keyword arguments: {}


Call arguments received:
  • inputs={'input_word_ids': 'tf.Tensor(shape=(None, None, 128), dtype=int32)', 'input_mask': 'tf.Tensor(shape=(None, None, 128), dtype=int32)', 'input_type_ids': 'tf.Tensor(shape=(None, None, 128), dtype=int32)'}
  • training=None

@akhorlin
Copy link
Collaborator

Looking at the error there is still shape mismatch:

Call argument received:
'input_word_ids': 'tf.Tensor(shape=(None, None, 128), dtype=int32)
Expected:
'input_word_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='input_word_ids')

(same for the other inputs)

The batch dimensions gets added automatically. So you probably want specify the input spec to be (128) which will result in (None, 128).

You can also just follow a more streamlined example and directly pass the output of the pre-processor into the encoder, assuming the preprocessing is one of the supported preprocessors (e.g https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3) or is compatible with expected interface.

# Note that input is unparsed text string (dimension: ()), effective input tensor dimensions will be (None, ) 
# where dimension 0 is a batch dimension. 
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)
preprocessor = hub.KerasLayer(
    "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")
encoder_inputs = preprocessor(text_input)
encoder = hub.KerasLayer(
    "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4",
    trainable=True)
outputs = encoder(encoder_inputs)
pooled_output = outputs["pooled_output"]      # [batch_size, 768].
sequence_output = outputs["sequence_output"]  # [batch_size, seq_length, 768].

Detailed example: https://colab.sandbox.google.com/github/tensorflow/text/blob/master/docs/tutorials/bert_glue.ipynb#scrollTo=KeHEYKXGqjAZ

@jeankhawand
Copy link
Author

jeankhawand commented Jan 27, 2022

Looking at the error there is still shape mismatch:

Call argument received: 'input_word_ids': 'tf.Tensor(shape=(None, None, 128), dtype=int32) Expected: 'input_word_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name='input_word_ids')

(same for the other inputs)

The batch dimensions gets added automatically. So you probably want specify the input spec to be (128) which will result in (None, 128).

You can also just follow a more streamlined example and directly pass the output of the pre-processor into the encoder, assuming the preprocessing is one of the supported preprocessors (e.g https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3) or is compatible with expected interface.

# Note that input is unparsed text string (dimension: ()), effective input tensor dimensions will be (None, ) 
# where dimension 0 is a batch dimension. 
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)
preprocessor = hub.KerasLayer(
    "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")
encoder_inputs = preprocessor(text_input)
encoder = hub.KerasLayer(
    "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4",
    trainable=True)
outputs = encoder(encoder_inputs)
pooled_output = outputs["pooled_output"]      # [batch_size, 768].
sequence_output = outputs["sequence_output"]  # [batch_size, seq_length, 768].

Detailed example: https://colab.sandbox.google.com/github/tensorflow/text/blob/master/docs/tutorials/bert_glue.ipynb#scrollTo=KeHEYKXGqjAZ

I managed to apply the following approach available here that is using the preprocessing layer to turn text into input_word_ids, input_mask, input_type_ids

with tf.device('/cpu:0'):
    train_data = tf.data.Dataset.from_tensor_slices((data_train.data, tf.keras.utils.to_categorical(data_train.target)))
    valid_data = tf.data.Dataset.from_tensor_slices((data_test.data, tf.keras.utils.to_categorical(data_test.target)))
    for text,label in train_data.take(1):
        print(text)
        print(label)

def build_classifier_model():
  text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
  preprocessing_layer =  hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")
  encoder = hub.KerasLayer("https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/2", trainable=True, name='BERT_encoder')
  encoder_inputs = preprocessing_layer(text_input)
  encoder = hub.KerasLayer(encoder, trainable=True, name='BERT_encoder')
  outputs = encoder(encoder_inputs)
  net = outputs['pooled_output']
  net = tf.keras.layers.Dropout(0.1)(net)
  net = tf.keras.layers.Dense(10, activation=None, name='classifier')(net)
  return tf.keras.Model(text_input, net)

classifier_model = build_classifier_model()

classifier_model.compile(optimizer="adam",
                         loss="categorical_crossentropy",
                         metrics="accuracy")

history = classifier_model.fit(x=train_data,
                               validation_data=valid_data,
                               epochs=10)

output:

tf.Tensor(b"I can tell you that when AMSAT launched some birds along a Spot satellite\n(French), that during installation of some instruments on Spot 2, there\nheavily armed legionaires who had a `take no prisoners' look on there faces.\nSpot satellites are completely capable of doing some very good on orbit\nsurveillance.\n\nBMc\n--", shape=(), dtype=string)
tf.Tensor([0. 0. 0. 0. 0. 0. 0. 1. 0. 0.], shape=(10,), dtype=float32)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-13-039563b58a99> in <module>()
     19   return tf.keras.Model(text_input, net)
     20 
---> 21 classifier_model = build_classifier_model()
     22 
     23 classifier_model.compile(optimizer="adam",

2 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    697       except Exception as e:  # pylint:disable=broad-except
    698         if hasattr(e, 'ag_error_metadata'):
--> 699           raise e.ag_error_metadata.to_exception(e)
    700         else:
    701           raise

ValueError: Exception encountered when calling layer "keras_layer_4" (type KerasLayer).

in user code:

    File "/usr/local/lib/python3.7/dist-packages/tensorflow_hub/keras_layer.py", line 237, in call  *
        result = smart_cond.smart_cond(training,

    ValueError: Could not find matching concrete function to call loaded from the SavedModel. Got:
      Positional arguments (3 total):
        * Tensor("inputs:0", shape=(None, None), dtype=string)
        * False
        * None
      Keyword arguments: {}
    
     Expected these arguments to match one of the following 4 option(s):
    
    Option 1:
      Positional arguments (3 total):
        * TensorSpec(shape=(None,), dtype=tf.string, name='sentences')
        * False
        * None
      Keyword arguments: {}
    
    Option 2:
      Positional arguments (3 total):
        * TensorSpec(shape=(None,), dtype=tf.string, name='sentences')
        * True
        * None
      Keyword arguments: {}
    
    Option 3:
      Positional arguments (3 total):
        * TensorSpec(shape=(None,), dtype=tf.string, name='inputs')
        * False
        * None
      Keyword arguments: {}
    
    Option 4:
      Positional arguments (3 total):
        * TensorSpec(shape=(None,), dtype=tf.string, name='inputs')
        * True
        * None
      Keyword arguments: {}


Call arguments received:
  • inputs=tf.Tensor(shape=(None, None), dtype=string)
  • training=None

@B-swagG
Copy link

B-swagG commented May 4, 2022

I was having the same issue using tf.data.Dataset.from_tensor_slices() with the data in a single text file and loaded into a dataframe. Restructured the data so I could load it using tf.keras.utils.text_dataset_from_directory and everything worked fine.

@gaikwadrahul8
Copy link

@jeankhawand

Apologies for the delay and while using BERT preprocessing model from TFHub, Tensorflow and tensorflow_text versions should be same so please make sure that both ( Tensorflow and tensorflow_text ) versions are same. It happens because you're using latest version for tensorflow_text but you're using different versions for Python and Tensorflow so may be because of that reason you're getting error so I would request you to please try with same versions for Tensorflow and tensorflow_text and please let us know, Is it resolving your issue or not ?

Please use below commands to install Tensorflow and Tensorflow-Text :

!pip install -U tensorflow
!pip install -U tensorflow-text
import tensorflow as tf
import tensorflow_text as text

                OR

# To install with a specific Version

!pip install -U tensorflow==2.11.*
!pip install -U tensorflow-text==2.11.*
import tensorflow as tf
import tensorflow_text as text

If your issue got resolved with above workaround, please feel free to close this issue

Thank you!

@gaikwadrahul8
Copy link

Hi, @jeankhawand

Closing this issue due to lack of recent activity for couple of weeks. Please feel free to reopen the issue with more details( if possible please help us with complete code with dataset to do troubleshooting to find out root cause) if the problem still persists after trying above workaround. Thank you!

@dgarnitz
Copy link

@gaikwadrahul8 I have the exact same problem as @jeankhawand. I followed your advice about making sure tensorflow and tensorflow-text are the same version. This did not fix this issue for me. Do you have any ideas of other things to try in order to resolve this error?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

8 participants
@WGierke @jeankhawand @akhorlin @sanatmpa1 @B-swagG @gaikwadrahul8 @dgarnitz and others