Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
/research/brain_coder/ @danabo
/research/cognitive_mapping_and_planning/ @s-gupta
/research/compression/ @nmjohn
/research/deeplab/ @aquariusjay @yknzhu @gpapan
/research/delf/ @andrefaraujo
/research/differential_privacy/ @panyx0718
/research/domain_adaptation/ @bousmalis @dmrd
Expand Down
159 changes: 159 additions & 0 deletions research/deeplab/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# DeepLab: Deep Labelling for Semantic Image Segmentation

DeepLab is a state-of-art deep learning model for semantic image segmentation,
where the goal is to assign semantic labels (e.g., person, dog, cat and so on)
to every pixel in the input image. Current implementation includes the following
features:

1. DeepLabv1 [1]: We use *atrous convolution* to explicitly control the
resolution at which feature responses are computed within Deep Convolutional
Neural Networks.

2. DeepLabv2 [2]: We use *atrous spatial pyramid pooling* (ASPP) to robustly
segment objects at multiple scales with filters at multiple sampling rates
and effective fields-of-views.

3. DeepLabv3 [3]: We augment the ASPP module with *image-level feature* [5, 6]
to capture longer range information. We also include *batch normalization*
[7] parameters to facilitate the training. In particular, we applying atrous
convolution to extract output features at different output strides during
training and evaluation, which efficiently enables training BN at output
stride = 16 and attains a high performance at output stride = 8 during
evaluation.

4. DeepLabv3+ [4]: We extend DeepLabv3 to include a simple yet effective
decoder module to refine the segmentation results especially along object
boundaries. Furthermore, in this encoder-decoder structure one can
arbitrarily control the resolution of extracted encoder features by atrous
convolution to trade-off precision and runtime.

If you find the code useful for your research, please consider citing our latest
work:

```
@article{deeplabv3plus2018,
title={Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation},
author={Liang-Chieh Chen and Yukun Zhu and George Papandreou and Florian Schroff and Hartwig Adam},
journal={arXiv:1802.02611},
year={2018}
}
```

In the current implementation, we support adopting the following network
backbones:

1. MobileNetv2 [8]: A fast network structure designed for mobile devices. **We
will provide MobileNetv2 support in the next update. Please stay tuned.**

2. Xception [9, 10]: A powerful network structure intended for server-side
deployment.

This directory contains our TensorFlow [11] implementation. We provide codes
allowing users to train the model, evaluate results in terms of mIOU (mean
intersection-over-union), and visualize segmentation results. We use PASCAL VOC
2012 [12] and Cityscapes [13] semantic segmentation benchmarks as an example in
the code.

Some segmentation results on Flickr images:
<p align="center">
<img src="g3doc/img/vis1.png" width=600></br>
<img src="g3doc/img/vis2.png" width=600></br>
<img src="g3doc/img/vis3.png" width=600></br>
</p>

## Contacts (Maintainers)

* Liang-Chieh Chen, github: [aquariusjay](https://github.com/aquariusjay)
* YuKun Zhu, github: [yknzhu](https://github.com/YknZhu)
* George Papandreou, github: [gpapan](https://github.com/gpapan)

## Tables of Contents

Demo:

* <a href='deeplab_demo.ipynb'>Jupyter notebook for off-the-shelf inference.</a><br>

Running:

* <a href='g3doc/installation.md'>Installation.</a><br>
* <a href='g3doc/pascal.md'>Running DeepLab on PASCAL VOC 2012 semantic segmentation dataset.</a><br>
* <a href='g3doc/cityscapes.md'>Running DeepLab on Cityscapes semantic segmentation dataset.</a><br>

Models:

* <a href='g3doc/model_zoo.md'>Checkpoints and frozen inference graphs.</a><br>

Misc:

* Please check <a href='g3doc/faq.md'>FAQ</a> if you have some questions before reporting the issues.<br>

## Getting Help

To get help with issues you may encounter while using the DeepLab Tensorflow
implementation, create a new question on
[StackOverflow](https://stackoverflow.com/) with the tags "tensorflow" and
"deeplab".

Please report bugs (i.e., broken code, not usage questions) to the
tensorflow/models GitHub [issue
tracker](https://github.com/tensorflow/models/issues), prefixing the issue name
with "deeplab".

## References

1. **Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected CRFs**<br />
Liang-Chieh Chen+, George Papandreou+, Iasonas Kokkinos, Kevin Murphy, Alan L. Yuille (+ equal
contribution). <br />
[[link]](https://arxiv.org/abs/1412.7062). In ICLR, 2015.

2. **DeepLab: Semantic Image Segmentation with Deep Convolutional Nets,**
**Atrous Convolution, and Fully Connected CRFs** <br />
Liang-Chieh Chen+, George Papandreou+, Iasonas Kokkinos, Kevin Murphy, and Alan L Yuille (+ equal
contribution). <br />
[[link]](http://arxiv.org/abs/1606.00915). TPAMI 2017.

3. **Rethinking Atrous Convolution for Semantic Image Segmentation**<br />
Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam.<br />
[[link]](http://arxiv.org/abs/1706.05587). arXiv: 1706.05587, 2017.

4. **Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation**<br />
Liang-Chieh Chen, Yukun Zhu, George Papandreou, Florian Schroff, Hartwig Adam. arXiv: 1802.02611.<br />
[[link]](https://arxiv.org/abs/1802.02611). arXiv: 1802.02611, 2018.

5. **ParseNet: Looking Wider to See Better**<br />
Wei Liu, Andrew Rabinovich, Alexander C Berg<br />
[[link]](https://arxiv.org/abs/1506.04579). arXiv:1506.04579, 2015.

6. **Pyramid Scene Parsing Network**<br />
Hengshuang Zhao, Jianping Shi, Xiaojuan Qi, Xiaogang Wang, Jiaya Jia<br />
[[link]](https://arxiv.org/abs/1612.01105). In CVPR, 2017.

7. **Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate shift**<br />
Sergey Ioffe, Christian Szegedy <br />
[[link]](https://arxiv.org/abs/1502.03167). In ICML, 2015.

8. **Inverted Residuals and Linear Bottlenecks: Mobile Networks for Classification, Detection and Segmentation**<br />
Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen<br />
[[link]](https://arxiv.org/abs/1801.04381). arXiv:1801.04381, 2018.

9. **Xception: Deep Learning with Depthwise Separable Convolutions**<br />
François Chollet<br />
[[link]](https://arxiv.org/abs/1610.02357). In CVPR, 2017.

10. **Deformable Convolutional Networks -- COCO Detection and Segmentation Challenge 2017 Entry**<br />
Haozhi Qi, Zheng Zhang, Bin Xiao, Han Hu, Bowen Cheng, Yichen Wei, Jifeng Dai<br />
[[link]](http://presentations.cocodataset.org/COCO17-Detect-MSRA.pdf). ICCV COCO Challenge
Workshop, 2017.

11. **Tensorflow: Large-Scale Machine Learning on Heterogeneous Distributed Systems**<br />
M. Abadi, A. Agarwal, et al. <br />
[[link]](https://arxiv.org/abs/1603.04467). arXiv:1603.04467, 2016.

12. **The Pascal Visual Object Classes Challenge – A Retrospective,** <br />
Mark Everingham, S. M. Ali Eslami, Luc Van Gool, Christopher K. I. Williams, John
Winn, and Andrew Zisserma. <br />
[[link]](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/). IJCV, 2014.

13. **The Cityscapes Dataset for Semantic Urban Scene Understanding**<br />
Cordts, Marius, Mohamed Omran, Sebastian Ramos, Timo Rehfeld, Markus Enzweiler, Rodrigo Benenson, Uwe Franke, Stefan Roth, Bernt Schiele. <br />
[[link]](https://www.cityscapes-dataset.com/). In CVPR, 2016.
Empty file added research/deeplab/__init__.py
Empty file.
138 changes: 138 additions & 0 deletions research/deeplab/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides flags that are common to scripts.

Common flags from train/eval/vis/export_model.py are collected in this script.
"""
import collections

import tensorflow as tf

flags = tf.app.flags

# Flags for input preprocessing.

flags.DEFINE_integer('min_resize_value', None,
'Desired size of the smaller image side.')

flags.DEFINE_integer('max_resize_value', None,
'Maximum allowed size of the larger image side.')

flags.DEFINE_integer('resize_factor', None,
'Resized dimensions are multiple of factor plus one.')

# Model dependent flags.

flags.DEFINE_integer('logits_kernel_size', 1,
'The kernel size for the convolutional kernel that '
'generates logits.')

# We will support `mobilenet_v2' in the coming update. When using
# 'xception_65', we set atrous_rates = [6, 12, 18] (output stride 16) and
# decoder_output_stride = 4.
flags.DEFINE_enum('model_variant', 'xception_65', ['xception_65'],
'DeepLab model variants.')

flags.DEFINE_multi_float('image_pyramid', None,
'Input scales for multi-scale feature extraction.')

flags.DEFINE_boolean('add_image_level_feature', True,
'Add image level feature.')

flags.DEFINE_boolean('aspp_with_batch_norm', True,
'Use batch norm parameters for ASPP or not.')

flags.DEFINE_boolean('aspp_with_separable_conv', True,
'Use separable convolution for ASPP or not.')

flags.DEFINE_multi_integer('multi_grid', None,
'Employ a hierarchy of atrous rates for ResNet.')

# For `xception_65`, use decoder_output_stride = 4.
flags.DEFINE_integer('decoder_output_stride', None,
'The ratio of input to output spatial resolution when '
'employing decoder to refine segmentation results.')

flags.DEFINE_boolean('decoder_use_separable_conv', True,
'Employ separable convolution for decoder or not.')

flags.DEFINE_enum('merge_method', 'max', ['max', 'avg'],
'Scheme to merge multi scale features.')

FLAGS = flags.FLAGS

# Constants

# Perform semantic segmentation predictions.
OUTPUT_TYPE = 'semantic'

# Semantic segmentation item names.
LABELS_CLASS = 'labels_class'
IMAGE = 'image'
HEIGHT = 'height'
WIDTH = 'width'
IMAGE_NAME = 'image_name'
LABEL = 'label'
ORIGINAL_IMAGE = 'original_image'

# Test set name.
TEST_SET = 'test'


class ModelOptions(
collections.namedtuple('ModelOptions', [
'outputs_to_num_classes',
'crop_size',
'atrous_rates',
'output_stride',
'merge_method',
'add_image_level_feature',
'aspp_with_batch_norm',
'aspp_with_separable_conv',
'multi_grid',
'decoder_output_stride',
'decoder_use_separable_conv',
'logits_kernel_size',
'model_variant'
])):
"""Immutable class to hold model options."""

__slots__ = ()

def __new__(cls,
outputs_to_num_classes,
crop_size=None,
atrous_rates=None,
output_stride=8):
"""Constructor to set default values.

Args:
outputs_to_num_classes: A dictionary from output type to the number of
classes. For example, for the task of semantic segmentation with 21
semantic classes, we would have outputs_to_num_classes['semantic'] = 21.
crop_size: A tuple [crop_height, crop_width].
atrous_rates: A list of atrous convolution rates for ASPP.
output_stride: The ratio of input to output spatial resolution.

Returns:
A new ModelOptions instance.
"""
return super(ModelOptions, cls).__new__(
cls, outputs_to_num_classes, crop_size, atrous_rates, output_stride,
FLAGS.merge_method, FLAGS.add_image_level_feature,
FLAGS.aspp_with_batch_norm, FLAGS.aspp_with_separable_conv,
FLAGS.multi_grid, FLAGS.decoder_output_stride,
FLAGS.decoder_use_separable_conv, FLAGS.logits_kernel_size,
FLAGS.model_variant)
Empty file.
Loading