Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions research/object_detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ Song Y, Guadarrama S, Murphy K, CVPR 2017
\[[link](https://arxiv.org/abs/1611.10012)\]\[[bibtex](
https://scholar.googleusercontent.com/scholar.bib?q=info:l291WsrB-hQJ:scholar.google.com/&output=citation&scisig=AAGBfm0AAAAAWUIIlnPZ_L9jxvPwcC49kDlELtaeIyU-&scisf=4&ct=citation&cd=-1&hl=en&scfhb=1)\]

<p align="center">
<img src="g3doc/img/tf-od-api-logo.png" width=140 height=195>
</p>

## Maintainers

* Jonathan Huang, github: [jch1](https://github.com/jch1)
Expand Down Expand Up @@ -59,6 +63,10 @@ Extras:
Defining your own model architecture</a><br>
* <a href='g3doc/using_your_own_dataset.md'>
Bringing in your own dataset</a><br>
* <a href='g3doc/oid_inference_and_evaluation.md'>
Inference and evaluation on the Open Images dataset</a><br>
* <a href='g3doc/evaluation_protocols.md'>
Supported object detection evaluation protocols</a><br>

## Getting Help

Expand All @@ -71,8 +79,21 @@ tensorflow/models Github
[issue tracker](https://github.com/tensorflow/models/issues), prefixing the
issue name with "object_detection".



## Release information

### November 17, 2017

As a part of the Open Images V3 release we have released:

* An implementation of the Open Images evaluation metric and the [protocol](g3doc/evaluation_protocols.md#open-images).
* Additional tools to separate inference of detection and evaluation (see [this tutorial](g3doc/oid_inference_and_evaluation.md)).
* A new detection model trained on the Open Images V2 data release (see [Open Images model](g3doc/detection_model_zoo.md#open-images-models)).

See more information on the [Open Images website](https://github.com/openimages/dataset)!

<b>Thanks to contributors</b>: Stefan Popov, Alina Kuznetsova

### November 6, 2017

Expand Down Expand Up @@ -107,6 +128,7 @@ you to try out other detection models!

<b>Thanks to contributors</b>: Jonathan Huang, Andrew Harp


### June 15, 2017

In addition to our base Tensorflow detection model definitions, this
Expand All @@ -130,3 +152,4 @@ release includes:
<b>Thanks to contributors</b>: Jonathan Huang, Vivek Rathod, Derek Chow,
Chen Sun, Menglong Zhu, Matthew Tang, Anoop Korattikara, Alireza Fathi, Ian Fischer, Zbigniew Wojna, Yang Song, Sergio Guadarrama, Jasper Uijlings,
Viacheslav Kovalevskyi, Kevin Murphy

23 changes: 23 additions & 0 deletions research/object_detection/dataset_tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,26 @@ py_library(
"//tensorflow_models/object_detection/utils:dataset_util",
],
)

py_test(
name = "oid_tfrecord_creation_test",
srcs = ["oid_tfrecord_creation_test.py"],
deps = [
":oid_tfrecord_creation",
"//third_party/py/contextlib2",
"//third_party/py/pandas",
"//third_party/py/tensorflow",
],
)

py_binary(
name = "create_oid_tf_record",
srcs = ["create_oid_tf_record.py"],
deps = [
":oid_tfrecord_creation",
"//third_party/py/contextlib2",
"//third_party/py/pandas",
"//tensorflow",
"//tensorflow_models/object_detection/utils:label_map_util",
],
)
104 changes: 104 additions & 0 deletions research/object_detection/dataset_tools/create_oid_tf_record.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Creates TFRecords of Open Images dataset for object detection.

Example usage:
./create_oid_tf_record \
--input_annotations_csv=/path/to/input/annotations-human-bbox.csv \
--input_images_directory=/path/to/input/image_pixels_directory \
--input_label_map=/path/to/input/labels_bbox_545.labelmap \
--output_tf_record_path_prefix=/path/to/output/prefix.tfrecord

CSVs with bounding box annotations and image metadata (including the image URLs)
can be downloaded from the Open Images GitHub repository:
https://github.com/openimages/dataset

This script will include every image found in the input_images_directory in the
output TFRecord, even if the image has no corresponding bounding box annotations
in the input_annotations_csv.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

import contextlib2
import pandas as pd
import tensorflow as tf

from object_detection.dataset_tools import oid_tfrecord_creation
from object_detection.utils import label_map_util

tf.flags.DEFINE_string('input_annotations_csv', None,
'Path to CSV containing image bounding box annotations')
tf.flags.DEFINE_string('input_images_directory', None,
'Directory containing the image pixels '
'downloaded from the OpenImages GitHub repository.')
tf.flags.DEFINE_string('input_label_map', None, 'Path to the label map proto')
tf.flags.DEFINE_string(
'output_tf_record_path_prefix', None,
'Path to the output TFRecord. The shard index and the number of shards '
'will be appended for each output shard.')
tf.flags.DEFINE_integer('num_shards', 100, 'Number of TFRecord shards')

FLAGS = tf.flags.FLAGS


def main(_):
tf.logging.set_verbosity(tf.logging.INFO)

required_flags = [
'input_annotations_csv', 'input_images_directory', 'input_label_map',
'output_tf_record_path_prefix'
]
for flag_name in required_flags:
if not getattr(FLAGS, flag_name):
raise ValueError('Flag --{} is required'.format(flag_name))

label_map = label_map_util.get_label_map_dict(FLAGS.input_label_map)
all_annotations = pd.read_csv(FLAGS.input_annotations_csv)
all_images = tf.gfile.Glob(
os.path.join(FLAGS.input_images_directory, '*.jpg'))
all_image_ids = [os.path.splitext(os.path.basename(v))[0] for v in all_images]
all_image_ids = pd.DataFrame({'ImageID': all_image_ids})
all_annotations = pd.concat([all_annotations, all_image_ids])

tf.logging.log(tf.logging.INFO, 'Found %d images...', len(all_image_ids))

with contextlib2.ExitStack() as tf_record_close_stack:
output_tfrecords = oid_tfrecord_creation.open_sharded_output_tfrecords(
tf_record_close_stack, FLAGS.output_tf_record_path_prefix,
FLAGS.num_shards)

for counter, image_data in enumerate(all_annotations.groupby('ImageID')):
tf.logging.log_every_n(tf.logging.INFO, 'Processed %d images...', 1000,
counter)

image_id, image_annotations = image_data
# In OID image file names are formed by appending ".jpg" to the image ID.
image_path = os.path.join(FLAGS.input_images_directory, image_id + '.jpg')
with tf.gfile.Open(image_path) as image_file:
encoded_image = image_file.read()

tf_example = oid_tfrecord_creation.tf_example_from_annotations_data_frame(
image_annotations, label_map, encoded_image)
if tf_example:
shard_idx = long(image_id, 16) % FLAGS.num_shards
output_tfrecords[shard_idx].write(tf_example.SerializeToString())


if __name__ == '__main__':
tf.app.run()
113 changes: 113 additions & 0 deletions research/object_detection/dataset_tools/oid_tfrecord_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Utilities for creating TFRecords of TF examples for the Open Images dataset.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from object_detection.core import standard_fields
from object_detection.utils import dataset_util


def tf_example_from_annotations_data_frame(annotations_data_frame, label_map,
encoded_image):
"""Populates a TF Example message with image annotations from a data frame.

Args:
annotations_data_frame: Data frame containing the annotations for a single
image.
label_map: String to integer label map.
encoded_image: The encoded image string

Returns:
The populated TF Example, if the label of at least one object is present in
label_map. Otherwise, returns None.
"""

filtered_data_frame = annotations_data_frame[
annotations_data_frame.LabelName.isin(label_map)]

image_id = annotations_data_frame.ImageID.iloc[0]

feature_map = {
standard_fields.TfExampleFields.object_bbox_ymin:
dataset_util.float_list_feature(filtered_data_frame.YMin.as_matrix()),
standard_fields.TfExampleFields.object_bbox_xmin:
dataset_util.float_list_feature(filtered_data_frame.XMin.as_matrix()),
standard_fields.TfExampleFields.object_bbox_ymax:
dataset_util.float_list_feature(filtered_data_frame.YMax.as_matrix()),
standard_fields.TfExampleFields.object_bbox_xmax:
dataset_util.float_list_feature(filtered_data_frame.XMax.as_matrix()),
standard_fields.TfExampleFields.object_class_text:
dataset_util.bytes_list_feature(
filtered_data_frame.LabelName.as_matrix()),
standard_fields.TfExampleFields.object_class_label:
dataset_util.int64_list_feature(
filtered_data_frame.LabelName.map(lambda x: label_map[x])
.as_matrix()),
standard_fields.TfExampleFields.filename:
dataset_util.bytes_feature('{}.jpg'.format(image_id)),
standard_fields.TfExampleFields.source_id:
dataset_util.bytes_feature(image_id),
standard_fields.TfExampleFields.image_encoded:
dataset_util.bytes_feature(encoded_image),
}

if 'IsGroupOf' in filtered_data_frame.columns:
feature_map[standard_fields.TfExampleFields.
object_group_of] = dataset_util.int64_list_feature(
filtered_data_frame.IsGroupOf.as_matrix().astype(int))
if 'IsOccluded' in filtered_data_frame.columns:
feature_map[standard_fields.TfExampleFields.
object_occluded] = dataset_util.int64_list_feature(
filtered_data_frame.IsOccluded.as_matrix().astype(int))
if 'IsTruncated' in filtered_data_frame.columns:
feature_map[standard_fields.TfExampleFields.
object_truncated] = dataset_util.int64_list_feature(
filtered_data_frame.IsTruncated.as_matrix().astype(int))
if 'IsDepiction' in filtered_data_frame.columns:
feature_map[standard_fields.TfExampleFields.
object_depiction] = dataset_util.int64_list_feature(
filtered_data_frame.IsDepiction.as_matrix().astype(int))

return tf.train.Example(features=tf.train.Features(feature=feature_map))


def open_sharded_output_tfrecords(exit_stack, base_path, num_shards):
"""Opens all TFRecord shards for writing and adds them to an exit stack.

Args:
exit_stack: A context2.ExitStack used to automatically closed the TFRecords
opened in this function.
base_path: The base path for all shards
num_shards: The number of shards

Returns:
The list of opened TFRecords. Position k in the list corresponds to shard k.
"""
tf_record_output_filenames = [
'{}-{:05d}-of-{:05d}'.format(base_path, idx, num_shards)
for idx in xrange(num_shards)
]

tfrecords = [
exit_stack.enter_context(tf.python_io.TFRecordWriter(file_name))
for file_name in tf_record_output_filenames
]

return tfrecords
Loading