Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TFLite: Support grouped convolutions #40044

Closed
lgeiger opened this issue Jun 1, 2020 · 18 comments
Closed

TFLite: Support grouped convolutions #40044

lgeiger opened this issue Jun 1, 2020 · 18 comments
Assignees
Labels
comp:lite TF Lite related issues stat:awaiting response Status - Awaiting response from author type:feature Feature requests

Comments

@lgeiger
Copy link
Contributor

lgeiger commented Jun 1, 2020

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Colab
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (or github SHA if from source): tf-nightly==2.3.0.dev20200531

Motivation
#25818 added support for native grouped convolutions a year ago. This feature is now also available via Keras layers in the latest nightly release (#36773, #39516).
Converting a model using grouped convolutions to TFLite works fine, though the TFLite runtime currently doesn't support this feature and will throw an error when trying to allocate the Tensors.

It would be great to have this feature available in TFLite in order to have consistent behaviour accross TensorFlow and TFLite. The PRs linked above provide more detail about why this feature something that people would want to use.

Standalone code to reproduce the issue
The issue can be reproduced using this colab notebook.

Any other info / logs

I guess adding a reference implementation and implementing optimized kernels in the XNNPack delegate would be pretty straight forward as it already has native support for grouped convolutions:

/*groups=*/1, static_cast<size_t>(input_channels),

Though, I am not sure how much effort it would be to support it in the other optimized code paths with Ruy.

If there is a fundamental reason why support for grouped convolutions cannot be added to TFLite it would be great to handle this in the MLIR based converter and translate native TF grouped convolutions to a naive loop based implementation using tfl.split and tfl.concat which would allow people to use TF group convolutions and and fall back to a loop based implementation in TFLite for now.

@lgeiger lgeiger added the comp:lite TF Lite related issues label Jun 1, 2020
@Saduf2019 Saduf2019 added the type:feature Feature requests label Jun 1, 2020
@Saduf2019 Saduf2019 assigned ymodak and unassigned Saduf2019 Jun 1, 2020
@ymodak ymodak assigned srjoglekar246 and unassigned ymodak Jun 1, 2020
@lgeiger
Copy link
Contributor Author

lgeiger commented Jul 6, 2020

@srjoglekar246 Do you know if anyone is looking into that?

@lgeiger
Copy link
Contributor Author

lgeiger commented Jul 24, 2020

@srjoglekar246 @talumbau Do you have any information whether this is on the roadmap?

@reuvenperetz
Copy link

Hi,
Any news?

@srjoglekar246
Copy link
Contributor

Sorry for the late reply on this :-(. @talumbau is looking into this, and has a plan forward. We were busy with some other bugs last quarter, but we have a TODO for it to land this quarter. Thanks for your patience!

@simonmaurer
Copy link
Contributor

news? related: Group convolutions for CPUs

@will-rice
Copy link

will-rice commented Jun 3, 2021

If I wanted to try and attempt this, where would I start? This seems relevant because in python it's just a different kernel shape. https://www.tensorflow.org/lite/guide/ops_version#change_c_structures_and_kernel_implementation

@gitE0Z9
Copy link

gitE0Z9 commented Jun 21, 2021

Does this issue have any progress?
I have done a small experiment which showed that group convolution is the reason why tflite model failed to do inference with this error message.

TF version == 2.5.0

image

experiment model

test_data = tf.constant(0,shape=[1,320,320,3])
model = tf.keras.Sequential([
                tf.keras.layers.InputLayer([320,320,3]),
                tf.keras.layers.Conv2D(8,3),
                tf.keras.layers.Conv2D(8,3,groups=8)
])
model.predict(test_data).shape
tf.keras.models.save_model(model,'/content/test')

and convert command
tflite_convert --saved_model /content/test --output_file /content/test.tflite

@srjoglekar246 srjoglekar246 assigned thaink and unassigned talumbau Jun 21, 2021
@PINTO0309
Copy link
Contributor

PINTO0309 commented Aug 12, 2021

I know this won't be helpful to everyone, but this is my special work-around so far. If this is confusing, please ignore it.

I implemented GoupConvolusion with the standard Conv2D and Split, Concat, although I may have failed to transpose the weights. This model is chaotic.

PINTO0309/PINTO_model_zoo#15
https://github.com/PINTO0309/openvino2tensorflow

  • Midasnet - Float32 - GroupConvolusion - TFLite(.tflite)
    image

I have been waiting for this issue to be resolved for a long time.

@daeing
Copy link

daeing commented Nov 3, 2021

Does this issue have any progress? I have done a small experiment which showed that group convolution is the reason why tflite model failed to do inference with this error message.

TF version == 2.5.0

image

experiment model

test_data = tf.constant(0,shape=[1,320,320,3])
model = tf.keras.Sequential([
                tf.keras.layers.InputLayer([320,320,3]),
                tf.keras.layers.Conv2D(8,3),
                tf.keras.layers.Conv2D(8,3,groups=8)
])
model.predict(test_data).shape
tf.keras.models.save_model(model,'/content/test')

and convert command tflite_convert --saved_model /content/test --output_file /content/test.tflite

I meet the same problem, when my h5 model trans to tflite, I use tflite to do inference, but conv2d with groups failed. Does this issue have any progress?

@ydshieh
Copy link

ydshieh commented Nov 17, 2021

I also faced the same issue.

@leondgarse
Copy link
Contributor

leondgarse commented Dec 22, 2021

I made a work around in my repo Github leondgarse/keras_cv_attention_models, that replacing Conv2D groups != 1 with split -> conv -> concat like

  • Basic test:
    !pip install keras-cv-attention-models
    
    import tensorflow as tf
    import numpy as np
    from tensorflow import keras
    from keras_cv_attention_models.imagenet import eval_func
    from keras_cv_attention_models import model_surgery
    
    mm = keras.Sequential([keras.layers.InputLayer([320, 320, 32]), keras.layers.Conv2D(64, 3, groups=8)])
    bb = model_surgery.convert_groups_conv2d_2_split_conv2d(mm)
    test_inputs = tf.random.uniform([1, *mm.input_shape[1:]])
    print(np.allclose(mm(test_inputs), bb(test_inputs)))
    # True
    
    converter = tf.lite.TFLiteConverter.from_keras_model(bb)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    open("aa_dynamic.tflite", "wb").write(converter.convert())
    print(np.allclose(mm(test_inputs), eval_func.TFLiteModelInterf('aa_dynamic.tflite')(test_inputs), atol=1e-6))
    # True

We have a discussion here Freezing a trained keras CV attention model #17, where you can find more detail usage.

@reuvenperetz
Copy link

Hi,
In Tensorflow 2.8 an error is thrown earlier (during conversion):

Traceback (most recent call last):
...
tensorflow.lite.python.convert_phase.ConverterError: /.../lib/python3.8/site-packages/tensorflow/python/saved_model/save.py:1369:0: error: 'tf.Conv2D' op is neither a custom op nor a flex op
<unknown>:0: note: loc(fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]): called from
/.../lib/python3.8/site-packages/tensorflow/python/saved_model/save.py:1369:0: note: Error code: ERROR_NEEDS_FLEX_OPS
<unknown>:0: error: failed while converting: 'main': 
Some ops are not supported by the native TFLite runtime, you can enable TF kernels fallback using TF Select. See instructions: https://www.tensorflow.org/lite/guide/ops_select 
TF Select ops: Conv2D
Details:
	tf.Conv2D(tensor<?x26x26x32xf32>, tensor<3x3x8x32xf32>) -> (tensor<?x24x24x32xf32>) : {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}

@makcedward
Copy link

makcedward commented Apr 7, 2022

Hi, In Tensorflow 2.8 an error is thrown earlier (during conversion):

Traceback (most recent call last):
...
tensorflow.lite.python.convert_phase.ConverterError: /.../lib/python3.8/site-packages/tensorflow/python/saved_model/save.py:1369:0: error: 'tf.Conv2D' op is neither a custom op nor a flex op
<unknown>:0: note: loc(fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]): called from
/.../lib/python3.8/site-packages/tensorflow/python/saved_model/save.py:1369:0: note: Error code: ERROR_NEEDS_FLEX_OPS
<unknown>:0: error: failed while converting: 'main': 
Some ops are not supported by the native TFLite runtime, you can enable TF kernels fallback using TF Select. See instructions: https://www.tensorflow.org/lite/guide/ops_select 
TF Select ops: Conv2D
Details:
	tf.Conv2D(tensor<?x26x26x32xf32>, tensor<3x3x8x32xf32>) -> (tensor<?x24x24x32xf32>) : {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}

You need to use SELECT OPS. Add the following code for that

converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OP]

@reuvenperetz
Copy link

Hi, In Tensorflow 2.8 an error is thrown earlier (during conversion):

Traceback (most recent call last):
...
tensorflow.lite.python.convert_phase.ConverterError: /.../lib/python3.8/site-packages/tensorflow/python/saved_model/save.py:1369:0: error: 'tf.Conv2D' op is neither a custom op nor a flex op
<unknown>:0: note: loc(fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]): called from
/.../lib/python3.8/site-packages/tensorflow/python/saved_model/save.py:1369:0: note: Error code: ERROR_NEEDS_FLEX_OPS
<unknown>:0: error: failed while converting: 'main': 
Some ops are not supported by the native TFLite runtime, you can enable TF kernels fallback using TF Select. See instructions: https://www.tensorflow.org/lite/guide/ops_select 
TF Select ops: Conv2D
Details:
	tf.Conv2D(tensor<?x26x26x32xf32>, tensor<3x3x8x32xf32>) -> (tensor<?x24x24x32xf32>) : {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}

You need to use SELECT OPS. Add the following code for that

converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OP]

But inference still fails. I convert and infer this way:

verter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_index = interpreter.get_input_details()[0]["index"]
test_image = np.random.randn(1, 28, 28, 1).astype(np.float32)
interpreter.set_tensor(input_index, test_image)
# Run inference.
interpreter.invoke()

It fails with:

Process finished with exit code 139 (interrupted by signal 11: SIGSEGV)

@InSightSuen
Copy link

I made a work around in my repo Github leondgarse/keras_cv_attention_models, that replacing Conv2D groups != 1 with split -> conv -> concat like

  • Basic test:
    !pip install keras-cv-attention-models
    
    import tensorflow as tf
    import numpy as np
    from tensorflow import keras
    from keras_cv_attention_models.imagenet import eval_func
    from keras_cv_attention_models import model_surgery
    
    mm = keras.Sequential([keras.layers.InputLayer([320, 320, 32]), keras.layers.Conv2D(64, 3, groups=8)])
    bb = model_surgery.convert_groups_conv2d_2_split_conv2d(mm)
    test_inputs = tf.random.uniform([1, *mm.input_shape[1:]])
    print(np.allclose(mm(test_inputs), bb(test_inputs)))
    # True
    
    converter = tf.lite.TFLiteConverter.from_keras_model(bb)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    open("aa_dynamic.tflite", "wb").write(converter.convert())
    print(np.allclose(mm(test_inputs), eval_func.TFLiteModelInterf('aa_dynamic.tflite')(test_inputs), atol=1e-6))
    # True

We have a discussion here Freezing a trained keras CV attention model #17, where you can find more detail usage.

Works for me, thx!
My model contains depthwise conv but with different length strides in row and column ( strides=[2, 1] ), which layers.DepthwiseConv2D not implmented currently.

Tensorflow Verison: 2.7.0

@smohan-ambarella
Copy link

Any update on this issue ? Does tflite now support group convolution ?

@pjpratik pjpratik self-assigned this Dec 29, 2022
@pjpratik
Copy link
Contributor

@lgeiger This issue seems to be resolved in TF 2.11. Can you please check this gist and let us know if helps. Thank you!

@pjpratik pjpratik added the stat:awaiting response Status - Awaiting response from author label Dec 29, 2022
@lgeiger
Copy link
Contributor Author

lgeiger commented Jan 3, 2023

@lgeiger This issue seems to be resolved in TF 2.11. Can you please check this gist and let us know if helps. Thank you!

👍 This is resolved now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:lite TF Lite related issues stat:awaiting response Status - Awaiting response from author type:feature Feature requests
Projects
None yet
Development

No branches or pull requests