/
lite.py
2859 lines (2477 loc) · 118 KB
/
lite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Lite tooling helper functionality."""
import enum
import functools
import pprint
import shutil
import tempfile
import time
import warnings
from absl import logging
import six
from six import PY2
from google.protobuf import text_format as _text_format
from google.protobuf.message import DecodeError
from tensorflow.core.framework import graph_pb2 as _graph_pb2
from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op # pylint: disable=unused-import
from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metdata_fb
from tensorflow.lite.python import lite_constants as constants
from tensorflow.lite.python.convert import convert_graphdef as _convert_graphdef
from tensorflow.lite.python.convert import convert_graphdef_with_arrays as _convert_graphdef_with_arrays
from tensorflow.lite.python.convert import convert_jax_hlo as _convert_jax_hlo
from tensorflow.lite.python.convert import convert_saved_model as _convert_saved_model
from tensorflow.lite.python.convert import ConverterError # pylint: disable=unused-import
from tensorflow.lite.python.convert import deduplicate_readonly_buffers as _deduplicate_readonly_buffers
from tensorflow.lite.python.convert import mlir_quantize as _mlir_quantize
from tensorflow.lite.python.convert import mlir_sparsify as _mlir_sparsify
from tensorflow.lite.python.convert import OpsSet
from tensorflow.lite.python.convert import toco_convert # pylint: disable=unused-import
from tensorflow.lite.python.convert_phase import Component
from tensorflow.lite.python.convert_phase import convert_phase
from tensorflow.lite.python.convert_phase import SubComponent
from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=unused-import
from tensorflow.lite.python.interpreter import load_delegate # pylint: disable=unused-import
from tensorflow.lite.python.interpreter import OpResolverType # pylint: disable=unused-import
from tensorflow.lite.python.metrics import metrics
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
from tensorflow.lite.python.op_hint import is_ophint_converted as _is_ophint_converted
from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import
from tensorflow.lite.python.optimize import calibrator as _calibrator
from tensorflow.lite.python.util import _xla_computation
from tensorflow.lite.python.util import build_debug_info_func as _build_debug_info_func
from tensorflow.lite.python.util import convert_debug_info_func as _convert_debug_info_func
from tensorflow.lite.python.util import freeze_graph as _freeze_graph
from tensorflow.lite.python.util import get_debug_info as _get_debug_info
from tensorflow.lite.python.util import get_grappler_config as _get_grappler_config
from tensorflow.lite.python.util import get_sparsity_modes as _get_sparsity_modes
from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name
from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
from tensorflow.lite.python.util import get_tf_type_name as _get_tf_type_name
from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
from tensorflow.lite.python.util import model_input_signature as _model_input_signature
from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type
from tensorflow.lite.python.util import populate_conversion_metadata as _populate_conversion_metadata
from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
from tensorflow.lite.python.util import trace_model_call as _trace_model_call
from tensorflow.lite.tools import flatbuffer_utils
from tensorflow.lite.tools.optimize.debugging.python.debugger import QuantizationDebugger # pylint: disable=unused-import
from tensorflow.lite.tools.optimize.debugging.python.debugger import QuantizationDebugOptions # pylint: disable=unused-import
from tensorflow.python import saved_model as _saved_model
from tensorflow.python.client import session as _session
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function as _def_function
from tensorflow.python.eager import function as _function
from tensorflow.python.framework import convert_to_constants as _convert_to_constants
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import versions
from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import loader_impl as _loader_impl
from tensorflow.python.saved_model import save_options as _save_options
from tensorflow.python.saved_model import signature_constants as _signature_constants
from tensorflow.python.saved_model import tag_constants as _tag_constants
from tensorflow.python.saved_model.load import load as _load
from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info as _parse_saved_model_with_debug_info
from tensorflow.python.util import deprecation as _deprecation
from tensorflow.python.util import keras_deps
from tensorflow.python.util.tf_export import tf_export as _tf_export
@_tf_export("lite.Optimize")
class Optimize(enum.Enum):
"""Enum defining the optimizations to apply when generating a tflite model.
DEFAULT
Default optimization strategy that quantizes model weights. Enhanced
optimizations are gained by providing a representative dataset that
quantizes biases and activations as well.
Converter will do its best to reduce size and latency, while minimizing
the loss in accuracy.
OPTIMIZE_FOR_SIZE
Deprecated. Does the same as DEFAULT.
OPTIMIZE_FOR_LATENCY
Deprecated. Does the same as DEFAULT.
EXPERIMENTAL_SPARSITY
Experimental flag, subject to change.
Enable optimization by taking advantage of the sparse model weights
trained with pruning.
The converter will inspect the sparsity pattern of the model weights and
do its best to improve size and latency.
The flag can be used alone to optimize float32 models with sparse weights.
It can also be used together with the DEFAULT optimization mode to
optimize quantized models with sparse weights.
"""
# Default optimization strategy that quantizes model weights. Enhanced
# optimizations are gained by providing a representative dataset that
# quantizes biases and activations as well.
# Converter will do its best to reduce size and latency, while minimizing
# the loss in accuracy.
DEFAULT = "DEFAULT"
# Deprecated. Does the same as DEFAULT.
OPTIMIZE_FOR_SIZE = "OPTIMIZE_FOR_SIZE"
# Deprecated. Does the same as DEFAULT.
OPTIMIZE_FOR_LATENCY = "OPTIMIZE_FOR_LATENCY"
# Experimental flag, subject to change.
# Enable optimization by taking advantage of the sparse model weights trained
# with pruning.
#
# The converter will inspect the sparsity pattern of the model weights and do
# its best to improve size and latency.
# The flag can be used alone to optimize float32 models with sparse weights.
# It can also be used together with the DEFAULT optimization mode to optimize
# quantized models with sparse weights.
# TODO(b/161560631): Add log message when this optimization is applied.
EXPERIMENTAL_SPARSITY = "EXPERIMENTAL_SPARSITY"
def __str__(self):
return str(self.value)
# TODO(b/198099651): move converter implementation out of lite.py
@_tf_export("lite.RepresentativeDataset")
class RepresentativeDataset(object):
"""Representative dataset used to optimize the model.
This is a generator function that provides a small dataset to calibrate or
estimate the range, i.e, (min, max) of all floating-point arrays in the model
(such as model input, activation outputs of intermediate layers, and model
output) for quantization. Usually, this is a small subset of a few hundred
samples randomly chosen, in no particular order, from the training or
evaluation dataset.
"""
def __init__(self, input_gen):
"""Creates a representative dataset.
Args:
input_gen: A generator function that generates input samples for the
model and has the same order, type and shape as the inputs to the model.
Usually, this is a small subset of a few hundred samples randomly
chosen, in no particular order, from the training or evaluation dataset.
"""
self.input_gen = input_gen
@_tf_export("lite.TargetSpec")
class TargetSpec(object):
"""Specification of target device used to optimize the model.
Attributes:
supported_ops: Experimental flag, subject to change. Set of `tf.lite.OpsSet`
options, where each option represents a set of operators supported by the
target device. (default {tf.lite.OpsSet.TFLITE_BUILTINS}))
supported_types: Set of `tf.dtypes.DType` data types supported on the target
device. If initialized, optimization might be driven by the smallest type
in this set. (default set())
experimental_select_user_tf_ops: Experimental flag, subject to change. Set
of user's TensorFlow operators' names that are required in the TensorFlow
Lite runtime. These ops will be exported as select TensorFlow ops in the
model (in conjunction with the tf.lite.OpsSet.SELECT_TF_OPS flag). This is
an advanced feature that should only be used if the client is using TF ops
that may not be linked in by default with the TF ops that are provided
when using the SELECT_TF_OPS path. The client is responsible for linking
these ops into the target runtime.
experimental_supported_backends: Experimental flag, subject to change.
Set containing names of supported backends. Currently only "GPU" is
supported, more options will be available later.
"""
def __init__(self,
supported_ops=None,
supported_types=None,
experimental_select_user_tf_ops=None,
experimental_supported_backends=None):
if supported_ops is None:
supported_ops = {OpsSet.TFLITE_BUILTINS}
self.supported_ops = supported_ops
if supported_types is None:
supported_types = set()
self.supported_types = supported_types
if experimental_select_user_tf_ops is None:
experimental_select_user_tf_ops = set()
self.experimental_select_user_tf_ops = experimental_select_user_tf_ops
self.experimental_supported_backends = experimental_supported_backends
self._experimental_custom_op_registerers = []
# Hint for the supported accumulation type used for inference. Typically
# used for fp16 post-training quantization, where some models can use fp16
# accumulators instead of the typical fp32 type.
# TODO(b/188185962): Provide full API and authoring support for
# reduced precision accumulation types.
self._experimental_supported_accumulation_type = None
class QuantizationMode(object):
"""QuantizationMode determines the quantization type from user options."""
def __init__(self,
optimizations,
target_spec,
representative_dataset,
graph_def,
disable_per_channel=False,
experimental_new_dynamic_range_quantizer=False,
experimental_low_bit_qat=False,
full_integer_quantization_bias_type=None):
self._optimizations = optimizations
for deprecated_optimization in [
Optimize.OPTIMIZE_FOR_SIZE, Optimize.OPTIMIZE_FOR_LATENCY
]:
if deprecated_optimization in self._optimizations:
logging.warning(
"Optimization option %s is deprecated, please use optimizations="
"[Optimize.DEFAULT] instead.", deprecated_optimization)
self._target_spec = target_spec
self._representative_dataset = representative_dataset
self._graph_def = graph_def
self._validate_int8_required()
self._disable_per_channel = disable_per_channel
self._enable_new_dynamic_range_quantizer = (
experimental_new_dynamic_range_quantizer)
# Allow training with lower than 8 bit weights to be converted
# to constants with trained scale.
self._experimental_low_bit_qat = experimental_low_bit_qat
self._full_integer_quantization_bias_type = full_integer_quantization_bias_type
self._validate_full_integer_quantization_bias_type()
def is_post_training_int8_only_quantization(self):
return (self.is_any_optimization_enabled() and
self._representative_dataset is not None and
not self._is_int16x8_target_required() and
not self.is_allow_float() and
self._is_int8_target_required())
def is_post_training_int8_quantization_with_float_fallback(self):
return (self.is_any_optimization_enabled() and
self._representative_dataset is not None and
not self._is_int16x8_target_required() and
self.is_allow_float() and
self._smallest_supported_type() == _dtypes.int8)
def is_post_training_int8_quantization(self):
return (self.is_post_training_int8_only_quantization() or
self.is_post_training_int8_quantization_with_float_fallback())
def is_post_training_int16x8_only_quantization(self):
return (self.is_any_optimization_enabled() and
self._representative_dataset is not None and
self._is_int16x8_target_required() and
not self.is_allow_float())
def is_post_training_int16x8_quantization_with_float_fallback(self):
return (self.is_any_optimization_enabled() and
self._representative_dataset is not None and
self._is_int16x8_target_required() and
self.is_allow_float())
def is_post_training_int16x8_quantization(self):
return (self.is_post_training_int16x8_only_quantization() or
self.is_post_training_int16x8_quantization_with_float_fallback())
def is_post_training_integer_quantization(self):
return (self.is_post_training_int8_quantization() or
self.is_post_training_int16x8_quantization())
def is_low_bit_quantize_aware_training(self):
return (self.is_any_optimization_enabled() and
self.is_quantization_aware_trained_model() and
self._experimental_low_bit_qat)
def is_quantization_aware_training(self):
return (self.is_any_optimization_enabled() and
self.is_quantization_aware_trained_model() and
not self.is_low_bit_quantize_aware_training())
def is_integer_quantization(self):
return (self.is_post_training_integer_quantization() or
self.is_quantization_aware_training() or
self.is_low_bit_quantize_aware_training())
def is_post_training_dynamic_range_quantization(self):
# Post-training dynamic range quantization is only enabled if post-training
# int8 quantization and training time quantization was not done.
return (self.is_any_optimization_enabled() and
self._representative_dataset is None and
not self.is_quantization_aware_trained_model() and
self._smallest_supported_type() == _dtypes.int8)
def is_post_training_float16_quantization(self):
return (self.is_any_optimization_enabled() and
self._smallest_supported_type().size == 2 and
_dtypes.float16 in self._target_spec.supported_types)
def is_bfloat16_quantization(self):
return (self.is_any_optimization_enabled() and
self._smallest_supported_type().size == 2 and
_dtypes.bfloat16 in self._target_spec.supported_types)
def activations_type(self):
if self.is_integer_quantization():
if self._is_int16x8_target_required():
return _dtypes.int16
else:
return _dtypes.int8
else:
return _dtypes.float32
def bias_type(self):
if self._full_integer_quantization_bias_type:
return self._full_integer_quantization_bias_type
if self.activations_type() == _dtypes.int16:
return _dtypes.int64
elif self.activations_type() == _dtypes.int8:
return _dtypes.int32
else:
return _dtypes.float32
def converter_flags(self, inference_ty=None, inference_input_ty=None):
"""Flags to the converter."""
if self.is_integer_quantization():
is_low_bit_qat = self.is_low_bit_quantize_aware_training()
return {
"inference_type": (inference_ty if inference_ty is not None else
self.activations_type()),
"inference_input_type": _dtypes.float32,
"post_training_quantize": False, # disable dynamic range quantization
"quantize_to_float16": False, # disable float16 quantization
"disable_infer_tensor_range": is_low_bit_qat,
"use_fake_quant_num_bits": is_low_bit_qat,
}
elif self.is_post_training_dynamic_range_quantization():
return {
"inference_type": _dtypes.float32,
"inference_input_type": _dtypes.float32,
"post_training_quantize": True, # enable dynamic range quantization
"quantize_to_float16": False, # disable float16 quantization
# experimental: disable per-channel (per-axis) quantization.
"disable_per_channel_quantization":
self._disable_per_channel,
"enable_mlir_dynamic_range_quantizer":
self._enable_new_dynamic_range_quantizer
}
elif self.is_post_training_float16_quantization():
return {
"inference_type": _dtypes.float32,
"inference_input_type": _dtypes.float32,
"post_training_quantize": True,
"quantize_to_float16": True, # enable float16 quantization
"accumulation_type":
self._target_spec._experimental_supported_accumulation_type, # pylint: disable=protected-access
"allow_bfloat16":
self.is_bfloat16_quantization(),
"enable_mlir_dynamic_range_quantizer":
self._enable_new_dynamic_range_quantizer
}
else:
# Note this might still trigger (uint8) quantization to be compatible with
# the old converter.
return {
"inference_type": (
inference_ty if inference_ty is not None else _dtypes.float32),
"inference_input_type": inference_input_ty,
"post_training_quantize": False, # enable dynamic range quantization
"quantize_to_float16": False, # disable float16 quantization
"allow_bfloat16": self.is_bfloat16_quantization()
}
# Below are helpers for the above functions.
def _validate_int8_required(self):
"""Int8 mode requires certain parameters to exist and be compatible."""
if not self._is_int8_target_required():
return
# Validate target_spec attibute.
if (set(self._target_spec.supported_ops) == {OpsSet.TFLITE_BUILTINS_INT8}
and not (set(self._target_spec.supported_types) == set() or
set(self._target_spec.supported_types) == {_dtypes.int8})):
raise ValueError(
"As full integer quantization has been enabled by setting "
"`target_spec.supported_ops`={tf.lite.OpsSet.TFLITE_BUILTINS_INT8}, "
"thus `target_spec.supported_types` should be left uninitizalized "
"or set to {tf.int8}.")
if set(self._target_spec.supported_types) == {_dtypes.int8}:
self._target_spec.supported_ops = {OpsSet.TFLITE_BUILTINS_INT8}
# Check if representative_dataset is specified.
if (not self._representative_dataset and
not self.is_quantization_aware_training()):
raise ValueError("For full integer quantization, a "
"`representative_dataset` must be specified.")
# Update represenative dataset to the expected format.
if self._representative_dataset:
if not isinstance(self._representative_dataset, RepresentativeDataset):
self._representative_dataset = RepresentativeDataset(
self._representative_dataset)
def _validate_full_integer_quantization_bias_type(self):
"""Validates bias type for full interger quantization."""
bias_type = self._full_integer_quantization_bias_type
if not bias_type:
return
if self.activations_type() == _dtypes.float32:
raise ValueError(
"`full_integer_quantization_bias_type` is only supported for full integer quantization."
)
if self.activations_type() == _dtypes.int8 and bias_type != _dtypes.int32:
raise ValueError(
f"Expected bias type to be `dtypes.int32` for Int8Quant. "
f"Current setting bias type: {bias_type}")
if self.activations_type(
) == _dtypes.int16 and bias_type != _dtypes.int32 and bias_type != _dtypes.int64:
raise ValueError(
f"Expected bias type to be `dtypes.int32` or `dtypes.int64` for "
f"Int16Quant. Current setting bias type: {bias_type}")
def _is_int8_target_required(self):
return (OpsSet.TFLITE_BUILTINS_INT8 in set(
self._target_spec.supported_ops)) or (set(
self._target_spec.supported_types) == set([_dtypes.int8]))
def _is_int16x8_target_required(self):
return (OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
in set(self._target_spec.supported_ops))
def is_allow_float(self):
return (OpsSet.TFLITE_BUILTINS in set(
self._target_spec.supported_ops)) or (OpsSet.SELECT_TF_OPS in set(
self._target_spec.supported_ops))
def is_any_optimization_enabled(self):
return bool(
set(self._optimizations).intersection([
Optimize.OPTIMIZE_FOR_LATENCY, Optimize.OPTIMIZE_FOR_SIZE,
Optimize.DEFAULT
]))
def _smallest_supported_type(self):
if self._target_spec.supported_types:
return min(self._target_spec.supported_types, key=lambda x: x.size)
else:
# The default smallest supported type is INT8.
return _dtypes.int8
def is_quantization_aware_trained_model(self):
"""Checks if the graph contains any training-time quantization ops."""
training_quant_ops = frozenset({
"FakeQuantWithMinMaxVars",
"FakeQuantWithMinMaxVarsPerChannel",
"FakeQuantWithMinMaxArgs",
"QuantizeAndDequantizeV2",
"QuantizeAndDequantizeV3",
})
if self._graph_def:
for node_def in self._graph_def.node:
if node_def.op in training_quant_ops:
return True
for function in self._graph_def.library.function:
for node_def in function.node_def:
if node_def.op in training_quant_ops:
return True
return False
class TFLiteConverterBase(object):
"""Converter subclass to share functionality between V1 and V2 converters."""
# Stores the original model type temporarily to transmit the information
# from the factory class methods to TFLiteConverterBase init function.
_original_model_type = conversion_metdata_fb.ModelType.NONE
def __init__(self):
self.optimizations = set()
self.representative_dataset = None
self.target_spec = TargetSpec()
self.allow_custom_ops = False
self.experimental_new_converter = True
self.experimental_new_quantizer = True
self.experimental_enable_resource_variables = True
self._experimental_calibrate_only = False
self._experimental_sparsify_model = False
self._experimental_disable_per_channel = False
self._debug_info = None # contains the stack traces of all the original
# nodes in the `GraphDef` to the converter.
self.saved_model_dir = None
self._saved_model_tags = None
self._saved_model_version = 0
self._saved_model_exported_names = []
self._tflite_metrics = metrics.TFLiteConverterMetrics()
self._collected_converter_params = {}
self._experimental_disable_batchmatmul_unfold = False
self._experimental_lower_tensor_list_ops = True
self._experimental_default_to_single_batch_in_tensor_list_ops = False
self._experimental_unfold_large_splat_constant = False
self._experimental_tf_quantization_mode = None
# If unset, bias:int32 is by default except 16x8 quant.
# For 16x8 quant, bias:int64 is used to prevent any overflow by default.
self._experimental_full_integer_quantization_bias_type = None
# Initializes conversion metadata.
self.exclude_conversion_metadata = False
self._metadata = conversion_metdata_fb.ConversionMetadataT()
self._metadata.environment = conversion_metdata_fb.EnvironmentT()
self._metadata.options = conversion_metdata_fb.ConversionOptionsT()
self._metadata.environment.tensorflowVersion = versions.__version__
self._metadata.environment.modelType = self._get_original_model_type()
self._experimental_enable_dynamic_update_slice = False
self._experimental_preserve_assert_op = False
self._experimental_guarantee_all_funcs_one_use = False
# When the value is true, the MLIR quantantizer triggers dynamic range
# quantization in MLIR instead of the old quantizer. Used only if
# experimental_new_quantizer is on.
self.experimental_new_dynamic_range_quantizer = True
# Experimental flag to enable low-bit QAT in 8 bit.
self._experimental_low_bit_qat = False
def _grappler_config(self, optimizers=None):
"""Creates a tf.compat.v1.ConfigProto for configuring Grappler.
Args:
optimizers: List of strings that represents the list of optimizers.
Returns:
tf.ConfigProto.
"""
if not optimizers:
optimizers = []
# MLIR converter will take care of constant folding instead of grappler.
if not self.experimental_new_converter:
optimizers.append("constfold")
is_only_flex_enabled = (
set([OpsSet.SELECT_TF_OPS]) == set(self.target_spec.supported_ops))
if is_only_flex_enabled:
# The layout optimizer turns NHCW to NCHW. This provides performance
# optimizations when Flex mode is enabled. However, this is not compatible
# with builtin ops.
optimizers.append("layout")
return _get_grappler_config(optimizers)
def _quantize(self, result, input_type, output_type, activations_type,
bias_type, allow_float):
"""Quantize the model."""
# pylint: disable=protected-access
custom_op_registerers_by_name = [
x for x in self.target_spec._experimental_custom_op_registerers
if isinstance(x, str)
]
custom_op_registerers_by_func = [
x for x in self.target_spec._experimental_custom_op_registerers
if not isinstance(x, str)
]
# pylint: enable=protected-access
if not isinstance(self.representative_dataset, RepresentativeDataset):
self.representative_dataset = RepresentativeDataset(
self.representative_dataset)
# Add intermediate tensors to the model if needed.
result = _calibrator.add_intermediate_tensors(result)
calibrate_quantize = _calibrator.Calibrator(result,
custom_op_registerers_by_name,
custom_op_registerers_by_func)
if self._experimental_calibrate_only or self.experimental_new_quantizer:
calibrated = calibrate_quantize.calibrate(
self.representative_dataset.input_gen)
if self._experimental_calibrate_only:
return calibrated
elif self.experimental_new_quantizer and (
activations_type != _dtypes.int16):
# TODO(b/175659372): remove the activations_type restriction and enable
# it for all the activation types.
return _mlir_quantize(
calibrated,
self._experimental_disable_per_channel,
input_data_type=input_type,
output_data_type=output_type)
else:
return calibrate_quantize.calibrate_and_quantize(
self.representative_dataset.input_gen,
input_type,
output_type,
allow_float,
activations_type,
bias_type,
disable_per_channel=self._experimental_disable_per_channel)
def _is_unknown_shapes_allowed(self):
# Unknown dimensions are only allowed with the new converter.
return self.experimental_new_converter
def _get_base_converter_args(self):
"""Returns the base converter args.
Returns:
{key str: val}
"""
args = {
"input_format":
constants.TENSORFLOW_GRAPHDEF,
"allow_custom_ops":
self.allow_custom_ops,
"debug_info":
self._debug_info,
"target_ops":
self.target_spec.supported_ops,
"enable_mlir_converter":
self.experimental_new_converter,
"select_user_tf_ops":
self.target_spec.experimental_select_user_tf_ops,
"supported_backends":
self.target_spec.experimental_supported_backends,
"unfold_batchmatmul":
not self._experimental_disable_batchmatmul_unfold,
"lower_tensor_list_ops":
self._experimental_lower_tensor_list_ops,
"unfold_large_splat_constant":
self._experimental_unfold_large_splat_constant,
"default_to_single_batch_in_tensor_list_ops":
self._experimental_default_to_single_batch_in_tensor_list_ops,
"tf_quantization_mode":
self._experimental_tf_quantization_mode,
"experimental_enable_resource_variables":
self.experimental_enable_resource_variables,
"enable_dynamic_update_slice":
self._experimental_enable_dynamic_update_slice,
"preserve_assert_op":
self._experimental_preserve_assert_op,
"guarantee_all_funcs_one_use":
self._experimental_guarantee_all_funcs_one_use,
}
if self.saved_model_dir:
args.update({
"saved_model_dir": self.saved_model_dir,
"saved_model_version": self._saved_model_version,
"saved_model_tags": self._saved_model_tags,
"saved_model_exported_names": self._saved_model_exported_names,
})
return args
def _contains_function_with_implements_attr(self, saved_model_proto):
meta_graph = saved_model_proto.meta_graphs[0]
for function in meta_graph.graph_def.library.function:
if function.attr.get("_implements", None) or function.attr.get(
"api_implements", None):
return True
return False
def _parse_saved_model_args(self, always_enable_saved_model_import=False):
"""Parses SavedModel arguments from the given Keras/RNN SavedModel.
Args:
always_enable_saved_model_import: Bool. When the value is true, it enables
MLIR saved model import path regardless of checking the conditions.
"""
if not self.experimental_new_converter:
self.saved_model_dir = None
return
if self.saved_model_dir:
try:
saved_model_proto, _ = (
_parse_saved_model_with_debug_info(self.saved_model_dir))
except OSError:
# If it fails to read the given saved model, it will fall back to the
# frozen graph def path.
self.saved_model_dir = None
return
if (not always_enable_saved_model_import and
not self._contains_function_with_implements_attr(saved_model_proto)):
self.saved_model_dir = None
return
if not self._saved_model_exported_names:
self._saved_model_exported_names = []
self._saved_model_version = saved_model_proto.saved_model_schema_version
if self._saved_model_version == 0:
self.saved_model_dir = None
logging.warning("SavedModel schema version is zero.")
return
if self._saved_model_version not in [1, 2]:
raise ValueError("SavedModel file format({0}) is not supported".format(
self._saved_model_version))
def _sparsify_model(self):
return Optimize.EXPERIMENTAL_SPARSITY in self.optimizations
def _increase_conversion_attempt_metric(self):
self._tflite_metrics.increase_counter_converter_attempt()
def _increase_conversion_success_metric(self):
self._tflite_metrics.increase_counter_converter_success()
@classmethod
def _set_original_model_type(cls, model_type):
"""Stores the original model type."""
if model_type == conversion_metdata_fb.ModelType.NONE:
raise ValueError("The original model type should be specified.")
cls._original_model_type = model_type
def _get_original_model_type(self):
"""One-time getter to return original model type and set it to NONE."""
model_type = TFLiteConverterBase._original_model_type
TFLiteConverterBase._original_model_type = conversion_metdata_fb.ModelType.NONE
return model_type
def _save_conversion_params_metric(self,
graph_def=None,
inference_type=None,
inference_input_type=None):
"""Set conversion parameter metrics."""
converter_kwargs = self._collected_converter_params
converter_kwargs.update(self._get_base_converter_args())
# Optimization parameters.
quant_mode = QuantizationMode(
self.optimizations, self.target_spec, self.representative_dataset,
graph_def, self._experimental_disable_per_channel,
self.experimental_new_dynamic_range_quantizer,
self._experimental_low_bit_qat,
self._experimental_full_integer_quantization_bias_type)
converter_kwargs.update({
"tf_version":
self._metadata.environment.tensorflowVersion,
"api_version":
self._metadata.environment.apiVersion,
"original_model_format":
self._metadata.environment.modelType,
"optimization_default":
quant_mode.is_any_optimization_enabled(),
"optimization_post_training_dynamic_range":
quant_mode.is_post_training_dynamic_range_quantization(),
"optimization_post_training_float16":
quant_mode.is_post_training_float16_quantization(),
"optimization_post_training_integer_quantize":
quant_mode.is_post_training_integer_quantization(),
"optimization_qat":
quant_mode.is_quantization_aware_training(),
"optimization_low_bit_qat":
quant_mode.is_low_bit_quantize_aware_training(),
"optimization_sparsify":
self._sparsify_model(),
"activations_type":
quant_mode.activations_type()
})
converter_kwargs.update(
quant_mode.converter_flags(inference_type, inference_input_type))
# pylint: disable=protected-access
if self.target_spec._experimental_supported_accumulation_type:
converter_kwargs.update({
"accumulation_type":
self.target_spec._experimental_supported_accumulation_type
})
# pylint: enable=protected-access
def format_element(elem):
if isinstance(elem, enum.Enum):
return str(elem.value)
return pprint.pformat(elem)
def format_param(param):
if isinstance(param, (list, tuple, set)):
if not param:
return "None" # Return None if empty.
string_list = [format_element(x) for x in param]
return ",".join(sorted(string_list))
return format_element(param)
for key, value in converter_kwargs.items():
self._tflite_metrics.set_converter_param(key, format_param(value))
self._tflite_metrics.set_export_required()
# Set conversion option metadata.
self._metadata.options.allowCustomOps = self.allow_custom_ops
self._metadata.options.enableSelectTfOps = (
OpsSet.SELECT_TF_OPS in self.target_spec.supported_ops)
self._metadata.options.forceSelectTfOps = (
set([OpsSet.SELECT_TF_OPS]) == set(self.target_spec.supported_ops))
self._metadata.options.modelOptimizationModes = []
if quant_mode.is_post_training_float16_quantization():
self._metadata.options.modelOptimizationModes.append(
conversion_metdata_fb.ModelOptimizationMode.PTQ_FLOAT16)
if quant_mode.is_post_training_dynamic_range_quantization():
self._metadata.options.modelOptimizationModes.append(
conversion_metdata_fb.ModelOptimizationMode.PTQ_DYNAMIC_RANGE)
if quant_mode.is_post_training_int8_quantization():
self._metadata.options.modelOptimizationModes.append(
conversion_metdata_fb.ModelOptimizationMode.PTQ_FULL_INTEGER)
if quant_mode.is_post_training_int16x8_quantization():
self._metadata.options.modelOptimizationModes.append(
conversion_metdata_fb.ModelOptimizationMode.PTQ_INT16)
if quant_mode.is_quantization_aware_training():
self._metadata.options.modelOptimizationModes.append(
conversion_metdata_fb.ModelOptimizationMode
.QUANTIZATION_AWARE_TRAINING)
def _set_conversion_latency_metric(self, value):
self._tflite_metrics.set_converter_latency(value)
@convert_phase(Component.OPTIMIZE_TFLITE_MODEL)
def _optimize_tflite_model(self, model, quant_mode, quant_io=True):
"""Apply optimizations on a TFLite model."""
if quant_mode.is_integer_quantization():
in_type, out_type = self.inference_input_type, self.inference_output_type
if quant_mode.is_post_training_integer_quantization():
q_in_type = in_type if in_type and quant_io else _dtypes.float32
q_out_type = out_type if out_type and quant_io else _dtypes.float32
q_activations_type = quant_mode.activations_type()
q_bias_type = quant_mode.bias_type()
q_allow_float = quant_mode.is_allow_float()
model = self._quantize(model, q_in_type, q_out_type, q_activations_type,
q_bias_type, q_allow_float)
m_in_type = in_type if in_type else _dtypes.float32
m_out_type = out_type if out_type else _dtypes.float32
# Skip updating model io types if MLIR quantizer already takes care of it
if not (quant_mode.is_post_training_integer_quantization() and
self.experimental_new_quantizer and quant_io and
(m_in_type in [_dtypes.int8, _dtypes.uint8, _dtypes.float32]) and
(m_out_type in [_dtypes.int8, _dtypes.uint8, _dtypes.float32])):
model = _modify_model_io_type(model, m_in_type, m_out_type)
if self._sparsify_model():
model = _mlir_sparsify(model)
try:
model = _deduplicate_readonly_buffers(model)
except Exception: # pylint: disable=broad-except
# Skip buffer deduplication when flatbuffer library is not ready to be
# utilized.
logging.warning(
"Buffer deduplication procedure will be skipped when flatbuffer "
"library is not properly loaded")
return model
def _convert_and_export_metrics(self, convert_func, *args, **kwargs):
"""Wraps around convert function to export metrics.
Args:
convert_func: The convert function to wrap.
*args: Positional arguments of the convert function.
**kwargs: The keyword arguments of the convert function.
Returns:
The decorator to wrap the convert function.
"""
self._increase_conversion_attempt_metric()
self._save_conversion_params_metric()
start_time = time.process_time()
result = convert_func(self, *args, **kwargs)
elapsed_time_ms = (time.process_time() - start_time) * 1000
if result:
self._increase_conversion_success_metric()
self._set_conversion_latency_metric(round(elapsed_time_ms))
self._tflite_metrics.export_metrics()
model_object = flatbuffer_utils.convert_bytearray_to_object(result)
# Populates the conversion metadata.
# TODO(b/202090541): Collects sparsity block size information.
sparsity_modes = _get_sparsity_modes(model_object)
self._metadata.options.modelOptimizationModes.extend(sparsity_modes)
if not self.exclude_conversion_metadata:
model_object = _populate_conversion_metadata(model_object, self._metadata)
return flatbuffer_utils.convert_object_to_bytearray(model_object)
def _export_metrics(convert_func):
"""The decorator around convert function to export metrics."""
@functools.wraps(convert_func)
def wrapper(self, *args, **kwargs):
# pylint: disable=protected-access
return self._convert_and_export_metrics(convert_func, *args, **kwargs)
# pylint: enable=protected-access
return wrapper
class TFLiteConverterBaseV2(TFLiteConverterBase):
"""Converter subclass to share functionality between V2 converters."""
def __init__(self):
"""Constructor for TFLiteConverter."""
super(TFLiteConverterBaseV2, self).__init__()
self.inference_input_type = _dtypes.float32
self.inference_output_type = _dtypes.float32
self._metadata.environment.apiVersion = 2
def _validate_inference_input_output_types(self, quant_mode):
"""Validate inference_input_type and inference_output_type flags."""
default_types = [_dtypes.float32]
# We support integer input/output for integer quantized models only.
if quant_mode.is_integer_quantization():
if quant_mode.is_post_training_int16x8_quantization():
all_types = default_types + [_dtypes.int16]
else:
all_types = default_types + [_dtypes.int8, _dtypes.uint8]
if (self.inference_input_type not in all_types or
self.inference_output_type not in all_types):
all_types_names = ["tf." + t.name for t in all_types]
raise ValueError("The inference_input_type and inference_output_type "
"must be in {}.".format(all_types_names))
elif (self.inference_input_type not in default_types or
self.inference_output_type not in default_types):
raise ValueError("The inference_input_type and inference_output_type "
"must be tf.float32.")
@convert_phase(Component.PREPARE_TF_MODEL, SubComponent.LOAD_SAVED_MODEL)
def _load_saved_model(self, saved_model_dir, saved_model_tags):
"""Load graph_def from saved model with the default serving signature key.
Args:
saved_model_dir: Directory of the SavedModel.
saved_model_tags: Set of tags identifying the MetaGraphDef within the
SavedModel to analyze.
Returns:
graph_def: The loaded GraphDef.
input_tensors: List of input tensors.
output_tensors: List of output tensors.
"""
graph = _ops.Graph()
saved_model = _loader_impl.SavedModelLoader(saved_model_dir)
saved_model.load_graph(graph, tags=saved_model_tags)
meta_graph = saved_model.get_meta_graph_def_from_tags(saved_model_tags)
graph_def = meta_graph.graph_def
signature_def = meta_graph.signature_def[
_signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
input_tensors = [
graph.get_tensor_by_name(signature_def.inputs[key].name)
for key in signature_def.inputs
]
output_tensors = [
graph.get_tensor_by_name(signature_def.outputs[key].name)
for key in signature_def.outputs
]
return graph_def, input_tensors, output_tensors
@convert_phase(Component.PREPARE_TF_MODEL, SubComponent.VALIDATE_INPUTS)
def _validate_inputs(self, graph_def, input_tensors):
"""Validate the input parameters.
Args:
graph_def: The TensorFlow GraphDef.