Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'Failed to apply delegate' error occurred when training with tflite using nnapi #51859

Open
hyeonsu94 opened this issue Sep 7, 2021 · 1 comment
Assignees
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower TFLiteNNAPIDelegate For issues related to TFLite NNAPI Delegate type:support Support issues

Comments

@hyeonsu94
Copy link

1. System information

Converter

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 20.04.2 LTS
  • TensorFlow installation (pip package or built from source): pip package (pip install tf-nightly)
  • TensorFlow library (version, if pip package or github SHA, if built from source):
    image

Mobile phone

  • Model: Samsung galaxy z-flip 3(SM-F711N)
  • OS&SDK version: Android11, API30

2. Code

- model & tflite converter

import tensorflow as tf

IMG_SIZE = 28

class Model(tf.Module):

  def __init__(self):
    self.model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(IMG_SIZE, IMG_SIZE)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    self.model.compile(
        optimizer='sgd',
        loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
        metrics=['accuracy'])
    self._LOSS_FN = tf.keras.losses.CategoricalCrossentropy()
    self._OPTIM = tf.optimizers.SGD()

  @tf.function(input_signature=[
      tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
      tf.TensorSpec([None, 10], tf.float32),
  ])
  def train(self, x, y):
    with tf.GradientTape() as tape:
      prediction = self.model(x)
      loss = self._LOSS_FN(prediction, y)
    gradients = tape.gradient(loss, self.model.trainable_variables)
    self._OPTIM.apply_gradients(
        zip(gradients, self.model.trainable_variables))
    result = {"loss": loss}
    for grad in gradients:
      result[grad.name] = grad
    return result

  @tf.function(input_signature=[tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32)])
  def predict(self, x):
    return {
        "output": self.model(x)
    }

  @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
  def save(self, checkpoint_path):
    tensor_names = [weight.name for weight in self.model.weights]
    tensors_to_save = [weight.read_value() for weight in self.model.weights]
    tf.raw_ops.Save(
        filename=checkpoint_path, tensor_names=tensor_names,
        data=tensors_to_save, name='save')
    return {
        "checkpoint_path": checkpoint_path
    }

  @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
  def restore(self, checkpoint_path):
    restored_tensors = {}
    for var in self.model.weights:
      restored = tf.raw_ops.Restore(
          file_pattern=checkpoint_path, tensor_name=var.name, dt=var.dtype,
          name='restore')
      var.assign(restored)
      restored_tensors[var.name] = restored
    return restored_tensors

SAVED_MODEL_DIR = "saved_model"
m= Model()
tf.saved_model.save(
    m,
    SAVED_MODEL_DIR,
    signatures={
        'train':
            m.train.get_concrete_function(),
        'infer':
            m.predict.get_concrete_function(),
        'save':
            m.save.get_concrete_function(),
        'restore':
            m.restore.get_concrete_function(),
    })

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
]
converter.experimental_enable_resource_variables = True
tflite_model = converter.convert()
with open('model.tflite','wb') as f:
    f.write(tflite_model)
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
    print("main")

- android code

package com.tvstorm.tflitetest
import android.os.Bundle
import android.util.Log
import androidx.appcompat.app.AppCompatActivity
import org.apache.commons.io.FileUtils.copyInputStreamToFile
import org.apache.commons.io.IOUtils
import org.tensorflow.lite.Interpreter
import java.io.File
import java.io.InputStream
import java.nio.FloatBuffer
class MainActivity : AppCompatActivity() {
    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)
        runTFLite()
    }
    val NUM_EPOCHS = 100
    val NUM_TRAININGS = 60000
    val trainImages = Array(NUM_TRAININGS) { Array(28) { FloatArray(28) } }
    val trainLabels = Array(NUM_TRAININGS) { FloatArray(10) }
    var NUM_TESTS = 10
    var testImages = Array(NUM_TESTS) { Array(28) { FloatArray(28) } }
    var output = Array(NUM_TESTS) { FloatArray(10) }
    fun runTFLite() {
        val options = Interpreter.Options().apply { // use nnapi option
            setUseNNAPI(true)
        }
        val interpreter =
            Interpreter(convertInputStreamToFile(resources.openRawResource(R.raw.model)), options)
        for (i in 0 until NUM_EPOCHS) {
            val inputs: MutableMap<String, Any> = HashMap()
            inputs["x"] = trainImages
            inputs["y"] = trainLabels
            val outputs: MutableMap<String, Any> = HashMap()
            val loss: FloatBuffer = FloatBuffer.allocate(1)
            outputs["loss"] = loss
            interpreter.runSignature(inputs, outputs, "infer")
            Log.d("LOSS", "loss: ${loss[0]}")
        }
    }
    fun convertInputStreamToFile(inputStream: InputStream): File {
        val tempFile = File.createTempFile(java.lang.String.valueOf(inputStream.hashCode()), ".tmp")
        tempFile.deleteOnExit()
        copyInputStreamToFile(inputStream, tempFile)
        return tempFile
    }

}

3. Failure after conversion

https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/on_device_training/overview.ipynb

We have succeeded in training on android device with tflite according to the above document.

we tested using nnapi with the interpreter option, but it raises error with the message 'Failed to apply delegate' that is relevant to static tensor size.

Does 'On device training' not support nnapi yet? Is there a loadmap for nnapi?

4. error log

2021-09-07 11:28:23.521 20502-20502/com.tvstorm.tflitetest E/AndroidRuntime: FATAL EXCEPTION: main
    Process: com.tvstorm.tflitetest, PID: 20502
    java.lang.RuntimeException: Unable to start activity ComponentInfo{com.tvstorm.tflitetest/com.tvstorm.tflitetest.MainActivity}: java.lang.IllegalArgumentException: Internal error: Failed to apply delegate: Attempting to use a delegate that only supports static-sized tensors with a graph that has dynamic-sized tensors (tensor#3 is a dynamic-sized tensor).
        at android.app.ActivityThread.performLaunchActivity(ActivityThread.java:3832)
        at android.app.ActivityThread.handleLaunchActivity(ActivityThread.java:4008)
        at android.app.servertransaction.LaunchActivityItem.execute(LaunchActivityItem.java:85)
        at android.app.servertransaction.TransactionExecutor.executeCallbacks(TransactionExecutor.java:135)
        at android.app.servertransaction.TransactionExecutor.execute(TransactionExecutor.java:95)
        at android.app.ActivityThread$H.handleMessage(ActivityThread.java:2317)
        at android.os.Handler.dispatchMessage(Handler.java:106)
        at android.os.Looper.loop(Looper.java:247)
        at android.app.ActivityThread.main(ActivityThread.java:8618)
        at java.lang.reflect.Method.invoke(Native Method)
        at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:602)
        at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:1130)
     Caused by: java.lang.IllegalArgumentException: Internal error: Failed to apply delegate: Attempting to use a delegate that only supports static-sized tensors with a graph that has dynamic-sized tensors (tensor#3 is a dynamic-sized tensor).
        at org.tensorflow.lite.NativeInterpreterWrapper.applyDelegate(Native Method)
        at org.tensorflow.lite.NativeInterpreterWrapper.applyDelegates(NativeInterpreterWrapper.java:487)
        at org.tensorflow.lite.NativeInterpreterWrapper.init(NativeInterpreterWrapper.java:88)
        at org.tensorflow.lite.NativeInterpreterWrapper.<init>(NativeInterpreterWrapper.java:51)
        at org.tensorflow.lite.NativeInterpreterWrapperExperimental.<init>(NativeInterpreterWrapperExperimental.java:40)
        at org.tensorflow.lite.Interpreter.<init>(Interpreter.java:196)
        at com.tvstorm.tflitetest.MainActivity.runTFLite(MainActivity.kt:32)
        at com.tvstorm.tflitetest.MainActivity.onCreate(MainActivity.kt:17)
        at android.app.Activity.performCreate(Activity.java:8215)
        at android.app.Activity.performCreate(Activity.java:8199)
        at android.app.Instrumentation.callActivityOnCreate(Instrumentation.java:1309)
        at android.app.ActivityThread.performLaunchActivity(ActivityThread.java:3805)
        at android.app.ActivityThread.handleLaunchActivity(ActivityThread.java:4008) 
        at android.app.servertransaction.LaunchActivityItem.execute(LaunchActivityItem.java:85) 
        at android.app.servertransaction.TransactionExecutor.executeCallbacks(TransactionExecutor.java:135) 
        at android.app.servertransaction.TransactionExecutor.execute(TransactionExecutor.java:95) 
        at android.app.ActivityThread$H.handleMessage(ActivityThread.java:2317) 
        at android.os.Handler.dispatchMessage(Handler.java:106) 
        at android.os.Looper.loop(Looper.java:247) 
        at android.app.ActivityThread.main(ActivityThread.java:8618)
@hyeonsu94 hyeonsu94 added the TFLiteConverter For issues related to TFLite converter label Sep 7, 2021
@abattery abattery added TFLiteNNAPIDelegate For issues related to TFLite NNAPI Delegate type:support Support issues and removed TFLiteConverter For issues related to TFLite converter labels Sep 7, 2021
@sushreebarsa sushreebarsa removed their assignment Sep 8, 2021
@miaowang14
Copy link

NNAPI delegate by default only support models with static shapes.

Quick question, does the model require dynamic shapes? If not, there might be something needed in the conversion process to force using static shapes. @srjoglekar246

@chunduriv chunduriv added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Nov 1, 2021
@mohantym mohantym added stat:awaiting response Status - Awaiting response from author stat:awaiting tensorflower Status - Awaiting response from tensorflower and removed stat:awaiting tensorflower Status - Awaiting response from tensorflower stat:awaiting response Status - Awaiting response from author labels Nov 10, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower TFLiteNNAPIDelegate For issues related to TFLite NNAPI Delegate type:support Support issues
Projects
None yet
Development

No branches or pull requests

6 participants