Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TFLite toco failed to conver quantized model ( mobilenet_v1_1.0_224 ) to tflite format #19431

Closed
cefengxu opened this issue May 21, 2018 · 17 comments
Assignees
Labels
comp:lite TF Lite related issues

Comments

@cefengxu
Copy link

cefengxu commented May 21, 2018

Describe the Problem

Firstly, I download the mobilenet_v1_1.0_224 model from ( http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz ) ;

Then, i used Command belown to get a quantized model ( mobilenet_v1_1.0_224_frozen_quantized_graph.pb ) successfully.

bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=/tmp /mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_frozen.pb \
--inputs="input" \
--outputs="MobilenetV1/Predictions/Reshape_1" \
--out_graph=/tmp/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_frozen_quantized_graph.pb \
--transforms='add_default_attributes strip_unused_nodes(type=float, shape="1,224,224,3") 
remove_nodes(op=Identity, op=CheckNumerics) fold_constants(ignore_errors=true) 
fold_batch_norms fold_old_batch_norms quantize_weights quantize_nodes 
strip_unused_nodes sort_by_execution_order'

However , when I used TFLite toco Command to convert .pb to .lite format but ERROR was output

TFLite toco Build Command:

bazel run --config=opt \
  //tensorflow/contrib/lite/toco:toco -- \
  --input_file=/tmp/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_frozen_quantized_graph.pb \
  --output_file=/tmp/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_frozen_quantized_graph.lite \
  --input_format=TENSORFLOW_GRAPHDEF \
  --output_format=TFLITE \
  --input_shapes=1,224,224,3 \
  --mean_values=128 \
  --std_values=128 \
  --input_arrays="input" \
  --output_arrays="MobilenetV1/Predictions/Reshape_1" \
  --inference_type=QUANTIZED_UINT8 \
  --default_ranges_min=0 \
  --default_ranges_max=6 

ERROR OUTPUT:

2018-05-21 17:32:50.603908: F tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc:42] Check failed: IsConstantParameterArray(*model, bn_op->inputs[1]) && IsConstantParameterArray(*model, bn_op->inputs[2]) && IsConstantParameterArray(*model, bn_op->inputs[3]) Batch normalization resolution requires that mean, multiplier and offset arrays be constant.

ERROR LOG:

: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: QuantizeV2
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: QuantizeV2
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: QuantizeV2
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: Dequantize
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: Dequantize
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: Dequantize
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: QuantizeV2
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: QuantizeV2
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: QuantizeV2
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: QuantizeV2
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: QuantizeV2
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: Dequantize
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: QuantizeV2
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: Dequantize
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: QuantizeV2
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: QuantizedConv2D
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: RequantizationRange
……
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: Dequantize
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: QuantizeV2
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: QuantizedReshape
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: Dequantize
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: QuantizeV2
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: QuantizedReshape
: I tensorflow/contrib/lite/toco/import_tensorflow.cc:1326] Converting unsupported operation: Dequantize
2018-05-21 17:32:50.581333: I tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc:39] Before Removing unused ops: 352 operators, 853 arrays (0 quantized)
2018-05-21 17:32:50.601042: I tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc:39] Before general graph transformations: 352 operators, 853 arrays (0 quantized)
2018-05-21 17:32:50.603908: F tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc:42] Check failed: IsConstantParameterArray(*model, bn_op->inputs[1]) && IsConstantParameterArray(*model, bn_op->inputs[2]) && IsConstantParameterArray(*model, bn_op->inputs[3]) Batch normalization resolution requires that mean, multiplier and offset arrays be constant.

@freedomtan
Copy link
Contributor

freedomtan commented May 24, 2018

You might wanna read the TensorFlow guide on quantization to learn to use fake quantization technique instead of using the transform_graph to do direct quantization.

It seems that link is gone. Add another link here.

@eslambakr
Copy link

I have the same problem , any updates please?

@andrehentz
Copy link
Contributor

See freedomtan's comment. transform_graph is not supported by TF Lite. See the guide for retraining. If you are looking into weight-only quantization (and are well aware of the possible accuracy degradation) you can pass --quantize_weights=true to toco.

@andrehentz andrehentz added the comp:lite TF Lite related issues label Jun 12, 2018
@sxsxsx
Copy link

sxsxsx commented Oct 25, 2018

@cefengxu have you solved the problem?

@SubhashKsr
Copy link

@andrehentz did you mean the following when you said, "If you are looking into weight-only quantization (and are well aware of the possible accuracy degradation) you can pass --quantize_weights=true to toco."

Please help me understand this better.

bazel run --config=opt \
  //tensorflow/contrib/lite/toco:toco -- \
  --input_file=/tmp/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_frozen_quantized_graph.pb \
  --output_file=/tmp/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_frozen_quantized_graph.lite \
  --input_format=TENSORFLOW_GRAPHDEF \
  --output_format=TFLITE \
  --input_shapes=1,224,224,3 \
  --mean_values=128 \
  --std_values=128 \
  --input_arrays="input" \
  --output_arrays="MobilenetV1/Predictions/Reshape_1" \
  //--inference_type=QUANTIZED_UINT8 \
  --quantize_weights=true \
  --default_ranges_min=0 \
  --default_ranges_max=6

Please help me solve this error ?

2018-11-07 16:44:16.082872: F tensorflow/lite/toco/graph_transformations/resolve_batch_normalization.cc:45] Check failed: IsConstantParameterArray(*model, bn_op->inputs[1]) && IsConstantParameterArray(*model, bn_op->inputs[2]) && IsConstantParameterArray(*model, bn_op->inputs[3]) Batch normalization resolution requires that mean, multiplier and offset arrays be constant.
Abort trap: 6

I found the possible explanation for the issue here in #15336

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): No
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Mac OS Mojave (10.14 (18A391))
  • TensorFlow installed from (source or binary): source
  • TensorFlow version : 1.11.0
  • Python version: 2.7.10
  • Bazel version : 0.19.0
    -[GCC 4.2.1 Compatible Apple LLVM 10.0.0 (clang-1000.0.42)] on darwin
  • CUDA/cuDNN version: N/A (build without support CUDA)
  • GPU model and memory: N/A (build without support CUDA)
  • Exact command to reproduce:
./toco --input_format=TENSORFLOW_GRAPHDEF --input_file=/Users/......./Documents/data/output_graph.pb --output_format=TFLITE --output_file=/Users/......./Documents/data/facenet_model_quantized.lite --quantize_weights=true --input_arrays=input --output_arrays=embeddings --input_shapes=1,160,160,3 --mean_values=128 --std_values=128 --default_ranges_min=0 --default_ranges_max=6

This is the model, I have been trying to convert to .lite format

@freedomtan
Copy link
Contributor

@SubhashKsr I think what @andrehentz referred to is weights-only post-training quantization. See the guide and the example

@milinddeore
Copy link

milinddeore commented Nov 11, 2018

TFLite conversion can be done using SavedModel, i have given the link to the model below. This is as per the documentation here.

I have a model here, which is exported SavedModel using following code:

# SavedModel using simple_save()

ins = {"phase_train_placeholder":phase_train_placeholder}
outs = {"embeddings":embeddings}
tf.saved_model.simple_save(sess, '/content/generated/', ins, outs)

But while converting i am too getting the same error, in this case i am not freezing the model by myself. nor qualitizing it. Below code to convert SavedModel to TFLite:

import tensorflow as tf

saved_model_dir = '/content/generated/'

converter = tf.contrib.lite.TFLiteConverter.from_saved_model(saved_model_dir, input_arrays=['phase_train'], input_shapes=(1,160,160,3), 
                                                             output_arrays=['embeddings'])
tflite_model = converter.convert()
open("converted_model_savedModel.tflite", "wb").write(tflite_model)

Logs:

2018-11-11 02:27:06.150377: I tensorflow/lite/toco/import_tensorflow.cc:1280] Converting unsupported operation: RefSwitch
2018-11-11 02:27:06.150423: I tensorflow/lite/toco/import_tensorflow.cc:1280] Converting unsupported operation: AssignSub
2018-11-11 02:27:06.150476: I tensorflow/lite/toco/import_tensorflow.cc:1280] Converting unsupported operation: RefSwitch
2018-11-11 02:27:06.150510: I tensorflow/lite/toco/import_tensorflow.cc:1280] Converting unsupported operation: AssignSub
2018-11-11 02:27:06.790259: I tensorflow/lite/toco/graph_transformations/graph_transformations.cc:39] Before Removing unused ops: 5600 operators, 9398 arrays (0 quantized)
2018-11-11 02:27:07.384090: I tensorflow/lite/toco/graph_transformations/graph_transformations.cc:39] After Removing unused ops pass 1: 3582 operators, 6259 arrays (0 quantized)
2018-11-11 02:27:07.836694: I tensorflow/lite/toco/graph_transformations/graph_transformations.cc:39] Before general graph transformations: 3582 operators, 6259 arrays (0 quantized)
2018-11-11 02:27:07.839340: F tensorflow/lite/toco/graph_transformations/resolve_batch_normalization.cc:45] Check failed: IsConstantParameterArray(*model, bn_op->inputs[1]) && IsConstantParameterArray(*model, bn_op->inputs[2]) && IsConstantParameterArray(*model, bn_op->inputs[3]) Batch normalization resolution requires that mean, multiplier and offset arrays be constant.
Aborted (core dumped)

@scm-ns
Copy link

scm-ns commented Feb 13, 2019

@milinddeore Did you find a solution to this issue?

@scm-ns
Copy link

scm-ns commented Feb 13, 2019

@andrehentz But is the freeze_graph script supported? Do you what step causes this error to be thrown. Can you provide me some internal context on why this error might be occurring?

@andrehentz andrehentz assigned jdduke and gargn and unassigned andrehentz Feb 13, 2019
@andrehentz
Copy link
Contributor

I'm reopening this. Probably @jdduke or @gargn can help out with the issues you are facing. However, please consider opening separate issues, especially if your model is not 'mobilenet_v1_1.0_224'.

@maxjcohen
Copy link

Working on the same issue, would you be interested in quickly solving this problem over slack/appear/whatsapp (real time tools) ?
We are trying to push facenet to android as soon as possible.

@milinddeore
Copy link

@andrehentz can you please reopen this issue? Thanks.

@milinddeore
Copy link

milinddeore commented Feb 25, 2019

I could able to convert FaceNet .pb to .tflite model, and following are the instructions to do so:

We will quantise pre-trained Facenet model with 512 embedding size. This model is about 95MB in size before quantization. 

$ ls -l model_pc
total 461248
-rw-rw-r--@ 1 milinddeore  staff   95745767 Apr  9  2018 20180402-114759.pb

create a file inference_graph.py with following code:

import tensorflow as tf
from src.models import inception_resnet_v1
import sys
import click
from pathlib import Path

@click.command()
@click.argument('training_checkpoint_dir', type=click.Path(exists=True, file_okay=False, resolve_path=True))
@click.argument('eval_checkpoint_dir', type=click.Path(exists=True, file_okay=False, resolve_path=True))

def main(training_checkpoint_dir, eval_checkpoint_dir):
    traning_checkpoint = Path(training_checkpoint_dir) / "model-20180402-114759.ckpt-275"
    eval_checkpoint = Path(eval_checkpoint_dir) / "imagenet_facenet.ckpt"
    data_input = tf.placeholder(name='input', dtype=tf.float32, shape=[None, 160, 160, 3])
    output, _ = inception_resnet_v1.inference(data_input, keep_probability=0.8, phase_train=False, bottleneck_layer_size=512)
    label_batch= tf.identity(output, name='label_batch')
    embeddings = tf.identity(output, name='embeddings')
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        saver = tf.train.Saver()
        saver.restore(sess, traning_checkpoint.as_posix())
        save_path = saver.save(sess, eval_checkpoint.as_posix())
        print("Model saved in file: %s" % save_path)

if __name__ == "__main__":
     main()

Run this file on pre-trained model, would generate model for inference. Download pre-trained model and unzip it to model_pre_trained/ directory.
Make sure you have python ≥ 3.4 version.

python3 eval_graph.py model_pre_trained/ model_inference/

FaceNet provides freeze_graph.py file, which we will use to freeze the inference model. 

python3  src/freeze_graph.py model_inference/  my_facenet.pb

Once the frozen model is generated, time to convert it to .tflite 

$ tflite_convert --output_file model_mobile/my_facenet.tflite --graph_def_file my_facenet.pb  --input_arrays "input" --input_shapes "1,160,160,3" --output_arrays embeddings --output_format TFLITE --mean_values 128 --std_dev_values 128 --default_ranges_min 0  --default_ranges_max 6 --inference_type QUANTIZED_UINT8 --inference_input_type QUANTIZED_UINT8

Let us check the quantized model size:

$ ls -l model_mobile/
total 47232
-rw-r--r--@ 1 milinddeore  staff  23667888 Feb 25 13:39 my_facenet.tflite

Interpeter code:

 import numpy as np
 import tensorflow as tf


 # Load TFLite model and allocate tensors.
 interpreter = tf.lite.Interpreter(model_path="/Users/milinddeore/facenet/model_mobile/my_facenet.tflite")
 interpreter.allocate_tensors()

 # Get input and output tensors.
 input_details = interpreter.get_input_details()
 output_details = interpreter.get_output_details()

 # Test model on random input data.
 input_shape = input_details[0]['shape']
 input_data = np.array(np.random.random_sample(input_shape), dtype=np.uint8)
 interpreter.set_tensor(input_details[0]['index'], input_data)

 interpreter.invoke()
 output_data = interpreter.get_tensor(output_details[0]['index'])

 print('INPUTS: ')
 print(input_details)
 print('OUTPUTS: ')
 print(output_details)

Interpeter output:

$ python inout.py
INPUTS:
[{'index': 451, 'shape': array([  1, 160, 160,   3], dtype=int32), 'quantization': (0.0078125, 128L), 'name': 'input', 'dtype': <type 'numpy.uint8'>}]
OUTPUTS:
[{'index': 450, 'shape': array([  1, 512], dtype=int32), 'quantization': (0.0235294122248888, 0L), 'name': 'embeddings', 'dtype': <type 'numpy.uint8'>}]

Hope this helps!

@jdduke
Copy link
Member

jdduke commented Feb 28, 2019

Closing per the latest comment, glad you got it working!

@jdduke jdduke closed this as completed Feb 28, 2019
@Monikasinghjmi
Copy link

@milinddeore --Wouldn't the inference graph that will be produced (using the code you mentioned) corresponds to the pre-trained Facenet Protobuf file. If I want to deploy TFlite model on a IOS, it has to be trained on the training images I provide; which means I need a Tflite version of the trained model. I have gone through #23253, where it is mentioned that Training graph can't be converted. As per my understanding, on IOS, I require the trained model and its corresponding TFlite file.

How can a pretrained model's .pb , converted to tflite version, will be used to predict my data?

@Monikasinghjmi
Copy link

@cefengxu @jdduke --pls open this issue

@jdduke
Copy link
Member

jdduke commented Jul 24, 2019

@Monikasinghjmi, it's hard to tell exactly what you're asking, but it doesn't sound like it's a bug/issue, and is probably best suited as a question on StackOverflow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:lite TF Lite related issues
Projects
None yet
Development

No branches or pull requests