Skip to content

tensorflow lite : Convertion from keras works fine but execution on device fails : ByteBuffer is not a valid flatbuffer model #26083

@NicolasVIscool

Description

@NicolasVIscool

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Google Colaboratory / Android
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: Oneplus 6, Samsung Galaxy s8, ...
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): b'v1.13.0-rc1-2-gb6141b06f5' 1.13.0-rc1 (but also present in 1.12.0)
  • Firebase ML-Interpreter Version : 16.2.4
  • Python version: 3.6
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version: -
  • GPU model and memory: -

Describe the current behavior
We want to convert a keras model to a tflite model. Some architectures won't convert (ex: Reshape Layer followed by a BatchNorm layer), but some seem to be able to be converted, but when used on device the error "Caused by: java.lang.IllegalArgumentException: ByteBuffer is not a valid flatbuffer model" happens. Here is a simple model description demonstrating this issue :

Describe the expected behavior
If no errors are encountered when converting to tflite model, we should be able to run them with firebase model interpreter

Code to reproduce the issue

On google Colaboratory create the model :


from tensorflow.keras import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import mse, categorical_crossentropy

state = Input((884,))
mask = Input((260,))

bn = BatchNormalization()(state)
rs1 = Reshape((52, 17))(bn)
lstm1= LSTM((256, unroll=True))(rs1)

policy_raw = Dense(260, activation=sigmoid)(lstm1)
policy_masked = Multiply()([mask, policy_raw])
output_policy = Activation(activation=softmax)(policy_masked)

output_value = Dense(1)(state)

model = Model([state, mask], [output_policy, output_value])
model.compile(optimizer=Adam(), loss=[categorical_crossentropy, mse])

model.save("keras_model.keras")

Convert the model :

from tensorflow.lite.python.lite import TFLiteConverter

path = "keras_model.keras"

converter = TFLiteConverter.from_keras_model_file(path)

tflite_model = converter.convert()

open(path + ".tflite", "wb").write(tflite_model)

Then download the tflite model :

from google.colab import files

files.download(path +'.tflite')

Then try to use the model on Android with tflite through Firebase ML Interpreter :

FirebaseModelInputOutputOptions dataOptions = new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FLOAT32, new int[]{1, 884})
                .setInputFormat(1, FLOAT32, new int[]{1, 260})
                .setOutputFormat(0, FLOAT32, new int[]{1, 260})
                .setOutputFormat(1, FLOAT32, new int[]{1, 1})
                .build();

final FirebaseLocalModelSource modelSource = new FirebaseLocalModelSource.Builder("asset")
                .setAssetFilePath("keras_model.tflite").build();

FirebaseModelManager.getInstance().registerLocalModelSource(modelSource);

        final FirebaseModelOptions options = new FirebaseModelOptions.Builder()
                .setLocalModelName("asset")
                .build();

FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);

final FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
                .add(new float[][]{new float[884]})
                .add(new float[][]{new float[260]})
                .build();

interpreter.run(inputs, dataOptions)
                .continueWith(new Continuation<FirebaseModelOutputs, Object>() {
                    @Override
                    public Object then(@NonNull final Task<FirebaseModelOutputs> task) {
                        try {
                            task.getResult();
                        } catch (final Throwable e) {
                            Log.e("TFLITE", e.getMessage(), e);
                        } finally {
                        }

                        return null;
                    }
                });

You can find a converted model here :
keras_model.zip

Other info / logs
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

Error log :

com.google.firebase.ml.common.FirebaseMLException: The load task failed
    com.google.android.gms.tasks.RuntimeExecutionException: com.google.firebase.ml.common.FirebaseMLException: The load task failed
        at com.google.android.gms.tasks.zzu.getResult(Unknown Source:15)
        at com.iscoolentertainment.unity.tflite.UnityTFLiteRunner$1.then(UnityTFLiteRunner.java:92)
        at com.google.android.gms.tasks.zzd.run(Unknown Source:5)
        at android.os.Handler.handleCallback(Handler.java:873)
        at android.os.Handler.dispatchMessage(Handler.java:99)
        at android.os.Looper.loop(Looper.java:193)
        at android.app.ActivityThread.main(ActivityThread.java:6863)
        at java.lang.reflect.Method.invoke(Native Method)
        at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:537)
        at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:858)
     Caused by: com.google.firebase.ml.common.FirebaseMLException: The load task failed
        at com.google.android.gms.internal.firebase_ml.zzio.zzf(Unknown Source:65)
        at com.google.android.gms.internal.firebase_ml.zzij.call(Unknown Source:2)
        at com.google.android.gms.internal.firebase_ml.zzie.zza(Unknown Source:29)
        at com.google.android.gms.internal.firebase_ml.zzif.run(Unknown Source:2)
        at android.os.Handler.handleCallback(Handler.java:873)
        at android.os.Handler.dispatchMessage(Handler.java:99)
        at com.google.android.gms.internal.firebase_ml.zze.dispatchMessage(Unknown Source:5)
        at android.os.Looper.loop(Looper.java:193)
        at android.os.HandlerThread.run(HandlerThread.java:65)
     Caused by: com.google.firebase.ml.common.FirebaseMLException: Local model load failed: 
        at com.google.android.gms.internal.firebase_ml.zzje.zza(Unknown Source:127)
        at com.google.android.gms.internal.firebase_ml.zzje.zzgd(Unknown Source:102)
        at com.google.android.gms.internal.firebase_ml.zziq.zzgg(Unknown Source:7)
        at com.google.android.gms.internal.firebase_ml.zziq.call(Unknown Source:23)
        at com.google.android.gms.internal.firebase_ml.zzie.zza(Unknown Source:29) 
        at com.google.android.gms.internal.firebase_ml.zzif.run(Unknown Source:2) 
        at android.os.Handler.handleCallback(Handler.java:873) 
        at android.os.Handler.dispatchMessage(Handler.java:99) 
        at com.google.android.gms.internal.firebase_ml.zze.dispatchMessage(Unknown Source:5) 
        at android.os.Looper.loop(Looper.java:193) 
        at android.os.HandlerThread.run(HandlerThread.java:65) 
     Caused by: java.lang.IllegalArgumentException: ByteBuffer is not a valid flatbuffer model
        at org.tensorflow.lite.NativeInterpreterWrapper.createModelWithBuffer(Native Method)
        at org.tensorflow.lite.NativeInterpreterWrapper.<init>(NativeInterpreterWrapper.java:74)
        at org.tensorflow.lite.NativeInterpreterWrapper.<init>(NativeInterpreterWrapper.java:54)
        at org.tensorflow.lite.Interpreter.<init>(Interpreter.java:114)
        at com.google.android.gms.internal.firebase_ml.zzje.zzb(Unknown Source:222)
        at com.google.android.gms.internal.firebase_ml.zzjf.zzc(Unknown Source:0)
        at com.google.android.gms.internal.firebase_ml.zzje.zzb(Unknown Source:148)
        at com.google.android.gms.internal.firebase_ml.zzje.zza(Unknown Source:116)
        at com.google.android.gms.internal.firebase_ml.zzje.zzgd(Unknown Source:102) 
        at com.google.android.gms.internal.firebase_ml.zziq.zzgg(Unknown Source:7) 
        at com.google.android.gms.internal.firebase_ml.zziq.call(Unknown Source:23) 
        at com.google.android.gms.internal.firebase_ml.zzie.zza(Unknown Source:29) 
        at com.google.android.gms.internal.firebase_ml.zzif.run(Unknown Source:2) 
        at android.os.Handler.handleCallback(Handler.java:873) 
        at android.os.Handler.dispatchMessage(Handler.java:99) 
        at com.google.android.gms.internal.firebase_ml.zze.dispatchMessage(Unknown Source:5) 
        at android.os.Looper.loop(Looper.java:193) 
        at android.os.HandlerThread.run(HandlerThread.java:65) 

For information, if we use the model on Google collaboratory, it works fine :


import numpy as np
import tensorflow as tf

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="keras_model.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test model on random input data.
input_shape = input_details[0]['shape']
input_shape_2 = input_details[1]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
input_data = np.array(np.random.random_sample(input_shape_2), dtype=np.float32)
interpreter.set_tensor(input_details[1]['index'], input_data)

interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
output_data = interpreter.get_tensor(output_details[1]['index'])
print(output_data)

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions