-
Notifications
You must be signed in to change notification settings - Fork 75.3k
tensorflow lite : Convertion from keras works fine but execution on device fails : ByteBuffer is not a valid flatbuffer model #26083
Description
Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Google Colaboratory / Android
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: Oneplus 6, Samsung Galaxy s8, ...
- TensorFlow installed from (source or binary): binary
- TensorFlow version (use command below): b'v1.13.0-rc1-2-gb6141b06f5' 1.13.0-rc1 (but also present in 1.12.0)
- Firebase ML-Interpreter Version : 16.2.4
- Python version: 3.6
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version: -
- GPU model and memory: -
Describe the current behavior
We want to convert a keras model to a tflite model. Some architectures won't convert (ex: Reshape Layer followed by a BatchNorm layer), but some seem to be able to be converted, but when used on device the error "Caused by: java.lang.IllegalArgumentException: ByteBuffer is not a valid flatbuffer model" happens. Here is a simple model description demonstrating this issue :
Describe the expected behavior
If no errors are encountered when converting to tflite model, we should be able to run them with firebase model interpreter
Code to reproduce the issue
On google Colaboratory create the model :
from tensorflow.keras import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import mse, categorical_crossentropy
state = Input((884,))
mask = Input((260,))
bn = BatchNormalization()(state)
rs1 = Reshape((52, 17))(bn)
lstm1= LSTM((256, unroll=True))(rs1)
policy_raw = Dense(260, activation=sigmoid)(lstm1)
policy_masked = Multiply()([mask, policy_raw])
output_policy = Activation(activation=softmax)(policy_masked)
output_value = Dense(1)(state)
model = Model([state, mask], [output_policy, output_value])
model.compile(optimizer=Adam(), loss=[categorical_crossentropy, mse])
model.save("keras_model.keras")
Convert the model :
from tensorflow.lite.python.lite import TFLiteConverter
path = "keras_model.keras"
converter = TFLiteConverter.from_keras_model_file(path)
tflite_model = converter.convert()
open(path + ".tflite", "wb").write(tflite_model)
Then download the tflite model :
from google.colab import files
files.download(path +'.tflite')
Then try to use the model on Android with tflite through Firebase ML Interpreter :
FirebaseModelInputOutputOptions dataOptions = new FirebaseModelInputOutputOptions.Builder()
.setInputFormat(0, FLOAT32, new int[]{1, 884})
.setInputFormat(1, FLOAT32, new int[]{1, 260})
.setOutputFormat(0, FLOAT32, new int[]{1, 260})
.setOutputFormat(1, FLOAT32, new int[]{1, 1})
.build();
final FirebaseLocalModelSource modelSource = new FirebaseLocalModelSource.Builder("asset")
.setAssetFilePath("keras_model.tflite").build();
FirebaseModelManager.getInstance().registerLocalModelSource(modelSource);
final FirebaseModelOptions options = new FirebaseModelOptions.Builder()
.setLocalModelName("asset")
.build();
FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
final FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
.add(new float[][]{new float[884]})
.add(new float[][]{new float[260]})
.build();
interpreter.run(inputs, dataOptions)
.continueWith(new Continuation<FirebaseModelOutputs, Object>() {
@Override
public Object then(@NonNull final Task<FirebaseModelOutputs> task) {
try {
task.getResult();
} catch (final Throwable e) {
Log.e("TFLITE", e.getMessage(), e);
} finally {
}
return null;
}
});
You can find a converted model here :
keras_model.zip
Other info / logs
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
Error log :
com.google.firebase.ml.common.FirebaseMLException: The load task failed
com.google.android.gms.tasks.RuntimeExecutionException: com.google.firebase.ml.common.FirebaseMLException: The load task failed
at com.google.android.gms.tasks.zzu.getResult(Unknown Source:15)
at com.iscoolentertainment.unity.tflite.UnityTFLiteRunner$1.then(UnityTFLiteRunner.java:92)
at com.google.android.gms.tasks.zzd.run(Unknown Source:5)
at android.os.Handler.handleCallback(Handler.java:873)
at android.os.Handler.dispatchMessage(Handler.java:99)
at android.os.Looper.loop(Looper.java:193)
at android.app.ActivityThread.main(ActivityThread.java:6863)
at java.lang.reflect.Method.invoke(Native Method)
at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:537)
at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:858)
Caused by: com.google.firebase.ml.common.FirebaseMLException: The load task failed
at com.google.android.gms.internal.firebase_ml.zzio.zzf(Unknown Source:65)
at com.google.android.gms.internal.firebase_ml.zzij.call(Unknown Source:2)
at com.google.android.gms.internal.firebase_ml.zzie.zza(Unknown Source:29)
at com.google.android.gms.internal.firebase_ml.zzif.run(Unknown Source:2)
at android.os.Handler.handleCallback(Handler.java:873)
at android.os.Handler.dispatchMessage(Handler.java:99)
at com.google.android.gms.internal.firebase_ml.zze.dispatchMessage(Unknown Source:5)
at android.os.Looper.loop(Looper.java:193)
at android.os.HandlerThread.run(HandlerThread.java:65)
Caused by: com.google.firebase.ml.common.FirebaseMLException: Local model load failed:
at com.google.android.gms.internal.firebase_ml.zzje.zza(Unknown Source:127)
at com.google.android.gms.internal.firebase_ml.zzje.zzgd(Unknown Source:102)
at com.google.android.gms.internal.firebase_ml.zziq.zzgg(Unknown Source:7)
at com.google.android.gms.internal.firebase_ml.zziq.call(Unknown Source:23)
at com.google.android.gms.internal.firebase_ml.zzie.zza(Unknown Source:29)
at com.google.android.gms.internal.firebase_ml.zzif.run(Unknown Source:2)
at android.os.Handler.handleCallback(Handler.java:873)
at android.os.Handler.dispatchMessage(Handler.java:99)
at com.google.android.gms.internal.firebase_ml.zze.dispatchMessage(Unknown Source:5)
at android.os.Looper.loop(Looper.java:193)
at android.os.HandlerThread.run(HandlerThread.java:65)
Caused by: java.lang.IllegalArgumentException: ByteBuffer is not a valid flatbuffer model
at org.tensorflow.lite.NativeInterpreterWrapper.createModelWithBuffer(Native Method)
at org.tensorflow.lite.NativeInterpreterWrapper.<init>(NativeInterpreterWrapper.java:74)
at org.tensorflow.lite.NativeInterpreterWrapper.<init>(NativeInterpreterWrapper.java:54)
at org.tensorflow.lite.Interpreter.<init>(Interpreter.java:114)
at com.google.android.gms.internal.firebase_ml.zzje.zzb(Unknown Source:222)
at com.google.android.gms.internal.firebase_ml.zzjf.zzc(Unknown Source:0)
at com.google.android.gms.internal.firebase_ml.zzje.zzb(Unknown Source:148)
at com.google.android.gms.internal.firebase_ml.zzje.zza(Unknown Source:116)
at com.google.android.gms.internal.firebase_ml.zzje.zzgd(Unknown Source:102)
at com.google.android.gms.internal.firebase_ml.zziq.zzgg(Unknown Source:7)
at com.google.android.gms.internal.firebase_ml.zziq.call(Unknown Source:23)
at com.google.android.gms.internal.firebase_ml.zzie.zza(Unknown Source:29)
at com.google.android.gms.internal.firebase_ml.zzif.run(Unknown Source:2)
at android.os.Handler.handleCallback(Handler.java:873)
at android.os.Handler.dispatchMessage(Handler.java:99)
at com.google.android.gms.internal.firebase_ml.zze.dispatchMessage(Unknown Source:5)
at android.os.Looper.loop(Looper.java:193)
at android.os.HandlerThread.run(HandlerThread.java:65)
For information, if we use the model on Google collaboratory, it works fine :
import numpy as np
import tensorflow as tf
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="keras_model.tflite")
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Test model on random input data.
input_shape = input_details[0]['shape']
input_shape_2 = input_details[1]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
input_data = np.array(np.random.random_sample(input_shape_2), dtype=np.float32)
interpreter.set_tensor(input_details[1]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
output_data = interpreter.get_tensor(output_details[1]['index'])
print(output_data)