/
convert.py
988 lines (874 loc) · 42.1 KB
/
convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converts a frozen graph into a TFLite FlatBuffer."""
import distutils.spawn
import enum # pylint: disable=g-bad-import-order
import os as _os
import platform as _platform
import subprocess as _subprocess
import tempfile as _tempfile
import warnings
import six
from tensorflow.lite.python import lite_constants
from tensorflow.lite.python import util
from tensorflow.lite.python import wrap_toco
from tensorflow.lite.python.convert_phase import Component
from tensorflow.lite.python.convert_phase import convert_phase
from tensorflow.lite.python.convert_phase import ConverterError
from tensorflow.lite.python.convert_phase import SubComponent
from tensorflow.lite.python.metrics.wrapper import metrics_wrapper as _metrics_wrapper
from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2
from tensorflow.lite.toco import toco_flags_pb2 as _conversion_flags_pb2
from tensorflow.lite.toco import types_pb2 as _types_pb2
from tensorflow.lite.tools import flatbuffer_utils
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import resource_loader as _resource_loader
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export as _tf_export
def _is_quantized_input_stats_required(
conversion_flags: _conversion_flags_pb2.TocoFlags()) -> bool:
"""Checks if the `quantized_input_stats` flag is required for conversion.
Args:
conversion_flags: A protocol buffer describing the conversion process.
Returns:
True, if the `inference_type` or the `inference_input_type` is a quantized
type and it is not post training quantization, else False.
"""
quantized_inference_types = ([
_types_pb2.QUANTIZED_UINT8, _types_pb2.QUANTIZED_INT8
])
return ((conversion_flags.inference_type in quantized_inference_types or
conversion_flags.inference_input_type in quantized_inference_types)
and not conversion_flags.post_training_quantize)
def convert_tensor_tf_type_to_tflite_type(tf_type: dtypes.DType,
usage: str = ""
) -> _types_pb2.IODataType:
"""Convert tensor type from tf type to tflite type.
Args:
tf_type: TensorFlow type.
usage: Text describing the reason for invoking this function.
Raises:
ValueError: If `tf_type` is unsupported.
Returns:
tflite_type: TFLite type. Refer to lite/toco/types.proto.
"""
mapping = {
dtypes.float16: _types_pb2.FLOAT16,
dtypes.float32: _types_pb2.FLOAT,
dtypes.float64: _types_pb2.FLOAT64,
dtypes.int8: _types_pb2.INT8,
dtypes.int16: _types_pb2.INT16,
dtypes.uint16: _types_pb2.UINT16,
dtypes.int32: _types_pb2.INT32,
dtypes.int64: _types_pb2.INT64,
dtypes.uint8: _types_pb2.UINT8,
dtypes.uint32: _types_pb2.UINT32,
dtypes.uint64: _types_pb2.UINT64,
dtypes.string: _types_pb2.STRING,
dtypes.bool: _types_pb2.BOOL,
dtypes.complex64: _types_pb2.COMPLEX64,
dtypes.complex128: _types_pb2.COMPLEX128,
}
tflite_type = mapping.get(tf_type)
if tflite_type is None:
raise ValueError(
"Unsupported TensorFlow type `{0}` provided for the {1}".format(
tf_type, usage))
return tflite_type
# Only a few restricted tensor types are allowed for explicitly setting
# inference/input/output types.
def convert_inference_tf_type_to_tflite_type(tf_type: dtypes.DType,
usage: str = ""
) -> _types_pb2.IODataType:
"""Convert inference type from tf type to tflite type.
Args:
tf_type: TensorFlow type.
usage: Text describing the reason for invoking this function.
Raises:
ValueError: If `tf_type` is unsupported.
Returns:
tflite_type: TFLite type. Refer to lite/toco/types.proto.
"""
mapping = {
dtypes.float32: _types_pb2.FLOAT,
dtypes.uint8: _types_pb2.QUANTIZED_UINT8,
dtypes.int8: _types_pb2.QUANTIZED_INT8,
dtypes.int16: _types_pb2.QUANTIZED_INT16,
}
tflite_type = mapping.get(tf_type)
if tflite_type is None:
raise ValueError(
"Unsupported TensorFlow type `{0}` provided for the {1}".format(
tf_type, usage))
return tflite_type
# Find the deprecated conversion binary using the resource loader if using from
# bazel, otherwise we are in a pip where console_scripts already has the tool.
if lite_constants.EXPERIMENTAL_USE_TOCO_API_DIRECTLY:
_deprecated_conversion_binary = ""
else:
_deprecated_conversion_binary = _resource_loader.get_path_to_datafile(
"../toco/python/toco_from_protos")
if not _os.path.exists(_deprecated_conversion_binary):
_deprecated_conversion_binary = "toco_from_protos"
def _try_convert_to_unicode(output):
if output is None:
return u""
if isinstance(output, bytes):
try:
return six.ensure_text(output)
except UnicodeDecodeError:
pass
return output
@_tf_export("lite.OpsSet")
class OpsSet(enum.Enum):
"""Enum class defining the sets of ops available to generate TFLite models.
WARNING: Experimental interface, subject to change.
"""
# Convert model using TensorFlow Lite builtin ops.
TFLITE_BUILTINS = "TFLITE_BUILTINS"
# Convert model using TensorFlow ops. Not all TensorFlow ops are available.
# WARNING: Experimental interface, subject to change.
SELECT_TF_OPS = "SELECT_TF_OPS"
# Convert model using only TensorFlow Lite quantized int8 operations.
# Specifying this will throw an error for operations that do not yet have
# quantized implementations.
TFLITE_BUILTINS_INT8 = "TFLITE_BUILTINS_INT8"
# Convert model using only TensorFlow Lite operations with quantized int8
# weights, int16 activations and int64 bias.
# Specifying this will throw an error for operations that do not yet have
# quantized implementations.
# This quantization mode may be used in models for super-resolution,
# audio signal processing or image de-noising. It improves accuracy
# significantly, but only slightly increases the model size.
# WARNING: These ops are currently experimental and have not yet been
# finalized.
# They are only compatible with CPU execution, and have not been optimized for
# production.
EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 = (
"EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8")
def __str__(self):
return str(self.value)
@staticmethod
def get_options():
"""Returns a list of OpsSet options as a list of strings."""
return [str(option) for option in list(OpsSet)]
@convert_phase(Component.OPTIMIZE_TFLITE_MODEL, SubComponent.QUANTIZE)
def mlir_quantize(input_data_str,
disable_per_channel=False,
fully_quantize=False,
inference_type=_types_pb2.QUANTIZED_INT8,
input_data_type=dtypes.float32,
output_data_type=dtypes.float32,
enable_numeric_verify=False,
enable_whole_model_verify=False,
denylisted_ops=None,
denylisted_nodes=None):
"""Quantize `input_data_str` with calibration results.
Args:
input_data_str: Input data in serialized form (e.g. a TFLITE model with
calibration results).
disable_per_channel: Bool indicating whether to do per-channel or per-tensor
quantization
fully_quantize: Bool indicating whether to fully quantize the model. Besides
model body, the input/output will be quantized as well.
inference_type: Data type for the activations. The default value is int8.
input_data_type: Data type for the inputs. The default value is float32.
output_data_type: Data type for the outputs. The default value is float32.
enable_numeric_verify: Experimental. Subject to change. Bool indicating
whether to add NumericVerify ops into the debug mode quantized model.
enable_whole_model_verify: Experimental. Subject to change. Bool indicating
whether to add verification for layer by layer, or on whole model. When
disabled (per-layer) float and quantized ops will be run from same input
(output of previous quantized layer). When enabled, float and quantized
ops will run with respective float and quantized output of previous ops.
denylisted_ops: Experimental. Subject to change. Set of ops to denylist.
denylisted_nodes: Experimental. Subject to change. Set of notes to denylist.
Returns:
Quantized model in serialized form (e.g. a TFLITE model) with floating-point
inputs and outputs.
"""
return wrap_toco.wrapped_experimental_mlir_quantize(
input_data_str, disable_per_channel, fully_quantize, inference_type,
convert_tensor_tf_type_to_tflite_type(input_data_type),
convert_tensor_tf_type_to_tflite_type(output_data_type),
enable_numeric_verify, enable_whole_model_verify, denylisted_ops,
denylisted_nodes)
@convert_phase(Component.OPTIMIZE_TFLITE_MODEL, SubComponent.SPARSIFY)
def mlir_sparsify(input_data_str):
"""Sparsify `input_data_str` to encode sparse tensor with proper format.
Args:
input_data_str: Input data in serialized form (e.g. a TFLITE model).
Returns:
Sparsified model in serialized form (e.g. a TFLITE model).
"""
return wrap_toco.wrapped_experimental_mlir_sparsify(input_data_str)
def register_custom_opdefs(custom_opdefs_list):
"""Register the given custom opdefs to the TensorFlow global op registry.
Args:
custom_opdefs_list: String representing the custom ops OpDefs that are
included in the GraphDef.
Returns:
True if the registration is successfully completed.
"""
return wrap_toco.wrapped_register_custom_opdefs(custom_opdefs_list)
def convert(model_flags_str,
conversion_flags_str,
input_data_str,
debug_info_str=None,
enable_mlir_converter=True):
"""Converts `input_data_str` to a TFLite model.
Args:
model_flags_str: Serialized proto describing model properties, see
`model_flags.proto`.
conversion_flags_str: Serialized proto describing conversion properties, see
`toco/toco_flags.proto`.
input_data_str: Input data in serialized form (e.g. a graphdef is common, or
it can be hlo text or proto)
debug_info_str: Serialized `GraphDebugInfo` proto describing logging
information. (default None)
enable_mlir_converter: Enables MLIR-based conversion. (default True)
Returns:
Converted model in serialized form (e.g. a TFLITE model is common).
Raises:
ConverterError: When conversion fails in TFLiteConverter, usually due to
ops not being supported.
RuntimeError: When conversion fails, an exception is raised with the error
message embedded.
"""
# Historically, deprecated conversion failures would trigger a crash, so we
# attempt to run the converter out-of-process. The current MLIR conversion
# pipeline surfaces errors instead, and can be safely run in-process.
if enable_mlir_converter or not _deprecated_conversion_binary:
try:
model_str = wrap_toco.wrapped_toco_convert(model_flags_str,
conversion_flags_str,
input_data_str, debug_info_str,
enable_mlir_converter)
return model_str
except Exception as e:
converter_error = ConverterError(str(e))
for error_data in _metrics_wrapper.retrieve_collected_errors():
converter_error.append_error(error_data)
raise converter_error
return _run_deprecated_conversion_binary(model_flags_str,
conversion_flags_str, input_data_str,
debug_info_str)
@convert_phase(Component.CONVERT_TF_TO_TFLITE_MODEL,
SubComponent.CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER)
def _run_deprecated_conversion_binary(model_flags_str,
conversion_flags_str,
input_data_str,
debug_info_str=None):
"""Convert `input_data_str` using deprecated conversion binary.
Args:
model_flags_str: Serialized proto describing model properties, see
`model_flags.proto`.
conversion_flags_str: Serialized proto describing TFLite converter
properties, see `toco/toco_flags.proto`.
input_data_str: Input data in serialized form (e.g. a graphdef is common)
debug_info_str: Serialized `GraphDebugInfo` proto describing logging
information. (default None)
Returns:
Converted model in serialized form (e.g. a TFLITE model is common).
Raises:
ConverterError: When cannot find the deprecated conversion binary.
RuntimeError: When conversion fails, an exception is raised with the error
message embedded.
"""
if distutils.spawn.find_executable(_deprecated_conversion_binary) is None:
raise ConverterError("""Could not find `toco_from_protos` binary, make sure
your virtualenv bin directory or pip local bin directory is in your path.
In particular, if you have installed TensorFlow with --user, make sure you
add the install directory to your path.
For example:
Linux: export PATH=$PATH:~/.local/bin/
Mac: export PATH=$PATH:~/Library/Python/<version#>/bin
Alternative, use virtualenv.""")
# Windows and TemporaryFile are not that useful together,
# since you cannot have two readers/writers. So we have to
# make the temporaries and close and delete them explicitly.
conversion_filename, model_filename, input_filename, output_filename = (None,
None,
None,
None)
try:
# Build all input files
with _tempfile.NamedTemporaryFile(delete=False) as fp_conversion, \
_tempfile.NamedTemporaryFile(delete=False) as fp_model, \
_tempfile.NamedTemporaryFile(delete=False) as fp_input, \
_tempfile.NamedTemporaryFile(delete=False) as fp_debug:
conversion_filename = fp_conversion.name
input_filename = fp_input.name
model_filename = fp_model.name
debug_filename = fp_debug.name
fp_model.write(model_flags_str)
fp_conversion.write(conversion_flags_str)
fp_input.write(six.ensure_binary(input_data_str))
debug_info_str = debug_info_str if debug_info_str else ""
# if debug_info_str contains a "string value", then the call to
# fp_debug.write(debug_info_str) will fail with the following error
#
# TypeError: a bytes-like object is required, not 'str'
#
# Some of the subtests within the "convert_test" unit-test fail
# with the error shown above. So watch out for that scenario and
# convert debug_info_str to bytes where needed
if not isinstance(debug_info_str, bytes):
fp_debug.write(debug_info_str.encode("utf-8"))
else:
fp_debug.write(debug_info_str)
# Reserve an output file
with _tempfile.NamedTemporaryFile(delete=False) as fp:
output_filename = fp.name
# Run
cmd = [
_deprecated_conversion_binary,
model_filename,
conversion_filename,
input_filename,
output_filename,
"--debug_proto_file={}".format(debug_filename),
]
cmdline = " ".join(cmd)
is_windows = _platform.system() == "Windows"
proc = _subprocess.Popen(
cmdline,
shell=True,
stdout=_subprocess.PIPE,
stderr=_subprocess.STDOUT,
close_fds=not is_windows)
stdout, stderr = proc.communicate()
exitcode = proc.returncode
if exitcode == 0:
with open(output_filename, "rb") as fp:
return fp.read()
else:
stdout = _try_convert_to_unicode(stdout)
stderr = _try_convert_to_unicode(stderr)
raise ConverterError("See console for info.\n%s\n%s\n" % (stdout, stderr))
finally:
# Must manually cleanup files.
for filename in [
conversion_filename, input_filename, model_filename, output_filename
]:
try:
_os.unlink(filename)
except (OSError, TypeError):
pass
def build_model_flags(change_concat_input_ranges=False,
allow_nonexistent_arrays=False,
saved_model_dir=None,
saved_model_version=0,
saved_model_tags=None,
saved_model_exported_names=None,
**_):
"""Builds the model flags object from params.
Args:
change_concat_input_ranges: Boolean to change behavior of min/max ranges for
inputs and outputs of the concat operator for quantized models. Changes
the ranges of concat operator overlap when true. (default False)
allow_nonexistent_arrays: Allow specifying array names that don't exist or
are unused in the final graph. (default False)
saved_model_dir: Filepath of the saved model to be converted. This value
will be non-empty only when the saved model import path will be used.
Otherwises, the graph def-based conversion will be processed.
saved_model_version: SavedModel file format version of The saved model file
to be converted. This value will be set only when the SavedModel import
path will be used.
saved_model_tags: Set of string saved model tags, formatted in the
comma-separated value. This value will be set only when the SavedModel
import path will be used.
saved_model_exported_names: Names to be exported (default: export all) when
the saved model import path is on. This value will be set only when the
SavedModel import path will be used.
Returns:
model_flags: protocol buffer describing the model.
"""
model_flags = _model_flags_pb2.ModelFlags()
model_flags.change_concat_input_ranges = change_concat_input_ranges
model_flags.allow_nonexistent_arrays = allow_nonexistent_arrays
if saved_model_dir:
model_flags.saved_model_dir = saved_model_dir
model_flags.saved_model_version = saved_model_version
if saved_model_tags:
model_flags.saved_model_tags.extend(saved_model_tags)
if saved_model_exported_names:
model_flags.saved_model_exported_names.extend(saved_model_exported_names)
return model_flags
def build_conversion_flags(inference_type=dtypes.float32,
inference_input_type=None,
input_format=lite_constants.TENSORFLOW_GRAPHDEF,
output_format=lite_constants.TFLITE,
default_ranges_stats=None,
drop_control_dependency=True,
reorder_across_fake_quant=False,
allow_custom_ops=False,
post_training_quantize=False,
quantize_to_float16=False,
dump_graphviz_dir=None,
dump_graphviz_video=False,
target_ops=None,
conversion_summary_dir=None,
select_user_tf_ops=None,
allow_all_select_tf_ops=False,
enable_tflite_resource_variables=True,
unfold_batchmatmul=True,
lower_tensor_list_ops=True,
default_to_single_batch_in_tensor_list_ops=False,
accumulation_type=None,
allow_bfloat16=False,
unfold_large_splat_constant=False,
supported_backends=None,
disable_per_channel_quantization=False,
enable_mlir_dynamic_range_quantizer=False,
tf_quantization_mode=None,
disable_infer_tensor_range=False,
use_fake_quant_num_bits=False,
enable_dynamic_update_slice=False,
preserve_assert_op=False,
guarantee_all_funcs_one_use=False,
**_):
"""Builds protocol buffer describing a conversion of a model.
Typically this is to convert from TensorFlow GraphDef to TFLite, in which
case the default `input_format` and `output_format` are sufficient.
Args:
inference_type: Data type of numeric arrays, excluding the input layer.
(default tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
inference_input_type: Data type of the numeric arrays in the input layer. If
`inference_input_type` is in {tf.int8, tf.uint8}, then
`quantized_input_stats` must be provided. (default is the value assigned
to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8})
input_format: Type of data to read. (default TENSORFLOW_GRAPHDEF, must be in
{TENSORFLOW_GRAPHDEF})
output_format: Output file format. (default TFLITE, must be in {TFLITE,
GRAPHVIZ_DOT})
default_ranges_stats: Tuple of integers representing (min, max) range values
for all arrays without a specified range. Intended for experimenting with
quantization via "dummy quantization". (default None)
drop_control_dependency: Boolean indicating whether to drop control
dependencies silently. This is due to TFLite not supporting control
dependencies. (default True)
reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant
nodes in unexpected locations. Used when the location of the FakeQuant
nodes is preventing graph transformations necessary to convert the graph.
Results in a graph that differs from the quantized training graph,
potentially causing differing arithmetic behavior. (default False)
allow_custom_ops: Boolean indicating whether to allow custom operations.
When false any unknown operation is an error. When true, custom ops are
created for any op that is unknown. The developer will need to provide
these to the TensorFlow Lite runtime with a custom resolver. (default
False)
post_training_quantize: Boolean indicating whether to quantize the weights
of the converted float model. Model size will be reduced and there will be
latency improvements (at the cost of accuracy). (default False)
quantize_to_float16: Boolean indicating whether to convert float buffers to
float16. (default False)
dump_graphviz_dir: Full filepath of folder to dump the graphs at various
stages of processing GraphViz .dot files. Preferred over
--output_format=GRAPHVIZ_DOT in order to keep the requirements of the
output file. (default None)
dump_graphviz_video: Boolean indicating whether to dump the graph after
every graph transformation. (default False)
target_ops: Experimental flag, subject to change. Set of OpsSet options
indicating which converter to use. (default set([OpsSet.TFLITE_BUILTINS]))
conversion_summary_dir: A string, the path to the generated conversion logs.
select_user_tf_ops: List of user's defined TensorFlow ops need to be
supported in the TensorFlow Lite runtime. These ops will be supported as
select TensorFlow ops.
allow_all_select_tf_ops: If True, automatically add all TF ops (including
custom TF ops) to the converted model as flex ops.
enable_tflite_resource_variables: Experimental flag, subject to change.
Enables conversion of resource variables. (default False)
unfold_batchmatmul: Whether to unfold tf.BatchMatMul to a set of
tfl.fully_connected ops. If not, translate to tfl.batch_matmul.
lower_tensor_list_ops: Whether to lower tensor list ops to builtin ops. If
not, use Flex tensor list ops.
default_to_single_batch_in_tensor_list_ops: Whether to force to use batch
size one when the tensor list ops has the unspecified batch size.
accumulation_type: Data type of the accumulators in quantized inference.
Typically used for float16 quantization and is either fp16 or fp32.
allow_bfloat16: Whether the converted model supports reduced precision
inference with the bfloat16 type.
unfold_large_splat_constant: Whether to unfold large splat constant tensors
in the flatbuffer model to reduce size.
supported_backends: List of TFLite backends which needs to check
compatibility.
disable_per_channel_quantization: Disable per-channel quantized weights for
dynamic range quantization. Only per-tensor quantization will be used.
enable_mlir_dynamic_range_quantizer: Enable MLIR dynamic range quantization.
If False, the old converter dynamic range quantizer is used.
tf_quantization_mode: Indicates the mode of TF Quantization when the output
model is used for TF Quantization.
disable_infer_tensor_range: Disable infering tensor ranges.
use_fake_quant_num_bits: Allow quantization parameters to be calculated from
num_bits attribute.
enable_dynamic_update_slice: Enable to convert to DynamicUpdateSlice op.
(default: False).
preserve_assert_op: Whether to preserve `TF::AssertOp` (default: False).
guarantee_all_funcs_one_use: Whether to clone functions so that each
function only has a single use. This option will be helpful if the
conversion fails when the `PartitionedCall` or `StatefulPartitionedCall`
can't be properly inlined (default: False).
Returns:
conversion_flags: protocol buffer describing the conversion process.
Raises:
ValueError, if the input tensor type is unknown.
"""
conversion_flags = _conversion_flags_pb2.TocoFlags()
conversion_flags.inference_type = convert_inference_tf_type_to_tflite_type(
inference_type, usage="inference_type flag")
if inference_input_type:
conversion_flags.inference_input_type = (
convert_inference_tf_type_to_tflite_type(
inference_input_type, usage="inference_input_type flag"))
else:
conversion_flags.inference_input_type = conversion_flags.inference_type
conversion_flags.input_format = input_format
conversion_flags.output_format = output_format
if default_ranges_stats:
conversion_flags.default_ranges_min = default_ranges_stats[0]
conversion_flags.default_ranges_max = default_ranges_stats[1]
conversion_flags.drop_control_dependency = drop_control_dependency
conversion_flags.reorder_across_fake_quant = reorder_across_fake_quant
conversion_flags.allow_custom_ops = allow_custom_ops
conversion_flags.post_training_quantize = post_training_quantize
conversion_flags.quantize_to_float16 = quantize_to_float16
if dump_graphviz_dir:
conversion_flags.dump_graphviz_dir = dump_graphviz_dir
conversion_flags.dump_graphviz_include_video = dump_graphviz_video
if target_ops:
if OpsSet.SELECT_TF_OPS in target_ops:
conversion_flags.enable_select_tf_ops = True
if set(target_ops) == {OpsSet.SELECT_TF_OPS}:
conversion_flags.force_select_tf_ops = True
if conversion_summary_dir:
conversion_flags.conversion_summary_dir = conversion_summary_dir
if select_user_tf_ops:
conversion_flags.select_user_tf_ops.extend(select_user_tf_ops)
conversion_flags.allow_all_select_tf_ops = allow_all_select_tf_ops
conversion_flags.enable_tflite_resource_variables = (
enable_tflite_resource_variables)
conversion_flags.unfold_batchmatmul = unfold_batchmatmul
conversion_flags.lower_tensor_list_ops = lower_tensor_list_ops
conversion_flags.default_to_single_batch_in_tensor_list_ops = (
default_to_single_batch_in_tensor_list_ops)
if accumulation_type:
conversion_flags.accumulation_type = convert_tensor_tf_type_to_tflite_type(
accumulation_type, usage="accumulation_type flag")
conversion_flags.allow_bfloat16 = allow_bfloat16
conversion_flags.unfold_large_splat_constant = unfold_large_splat_constant
if supported_backends:
conversion_flags.supported_backends.extend(supported_backends)
conversion_flags.disable_per_channel_quantization = (
disable_per_channel_quantization)
conversion_flags.enable_mlir_dynamic_range_quantizer = (
enable_mlir_dynamic_range_quantizer)
conversion_flags.enable_dynamic_update_slice = enable_dynamic_update_slice
conversion_flags.preserve_assert_op = preserve_assert_op
conversion_flags.guarantee_all_funcs_one_use = guarantee_all_funcs_one_use
if tf_quantization_mode:
conversion_flags.tf_quantization_mode = tf_quantization_mode
conversion_flags.disable_infer_tensor_range = disable_infer_tensor_range
conversion_flags.use_fake_quant_num_bits = use_fake_quant_num_bits
return conversion_flags
@convert_phase(Component.CONVERT_TF_TO_TFLITE_MODEL,
SubComponent.CONVERT_GRAPHDEF)
def convert_graphdef_with_arrays(input_data, input_arrays_with_shape,
output_arrays, control_output_arrays,
**kwargs):
""""Convert a frozen GraphDef that can't be loaded in TF.
Conversion can be customized by providing arguments that are forwarded to
`build_model_flags` and `build_conversion_flags` (see documentation).
Args:
input_data: Input data (i.e. often `sess.graph_def`),
input_arrays_with_shape: Tuple of strings representing input tensor names
and list of integers representing input shapes
(e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
into TensorFlow and when `input_tensors` is None.
output_arrays: List of output tensors to freeze graph with. Use only when
graph cannot be loaded into TensorFlow and when `output_tensors` is None.
control_output_arrays: Control output node names. This is used when
converting a Graph with no output tensors. For example, if the graph's
last operation is a Print op, just specify that op's name in this field.
This can be used together with the `output_arrays` parameter.
**kwargs: See `build_model_flags` and `build_conversion_flags`.
Returns:
The converted data. For example if TFLite was the destination, then
this will be a tflite flatbuffer in a bytes array.
Raises:
Defined in `build_conversion_flags`.
"""
model_flags = build_model_flags(**kwargs)
conversion_flags = build_conversion_flags(**kwargs)
enable_mlir_converter = kwargs.get("enable_mlir_converter", True)
quantized_input_stats = kwargs.get("quantized_input_stats", None)
for idx, (name, shape) in enumerate(input_arrays_with_shape):
input_array = model_flags.input_arrays.add()
if _is_quantized_input_stats_required(conversion_flags):
if quantized_input_stats:
input_array.mean_value, input_array.std_value = (
quantized_input_stats[idx])
else:
raise ValueError(
"The `quantized_input_stats` flag must be defined when either "
"`inference_type` flag or `inference_input_type` flag is set to "
"tf.int8 or tf.uint8.")
input_array.name = name
input_array.shape.dims.extend(list(map(int, shape)))
if output_arrays:
for name in output_arrays:
model_flags.output_arrays.append(name)
if control_output_arrays:
for name in control_output_arrays:
model_flags.control_output_arrays.append(name)
data = convert(
model_flags.SerializeToString(),
conversion_flags.SerializeToString(),
input_data.SerializeToString(),
debug_info_str=None,
enable_mlir_converter=enable_mlir_converter)
return data
@convert_phase(Component.CONVERT_TF_TO_TFLITE_MODEL,
SubComponent.CONVERT_GRAPHDEF)
def convert_graphdef(input_data, input_tensors, output_tensors, **kwargs):
"""Convert a frozen GraphDef model using the TF Lite converter.
Conversion can be customized by providing arguments that are forwarded to
`build_model_flags` and `build_conversion_flags` (see documentation).
Args:
input_data: Input data (i.e. often `sess.graph_def`),
input_tensors: List of input tensors. Type and shape are computed using
`foo.shape` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
**kwargs: See `build_model_flags` and `build_conversion_flags`.
Returns:
The converted data. For example if TFLite was the destination, then
this will be a tflite flatbuffer in a bytes array.
Raises:
Defined in `build_conversion_flags`.
"""
model_flags = build_model_flags(**kwargs)
conversion_flags = build_conversion_flags(**kwargs)
saved_model_dir = kwargs.get("saved_model_dir", None)
input_shapes = kwargs.get("input_shapes", None)
enable_mlir_converter = kwargs.get("enable_mlir_converter", True)
quantized_input_stats = kwargs.get("quantized_input_stats", None)
debug_info = kwargs.get("debug_info", None)
for idx, input_tensor in enumerate(input_tensors):
input_array = model_flags.input_arrays.add()
if saved_model_dir:
input_array.name = input_tensor.name
else:
input_array.name = util.get_tensor_name(input_tensor)
input_array.data_type = convert_tensor_tf_type_to_tflite_type(
input_tensor.dtype, usage="input type of the TensorFlow model")
if _is_quantized_input_stats_required(conversion_flags):
if quantized_input_stats:
input_array.mean_value, input_array.std_value = (
quantized_input_stats[idx])
else:
# We should ideally raise an error here, but we don't as it would break
# several models/projects that depend on this workflow.
warnings.warn("Statistics for quantized inputs were expected, but not "
"specified; continuing anyway.")
if input_shapes is None:
shape = input_tensor.shape
else:
shape = input_shapes[idx]
if shape.rank is not None:
# Create shapes with -1 for unknown dimensions.
dims = []
for dim in shape:
if (dim is None or
(isinstance(dim, tensor_shape.Dimension) and dim.value is None)):
dims.append(-1)
else:
dims.append(int(dim))
input_array.shape.dims.extend(dims)
input_array.shape.unknown_rank = False
else:
input_array.shape.unknown_rank = True
for output_tensor in output_tensors:
if saved_model_dir:
model_flags.output_arrays.append(output_tensor.name)
else:
model_flags.output_arrays.append(util.get_tensor_name(output_tensor))
data = convert(
model_flags.SerializeToString(),
conversion_flags.SerializeToString(),
input_data.SerializeToString(),
debug_info_str=debug_info.SerializeToString() if debug_info else None,
enable_mlir_converter=enable_mlir_converter)
return data
@convert_phase(Component.CONVERT_TF_TO_TFLITE_MODEL,
SubComponent.CONVERT_SAVED_MODEL)
def convert_saved_model(**kwargs):
"""Converts a SavedModel using TF Lite converter."""
model_flags = build_model_flags(**kwargs)
conversion_flags = build_conversion_flags(**kwargs)
data = convert(
model_flags.SerializeToString(),
conversion_flags.SerializeToString(),
input_data_str=None,
debug_info_str=None,
enable_mlir_converter=True)
return data
@convert_phase(Component.CONVERT_TF_TO_TFLITE_MODEL,
SubComponent.CONVERT_JAX_HLO)
def convert_jax_hlo(input_content, input_names, is_proto_format, **kwargs):
"""Converts a Jax hlo-based model using TFLite converter."""
model_flags = _model_flags_pb2.ModelFlags()
model_flags.use_hlo_import = True
if is_proto_format:
model_flags.hlo_file_type = _model_flags_pb2.ModelFlags.HLO_PROTO
else:
model_flags.hlo_file_type = _model_flags_pb2.ModelFlags.HLO_TEXT
# Build input names.
for input_name in input_names:
input_array = model_flags.input_arrays.add()
input_array.name = input_name
conversion_flags = build_conversion_flags(**kwargs)
data = convert(
model_flags.SerializeToString(),
conversion_flags.SerializeToString(),
input_data_str=input_content,
debug_info_str=None,
enable_mlir_converter=True)
return data
@_tf_export(v1=["lite.toco_convert"])
@deprecation.deprecated(None, "Use `lite.TFLiteConverter` instead.")
def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
"""Convert a TensorFlow GraphDef to TFLite.
This function is deprecated. Please use `tf.lite.TFLiteConverter` API instead.
Conversion can be customized by providing arguments that are forwarded to
`build_model_flags` and `build_conversion_flags` (see documentation for
details).
Args:
input_data: Input data (i.e. often `sess.graph_def`).
input_tensors: List of input tensors. Type and shape are computed using
`foo.shape` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
*args: See `build_model_flags` and `build_conversion_flags`.
**kwargs: See `build_model_flags` and `build_conversion_flags`.
Returns:
The converted TensorFlow Lite model in a bytes array.
Raises:
Defined in `convert`.
"""
kwargs["enable_mlir_converter"] = kwargs.get("enable_mlir_converter", False)
return convert_graphdef(input_data, input_tensors, output_tensors, *args,
**kwargs)
def deduplicate_readonly_buffers(tflite_model):
""""Generates a new model byte array after deduplicating readonly buffers.
This function should be invoked after the model optimization toolkit. The
model optimization toolkit assumes that each tensor object owns its each
buffer separately.
Args:
tflite_model: TFLite flatbuffer in a byte array to be deduplicated.
Returns:
TFLite flatbuffer in a bytes array, processed with the deduplication method.
"""
# Load TFLite Flatbuffer byte array into an object.
model = flatbuffer_utils.convert_bytearray_to_object(tflite_model)
# Get all the read-only buffers, which can be modified without causing any
# issue in the graph invocation stage.
read_only_buffer_indices = set()
for subgraph in model.subgraphs:
# To get all the read-only buffers:
# (1) Get all read-only input tensors.
# (2) Discard intermediate or output tensors.
# (3) Discard the subgraph's input/output tensors.
# (4) Gather the buffers of the read-only input tensors.
# (1) Get read-only input tensors.
read_only_input_tensor_indices = set()
for op in subgraph.operators:
if op.inputs is None:
continue
for i, input_tensor_idx in enumerate(op.inputs):
# Ignore mutable tensors.
if op.mutatingVariableInputs is not None:
# Ignore invalid tensors.
if (i < len(op.mutatingVariableInputs) and
op.mutatingVariableInputs[i]):
continue
# Ignore variable tensors.
if subgraph.tensors[input_tensor_idx].isVariable:
continue
read_only_input_tensor_indices.add(input_tensor_idx)
# (2) Discard intermediate or output tensors.
for op in subgraph.operators:
if op.outputs is not None:
for output_tensor_idx in op.outputs:
read_only_input_tensor_indices.discard(output_tensor_idx)
if op.intermediates is not None:
for intermediate_tensor_idx in op.intermediates:
read_only_input_tensor_indices.discard(intermediate_tensor_idx)
# (3) Discard the subgraph's input and output tensors.
if subgraph.inputs is not None:
for input_tensor_idx in subgraph.inputs:
read_only_input_tensor_indices.discard(input_tensor_idx)
if subgraph.outputs is not None:
for output_tensor_idx in subgraph.outputs:
read_only_input_tensor_indices.discard(output_tensor_idx)
# (4) Gather the buffers of the read-only input tensors.
for tensor_idx in read_only_input_tensor_indices:
read_only_buffer_indices.add(subgraph.tensors[tensor_idx].buffer)
# Ignore invalid negative index or zero-sized buffers.
for buffer_idx in read_only_buffer_indices.copy():
if (buffer_idx < 0 or (model.buffers[buffer_idx].data is None or
isinstance(model.buffers[buffer_idx].data, list) or
model.buffers[buffer_idx].data.size == 0)):
read_only_buffer_indices.discard(buffer_idx)
# Sort by buffer size.
read_only_buffer_indices = list(read_only_buffer_indices)
sorted(
read_only_buffer_indices,
key=lambda idx: model.buffers[idx].data.data.tobytes())
# Create a map of duplicate buffers (same size and same type).
# eg: In [1, 2, 3, 4, 5, 6] if (1, 4, 6) and (2, 5) are each, groups of buffer
# indices of the same size and type, then the map would be {4:1, 6:1, 5:2}
duplicate_buffer_map = {}
for i, buffer_i_idx in enumerate(read_only_buffer_indices):
# This buffer is a duplicate.
if buffer_i_idx in duplicate_buffer_map:
continue
# This buffer is unique. Scan rest of the list to find duplicates
# of this buffer and mark them accordingly.
buffer_i = model.buffers[buffer_i_idx]
for buffer_j_idx in read_only_buffer_indices[i + 1:]:
if buffer_j_idx in duplicate_buffer_map:
continue
buffer_j = model.buffers[buffer_j_idx]
if buffer_i.data.size != buffer_j.data.size:
break
if buffer_i.data.data != buffer_j.data.data:
continue
# Found duplicate. Nullify j-th buffer and use i-th buffer instead.
duplicate_buffer_map[buffer_j_idx] = buffer_i_idx
# Make the duplicated tensors use the single shared buffer index.
for subgraph in model.subgraphs:
for op in subgraph.operators:
if op.inputs is None:
continue
for input_tensor in op.inputs:
buffer_idx = subgraph.tensors[input_tensor].buffer
if buffer_idx in duplicate_buffer_map:
subgraph.tensors[input_tensor].buffer = (
duplicate_buffer_map[buffer_idx])
# Nullify the unused buffers.
for idx in duplicate_buffer_map:
model.buffers[idx].data = None
# Return a TFLite flatbuffer as a byte array.
return flatbuffer_utils.convert_object_to_bytearray(model)