Skip to content

Commit

Permalink
Merge pull request #2692 from tombstone/fix_example_decoder
Browse files Browse the repository at this point in the history
temporarily change tf_example_decoder to not depend on BackupHandler.
  • Loading branch information
jch1 committed Nov 2, 2017
2 parents 59b96e9 + 64f0761 commit 1e2ada2
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 60 deletions.
22 changes: 4 additions & 18 deletions research/object_detection/data_decoders/tf_example_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,24 +113,10 @@ def __init__(self,
slim_example_decoder.ItemHandlerCallback(
['image/object/mask', 'image/height', 'image/width'],
self._reshape_instance_masks))
if label_map_proto_file:
label_map = label_map_util.get_label_map_dict(label_map_proto_file,
use_display_name)
# We use a default_value of -1, but we expect all labels to be contained
# in the label map.
table = tf.contrib.lookup.HashTable(
initializer=tf.contrib.lookup.KeyValueTensorInitializer(
keys=tf.constant(list(label_map.keys())),
values=tf.constant(list(label_map.values()), dtype=tf.int64)),
default_value=-1)
# If the label_map_proto is provided, try to use it in conjunction with
# the class text, and fall back to a materialized ID.
label_handler = slim_example_decoder.BackupHandler(
slim_example_decoder.LookupTensor(
'image/object/class/text', table, default_value=''),
slim_example_decoder.Tensor('image/object/class/label'))
else:
label_handler = slim_example_decoder.Tensor('image/object/class/label')
# TODO: Add label_handler that decodes from 'image/object/class/text'
# primarily after the recent tf.contrib.slim changes make into a release
# supported by cloudml.
label_handler = slim_example_decoder.Tensor('image/object/class/label')
self.items_to_handlers[
fields.InputDataFields.groundtruth_classes] = label_handler

Expand Down
42 changes: 0 additions & 42 deletions research/object_detection/data_decoders/tf_example_decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,48 +168,6 @@ def testDecodeObjectLabel(self):
self.assertAllEqual(bbox_classes,
tensor_dict[fields.InputDataFields.groundtruth_classes])

def testDecodeObjectLabelWithMapping(self):
image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor)
bbox_classes_text = ['cat', 'dog']
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
self._BytesFeature(encoded_jpeg),
'image/format':
self._BytesFeature('jpeg'),
'image/object/class/text':
self._BytesFeature(bbox_classes_text),
})).SerializeToString()

label_map_string = """
item {
id:3
name:'cat'
}
item {
id:1
name:'dog'
}
"""
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
with tf.gfile.Open(label_map_path, 'wb') as f:
f.write(label_map_string)
example_decoder = tf_example_decoder.TfExampleDecoder(
label_map_proto_file=label_map_path)
tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))

self.assertAllEqual((tensor_dict[fields.InputDataFields.groundtruth_classes]
.get_shape().as_list()), [None])

with self.test_session() as sess:
sess.run(tf.tables_initializer())
tensor_dict = sess.run(tensor_dict)

self.assertAllEqual([3, 1],
tensor_dict[fields.InputDataFields.groundtruth_classes])

def testDecodeObjectArea(self):
image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor)
Expand Down

0 comments on commit 1e2ada2

Please sign in to comment.