Skip to content

Commit

Permalink
Changing the mapping between proto and TF types.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 205185039
  • Loading branch information
jsimsa authored and tensorflower-gardener committed Jul 19, 2018
1 parent 2422a25 commit 874de86
Show file tree
Hide file tree
Showing 9 changed files with 682 additions and 485 deletions.
Expand Up @@ -106,34 +106,27 @@ def _compareProtos(self, batch_shape, sizes, fields, field_dict):
self.assertEqual(v, ev)
continue

# This can be a little confusing. For testing we are using TestValue in
# two ways: it's the proto that we decode for testing, and it's used in
# the expected value as a union type.
#
# The two cases are slightly different: this is the second case. We may be
# fetching the uint64_value from the test proto, but in the expected proto
# we store it in the int64_value field because TensorFlow doesn't support
# unsigned int64.
tf_type_to_primitive_value_field = {
dtypes.bool:
'bool_value',
dtypes.float32:
'float_value',
dtypes.float64:
'double_value',
dtypes.int32:
'int32_value',
dtypes.uint8:
'uint8_value',
dtypes.int8:
'int8_value',
dtypes.string:
'string_value',
dtypes.int32:
'int32_value',
dtypes.int64:
'int64_value',
dtypes.bool:
'bool_value',
# Unhandled TensorFlow types:
# DT_INT16 DT_COMPLEX64 DT_QINT8 DT_QUINT8 DT_QINT32
# DT_BFLOAT16 DT_QINT16 DT_QUINT16 DT_UINT16
dtypes.string:
'string_value',
dtypes.uint8:
'uint8_value',
dtypes.uint32:
'uint32_value',
dtypes.uint64:
'uint64_value',
}
tf_field_name = tf_type_to_primitive_value_field.get(field.dtype)
if tf_field_name is None:
Expand Down
72 changes: 42 additions & 30 deletions tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py
Expand Up @@ -44,7 +44,7 @@ def named_parameters():
("minmax", ProtoOpTestBase.minmax_test_case()),
("nested", ProtoOpTestBase.nested_test_case()),
("optional", ProtoOpTestBase.optional_test_case()),
("promote_unsigned", ProtoOpTestBase.promote_unsigned_test_case()),
("promote", ProtoOpTestBase.promote_test_case()),
("ragged", ProtoOpTestBase.ragged_test_case()),
("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()),
("simple", ProtoOpTestBase.simple_test_case()),
Expand Down Expand Up @@ -83,13 +83,13 @@ def defaults_test_case():
test_case.sizes.append(0)
field = test_case.fields.add()
field.name = "uint64_value_with_default"
field.dtype = types_pb2.DT_INT64
field.value.int64_value.append(4)
field.dtype = types_pb2.DT_UINT64
field.value.uint64_value.append(4)
test_case.sizes.append(0)
field = test_case.fields.add()
field.name = "fixed64_value_with_default"
field.dtype = types_pb2.DT_INT64
field.value.int64_value.append(6)
field.dtype = types_pb2.DT_UINT64
field.value.uint64_value.append(6)
test_case.sizes.append(0)
field = test_case.fields.add()
field.name = "int32_value_with_default"
Expand All @@ -108,13 +108,13 @@ def defaults_test_case():
test_case.sizes.append(0)
field = test_case.fields.add()
field.name = "uint32_value_with_default"
field.dtype = types_pb2.DT_INT32
field.value.int32_value.append(9)
field.dtype = types_pb2.DT_UINT32
field.value.uint32_value.append(9)
test_case.sizes.append(0)
field = test_case.fields.add()
field.name = "fixed32_value_with_default"
field.dtype = types_pb2.DT_INT32
field.value.int32_value.append(7)
field.dtype = types_pb2.DT_UINT32
field.value.uint32_value.append(7)
test_case.sizes.append(0)
field = test_case.fields.add()
field.name = "bool_value_with_default"
Expand Down Expand Up @@ -202,15 +202,15 @@ def minmax_test_case():
test_case.sizes.append(2)
field = test_case.fields.add()
field.name = "uint64_value"
field.dtype = types_pb2.DT_INT64
field.value.int64_value.append(0)
field.value.int64_value.append(-1)
field.dtype = types_pb2.DT_UINT64
field.value.uint64_value.append(0)
field.value.uint64_value.append(18446744073709551615)
test_case.sizes.append(2)
field = test_case.fields.add()
field.name = "fixed64_value"
field.dtype = types_pb2.DT_INT64
field.value.int64_value.append(0)
field.value.int64_value.append(-1)
field.dtype = types_pb2.DT_UINT64
field.value.uint64_value.append(0)
field.value.uint64_value.append(18446744073709551615)
test_case.sizes.append(2)
field = test_case.fields.add()
field.name = "int32_value"
Expand All @@ -232,15 +232,15 @@ def minmax_test_case():
test_case.sizes.append(2)
field = test_case.fields.add()
field.name = "uint32_value"
field.dtype = types_pb2.DT_INT32
field.value.int32_value.append(0)
field.value.int32_value.append(-1)
field.dtype = types_pb2.DT_UINT32
field.value.uint32_value.append(0)
field.value.uint32_value.append(4294967295)
test_case.sizes.append(2)
field = test_case.fields.add()
field.name = "fixed32_value"
field.dtype = types_pb2.DT_INT32
field.value.int32_value.append(0)
field.value.int32_value.append(-1)
field.dtype = types_pb2.DT_UINT32
field.value.uint32_value.append(0)
field.value.uint32_value.append(4294967295)
test_case.sizes.append(2)
field = test_case.fields.add()
field.name = "bool_value"
Expand Down Expand Up @@ -289,28 +289,40 @@ def optional_test_case():
return test_case

@staticmethod
def promote_unsigned_test_case():
def promote_test_case():
test_case = test_example_pb2.TestCase()
value = test_case.values.add()
value.sint32_value.append(2147483647)
value.sfixed32_value.append(2147483647)
value.int32_value.append(2147483647)
value.fixed32_value.append(4294967295)
value.uint32_value.append(4294967295)
test_case.shapes.append(1)
test_case.sizes.append(1)
field = test_case.fields.add()
field.name = "fixed32_value"
field.name = "sint32_value"
field.dtype = types_pb2.DT_INT64
field.value.int64_value.append(4294967295)
field.value.int64_value.append(2147483647)
test_case.sizes.append(1)
field = test_case.fields.add()
field.name = "uint32_value"
field.name = "sfixed32_value"
field.dtype = types_pb2.DT_INT64
field.value.int64_value.append(4294967295)
# Comes from an explicitly-specified default
test_case.sizes.append(0)
field.value.int64_value.append(2147483647)
test_case.sizes.append(1)
field = test_case.fields.add()
field.name = "uint32_value_with_default"
field.name = "int32_value"
field.dtype = types_pb2.DT_INT64
field.value.int64_value.append(9)
field.value.int64_value.append(2147483647)
test_case.sizes.append(1)
field = test_case.fields.add()
field.name = "fixed32_value"
field.dtype = types_pb2.DT_UINT64
field.value.uint64_value.append(4294967295)
test_case.sizes.append(1)
field = test_case.fields.add()
field.name = "uint32_value"
field.dtype = types_pb2.DT_UINT64
field.value.uint64_value.append(4294967295)
return test_case

@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/kernels/BUILD
Expand Up @@ -6320,6 +6320,7 @@ tf_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core/util/proto:decode",
"//tensorflow/core/util/proto:descriptors",
"//tensorflow/core/util/proto:proto_utils",
"//third_party/eigen3",
],
)
Expand All @@ -6332,6 +6333,7 @@ tf_kernel_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/util/proto:descriptors",
"//tensorflow/core/util/proto:proto_utils",
"//third_party/eigen3",
],
)
Expand Down

0 comments on commit 874de86

Please sign in to comment.