Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding quantization support for deeplab #6681

Merged
merged 3 commits into from May 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 17 additions & 6 deletions research/deeplab/eval.py
Expand Up @@ -38,8 +38,8 @@
flags.DEFINE_integer('eval_batch_size', 1,
'The number of images in each batch during evaluation.')

flags.DEFINE_multi_integer('eval_crop_size', [513, 513],
'Image crop size [height, width] for evaluation.')
flags.DEFINE_list('eval_crop_size', '513,513',
'Image crop size [height, width] for evaluation.')

flags.DEFINE_integer('eval_interval_secs', 60 * 5,
'How often (in seconds) to run evaluation.')
Expand All @@ -61,6 +61,10 @@
flags.DEFINE_bool('add_flipped_images', False,
'Add flipped images for evaluation or not.')

flags.DEFINE_integer(
'quantize_delay_step', -1,
'Steps to start quantized training. If < 0, will not quantize model.')

# Dataset settings.

flags.DEFINE_string('dataset', 'pascal_voc_seg',
Expand All @@ -84,7 +88,7 @@ def main(unused_argv):
split_name=FLAGS.eval_split,
dataset_dir=FLAGS.dataset_dir,
batch_size=FLAGS.eval_batch_size,
crop_size=FLAGS.eval_crop_size,
crop_size=map(int, FLAGS.eval_crop_size),
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
Expand All @@ -102,22 +106,26 @@ def main(unused_argv):

model_options = common.ModelOptions(
outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_of_classes},
crop_size=FLAGS.eval_crop_size,
crop_size=map(int, FLAGS.eval_crop_size),
atrous_rates=FLAGS.atrous_rates,
output_stride=FLAGS.output_stride)

# Set shape in order for tf.contrib.tfprof.model_analyzer to work properly.
samples[common.IMAGE].set_shape(
[FLAGS.eval_batch_size,
FLAGS.eval_crop_size[0],
FLAGS.eval_crop_size[1],
int(FLAGS.eval_crop_size[0]),
int(FLAGS.eval_crop_size[1]),
3])
if tuple(FLAGS.eval_scales) == (1.0,):
tf.logging.info('Performing single-scale test.')
predictions = model.predict_labels(samples[common.IMAGE], model_options,
image_pyramid=FLAGS.image_pyramid)
else:
tf.logging.info('Performing multi-scale test.')
if FLAGS.quantize_delay_step >= 0:
raise ValueError(
'Quantize mode is not supported with multi-scale test.')

predictions = model.predict_labels_multi_scale(
samples[common.IMAGE],
model_options=model_options,
Expand Down Expand Up @@ -154,6 +162,9 @@ def main(unused_argv):
if FLAGS.max_number_of_evaluations > 0:
num_eval_iters = FLAGS.max_number_of_evaluations

if FLAGS.quantize_delay_step >= 0:
tf.contrib.quantize.create_eval_graph()

tf.contrib.tfprof.model_analyzer.print_model_analysis(
tf.get_default_graph(),
tfprof_options=tf.contrib.tfprof.model_analyzer.
Expand Down
12 changes: 11 additions & 1 deletion research/deeplab/export_model.py
Expand Up @@ -53,6 +53,10 @@
flags.DEFINE_bool('add_flipped_images', False,
'Add flipped images during inference or not.')

flags.DEFINE_integer(
'quantize_delay_step', -1,
'Steps to start quantized training. If < 0, will not quantize model.')

flags.DEFINE_bool('save_inference_graph', False,
'Save inference graph in text proto.')

Expand Down Expand Up @@ -124,6 +128,9 @@ def main(unused_argv):
image_pyramid=FLAGS.image_pyramid)
else:
tf.logging.info('Exported model performs multi-scale inference.')
if FLAGS.quantize_delay_step >= 0:
raise ValueError(
'Quantize mode is not supported with multi-scale test.')
predictions = model.predict_labels_multi_scale(
image,
model_options=model_options,
Expand All @@ -150,7 +157,10 @@ def _resize_label(label, label_size):
semantic_predictions = _resize_label(semantic_predictions, image_size)
semantic_predictions = tf.identity(semantic_predictions, name=_OUTPUT_NAME)

saver = tf.train.Saver(tf.model_variables())
if FLAGS.quantize_delay_step >= 0:
tf.contrib.quantize.create_eval_graph()

saver = tf.train.Saver(tf.all_variables())

dirname = os.path.dirname(FLAGS.export_path)
tf.gfile.MakeDirs(dirname)
Expand Down
3 changes: 1 addition & 2 deletions research/deeplab/g3doc/ade20k.md
Expand Up @@ -57,8 +57,7 @@ python deeplab/train.py \
--atrous_rates=18 \
--output_stride=16 \
--decoder_output_stride=4 \
--train_crop_size=513 \
--train_crop_size=513 \
--train_crop_size="513,513" \
--train_batch_size=4 \
--min_resize_value=513 \
--max_resize_value=513 \
Expand Down
9 changes: 3 additions & 6 deletions research/deeplab/g3doc/cityscapes.md
Expand Up @@ -50,8 +50,7 @@ python deeplab/train.py \
--atrous_rates=18 \
--output_stride=16 \
--decoder_output_stride=4 \
--train_crop_size=769 \
--train_crop_size=769 \
--train_crop_size="769,769" \
--train_batch_size=1 \
--dataset="cityscapes" \
--tf_initial_checkpoint=${PATH_TO_INITIAL_CHECKPOINT} \
Expand Down Expand Up @@ -103,8 +102,7 @@ python deeplab/eval.py \
--atrous_rates=18 \
--output_stride=16 \
--decoder_output_stride=4 \
--eval_crop_size=1025 \
--eval_crop_size=2049 \
--eval_crop_size="1025,2049" \
--dataset="cityscapes" \
--checkpoint_dir=${PATH_TO_CHECKPOINT} \
--eval_logdir=${PATH_TO_EVAL_DIR} \
Expand All @@ -130,8 +128,7 @@ python deeplab/vis.py \
--atrous_rates=18 \
--output_stride=16 \
--decoder_output_stride=4 \
--vis_crop_size=1025 \
--vis_crop_size=2049 \
--vis_crop_size="1025,2049" \
--dataset="cityscapes" \
--colormap_type="cityscapes" \
--checkpoint_dir=${PATH_TO_CHECKPOINT} \
Expand Down
9 changes: 3 additions & 6 deletions research/deeplab/g3doc/pascal.md
Expand Up @@ -52,8 +52,7 @@ python deeplab/train.py \
--atrous_rates=18 \
--output_stride=16 \
--decoder_output_stride=4 \
--train_crop_size=513 \
--train_crop_size=513 \
--train_crop_size="513,513" \
--train_batch_size=1 \
--dataset="pascal_voc_seg" \
--tf_initial_checkpoint=${PATH_TO_INITIAL_CHECKPOINT} \
Expand Down Expand Up @@ -96,8 +95,7 @@ python deeplab/eval.py \
--atrous_rates=18 \
--output_stride=16 \
--decoder_output_stride=4 \
--eval_crop_size=513 \
--eval_crop_size=513 \
--eval_crop_size="513,513" \
--dataset="pascal_voc_seg" \
--checkpoint_dir=${PATH_TO_CHECKPOINT} \
--eval_logdir=${PATH_TO_EVAL_DIR} \
Expand All @@ -123,8 +121,7 @@ python deeplab/vis.py \
--atrous_rates=18 \
--output_stride=16 \
--decoder_output_stride=4 \
--vis_crop_size=513 \
--vis_crop_size=513 \
--vis_crop_size="513,513" \
--dataset="pascal_voc_seg" \
--checkpoint_dir=${PATH_TO_CHECKPOINT} \
--vis_logdir=${PATH_TO_VIS_DIR} \
Expand Down
110 changes: 110 additions & 0 deletions research/deeplab/g3doc/quantize.md
@@ -0,0 +1,110 @@
# Quantize DeepLab model for faster on-device inference

This page describes the steps required to quantize DeepLab model and convert it
to TFLite for on-device inference. The main steps include:

1. Quantization-aware training
1. Exporting model
1. Converting to TFLite FlatBuffer

We provide details for each step below.

## Quantization-aware training

DeepLab supports two approaches to quantize your model.

1. **[Recommended]** Training a non-quantized model until convergence. Then
fine-tune the trained float model with quantization using a small learning
rate (on PASCAL we use the value of 3e-5) . This fine-tuning step usually
takes 2k to 5k steps to converge.

1. Training a deeplab float model with delayed quantization. Usually we delay
quantization until the last a few thousand steps in training.

In the current implementation, quantization is only supported with 1)
`num_clones=1` for training and 2) single scale inference for evaluation,
visualization and model export. To get the best performance for the quantized
model, we strongly recommend to train the float model with larger `num_clones`
and then fine-tune the model with a single clone.

Here shows the commandline to quantize deeplab model trained on PASCAL VOC
dataset using fine-tuning:

```
# From tensorflow/models/research/
python deeplab/train.py \
--logtostderr \
--training_number_of_steps=3000 \
--train_split="train" \
--model_variant="mobilenet_v2" \
--output_stride=16 \
--train_crop_size="513,513" \
--train_batch_size=8 \
--base_learning_rate=3e-5 \
--dataset="pascal_voc_seg" \
--initialize_last_layer \
--quantize_delay_step=0 \
--tf_initial_checkpoint=${PATH_TO_TRAINED_FLOAT_MODEL} \
--train_logdir=${PATH_TO_TRAIN_DIR} \
--dataset_dir=${PATH_TO_DATASET}
```

## Converting to TFLite FlatBuffer

First use the following commandline to export your trained model.

```
# From tensorflow/models/research/
python deeplab/export_model.py \
--checkpoint_path=${CHECKPOINT_PATH} \
--quantize_delay_step=0 \
--export_path=${OUTPUT_DIR}/frozen_inference_graph.pb

```

Commandline below shows how to convert exported graphdef to TFlite model.

```
tflite_convert \
--graph_def_file=${OUTPUT_DIR}/frozen_inference_graph.pb \
--output_file=${OUTPUT_DIR}/frozen_inference_graph.tflite \
--output_format=TFLITE \
--input_shape=1,513,513,3 \
--input_arrays="MobilenetV2/MobilenetV2/input" \
--inference_type=QUANTIZED_UINT8 \
--inference_input_type=QUANTIZED_UINT8 \
--std_dev_values=128 \
--mean_values=128 \
--change_concat_input_ranges=true \
--output_arrays="ArgMax"
```

**[Important]** Note that converted model expects 513x513 RGB input and doesn't
include preprocessing (resize and pad input image) and post processing (crop
padded region and resize to original input size). These steps can be implemented
outside of TFlite model.

## Quantized model on PASCAL VOC

We provide float and quantized checkpoints that have been pretrained on VOC 2012
train_aug set, using MobileNet-v2 backbone with different depth multipliers.
Quantized model usually have 1% decay in mIoU.

For quantized (8bit) model, un-tar'ed directory includes:

* a frozen inference graph (frozen_inference_graph.pb)

* a checkpoint (model.ckpt.data*, model.ckpt.index)

* a converted TFlite FlatBuffer file (frozen_inference_graph.tflite)

Checkpoint name | Eval OS | Eval scales | Left-right Flip | Multiply-Adds | Quantize | PASCAL mIOU | File Size
-------------------------------------------------------------------------------------------------------------------------------------------- | :-----: | :---------: | :-------------: | :-----------: | :------: | :----------: | :-------:
[mobilenetv2_dm05_coco_voc_trainaug](http://download.tensorflow.org/models/deeplabv3_mnv2_dm05_pascal_trainaug_2018_10_01.tar.gz) | 16 | [1.0] | No | 0.88B | No | 70.19% (val) | 7.6MB
[mobilenetv2_dm05_coco_voc_trainaug_8bit](http://download.tensorflow.org/models/deeplabv3_mnv2_dm05_pascal_train_aug_8bit_2019_04_26.tar.gz) | 16 | [1.0] | No | 0.88B | Yes | 69.65% (val) | 8.2MB
[mobilenetv2_coco_voc_trainaug](http://download.tensorflow.org/models/deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz) | 16 | [1.0] | No | 2.75B | No | 75.32% (val) | 23MB
[mobilenetv2_coco_voc_trainaug_8bit](http://download.tensorflow.org/models/deeplabv3_mnv2_pascal_train_aug_8bit_2019_04_26.tar.gz) | 16 | [1.0] | No | 2.75B | Yes | 74.26% (val) | 24MB

Note that you might need the nightly build of TensorFlow (see
[here](https://www.tensorflow.org/install) for install instructions) to convert
above quantized model to TFLite.
9 changes: 3 additions & 6 deletions research/deeplab/local_test.sh
Expand Up @@ -82,8 +82,7 @@ python "${WORK_DIR}"/train.py \
--atrous_rates=18 \
--output_stride=16 \
--decoder_output_stride=4 \
--train_crop_size=513 \
--train_crop_size=513 \
--train_crop_size="513,513" \
--train_batch_size=4 \
--training_number_of_steps="${NUM_ITERATIONS}" \
--fine_tune_batch_norm=true \
Expand All @@ -103,8 +102,7 @@ python "${WORK_DIR}"/eval.py \
--atrous_rates=18 \
--output_stride=16 \
--decoder_output_stride=4 \
--eval_crop_size=513 \
--eval_crop_size=513 \
--eval_crop_size="513,513" \
--checkpoint_dir="${TRAIN_LOGDIR}" \
--eval_logdir="${EVAL_LOGDIR}" \
--dataset_dir="${PASCAL_DATASET}" \
Expand All @@ -120,8 +118,7 @@ python "${WORK_DIR}"/vis.py \
--atrous_rates=18 \
--output_stride=16 \
--decoder_output_stride=4 \
--vis_crop_size=513 \
--vis_crop_size=513 \
--vis_crop_size="513,513" \
--checkpoint_dir="${TRAIN_LOGDIR}" \
--vis_logdir="${VIS_LOGDIR}" \
--dataset_dir="${PASCAL_DATASET}" \
Expand Down
9 changes: 3 additions & 6 deletions research/deeplab/local_test_mobilenetv2.sh
Expand Up @@ -79,8 +79,7 @@ python "${WORK_DIR}"/train.py \
--train_split="trainval" \
--model_variant="mobilenet_v2" \
--output_stride=16 \
--train_crop_size=513 \
--train_crop_size=513 \
--train_crop_size="513,513" \
--train_batch_size=4 \
--training_number_of_steps="${NUM_ITERATIONS}" \
--fine_tune_batch_norm=true \
Expand All @@ -95,8 +94,7 @@ python "${WORK_DIR}"/eval.py \
--logtostderr \
--eval_split="val" \
--model_variant="mobilenet_v2" \
--eval_crop_size=513 \
--eval_crop_size=513 \
--eval_crop_size="513,513" \
--checkpoint_dir="${TRAIN_LOGDIR}" \
--eval_logdir="${EVAL_LOGDIR}" \
--dataset_dir="${PASCAL_DATASET}" \
Expand All @@ -107,8 +105,7 @@ python "${WORK_DIR}"/vis.py \
--logtostderr \
--vis_split="val" \
--model_variant="mobilenet_v2" \
--vis_crop_size=513 \
--vis_crop_size=513 \
--vis_crop_size="513,513" \
--checkpoint_dir="${TRAIN_LOGDIR}" \
--vis_logdir="${VIS_LOGDIR}" \
--dataset_dir="${PASCAL_DATASET}" \
Expand Down