Skip to content
Permalink
Browse files
Add a half type to TensorFlow core, based on Eigen::half.
Note that this is only the type, not support for it in any ops,
so it is not useful for anything yet. In particular,
neither TF_CALL_REAL_NUMBER_TYPES nor TF_CALL_GPU_NUMBER_TYPES
list Eigen::half, so even though a lot of ops will end up
declaring support for the new type, calling them will fail at
runtime.
Change: 117825461
  • Loading branch information
A. Unique TensorFlower authored and tensorflower-gardener committed Mar 22, 2016
1 parent 733321f commit 9d03824
Show file tree
Hide file tree
Showing 11 changed files with 2,862 additions and 312 deletions.
@@ -193,6 +193,7 @@ class Allocator {
template <typename T>
struct is_simple {
static const bool value = std::is_trivial<T>::value ||
std::is_same<T, Eigen::half>::value ||
std::is_same<T, complex64>::value ||
std::is_same<T, complex128>::value ||
is_quantized<T>::value;
@@ -110,15 +110,16 @@ TEST_F(OpDefBuilderTest, AttrFailure) {

TEST_F(OpDefBuilderTest, AttrWithRestrictions) {
// Types with restrictions.
ExpectSuccess(b().Attr("a:numbertype"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
"DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, "
"DT_QINT32] } } }");
ExpectSuccess(
b().Attr("a:numbertype"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
"DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, "
"DT_QINT32] } } }");
ExpectSuccess(b().Attr("a:realnumbertype"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
"DT_UINT16, DT_INT8] } } }");
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, "
"DT_INT16, DT_UINT16, DT_INT8] } } }");
ExpectSuccess(b().Attr("a:quantizedtype"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16]} } }");
@@ -191,7 +192,7 @@ TEST_F(OpDefBuilderTest, AttrListOfRestricted) {
b().Attr("a:list(realnumbertype)"),
"attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
"[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
"DT_UINT16, DT_INT8] } } }");
"DT_UINT16, DT_INT8, DT_HALF] } } }");
ExpectSuccess(
b().Attr("a:list(quantizedtype)"),
"attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
@@ -264,6 +264,23 @@ struct ProtoHelper<bfloat16> {
}
};

template <>
struct ProtoHelper<Eigen::half> {
typedef Helper<float>::RepeatedFieldType FieldType;
static const Eigen::half* Begin(const TensorProto& proto) {
return reinterpret_cast<const Eigen::half*>(proto.int_val().data());
}
static size_t NumElements(const TensorProto& proto) {
return proto.int_val().size();
}
static void Fill(const Eigen::half* data, size_t n, TensorProto* proto) {
proto->mutable_int_val()->Reserve(n);
for (size_t i = 0; i < n; ++i) {
proto->mutable_int_val()->AddAlreadyReserved(data[i].x);
}
}
};

template <typename T>
Buffer<T>::Buffer(Allocator* a, int64 n)
: alloc_(a), data_(a->Allocate<T>(n)), elem_(n) {}
@@ -410,6 +427,7 @@ void Tensor::UnsafeCopyFromInternal(const Tensor& other,
CASE(quint16, SINGLE_ARG(STMTS)) \
CASE(qint16, SINGLE_ARG(STMTS)) \
CASE(bfloat16, SINGLE_ARG(STMTS)) \
CASE(Eigen::half, SINGLE_ARG(STMTS)) \
case DT_INVALID: \
LOG(FATAL) << "Type not set"; \
break; \
@@ -82,6 +82,8 @@ string DataTypeString(DataType dtype) {
return "qint32";
case DT_BFLOAT16:
return "bfloat16";
case DT_HALF:
return "half";
default:
LOG(FATAL) << "Unrecognized DataType enum value " << dtype;
return "";
@@ -154,6 +156,9 @@ bool DataTypeFromString(StringPiece sp, DataType* dt) {
} else if (sp == "bfloat16") {
*dt = DT_BFLOAT16;
return true;
} else if (sp == "half" || sp == "float16") {
*dt = DT_HALF;
return true;
}
return false;
}
@@ -170,33 +175,33 @@ string DataTypeSliceString(const DataTypeSlice types) {
}

DataTypeVector AllTypes() {
return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16,
DT_UINT16, DT_INT8, DT_STRING, DT_COMPLEX64, DT_COMPLEX128,
DT_INT64, DT_BOOL, DT_QINT8, DT_QUINT8, DT_QINT16,
DT_QUINT16, DT_QINT32};
return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16,
DT_UINT16, DT_INT8, DT_STRING, DT_COMPLEX64, DT_COMPLEX128,
DT_INT64, DT_BOOL, DT_QINT8, DT_QUINT8, DT_QINT16,
DT_QUINT16, DT_QINT32, DT_HALF};
}

#if !defined(__ANDROID__)

DataTypeVector RealNumberTypes() {
return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64,
DT_UINT8, DT_INT16, DT_INT8, DT_UINT16};
return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8,
DT_INT16, DT_INT8, DT_UINT16, DT_HALF};
}

DataTypeVector QuantizedTypes() {
return {DT_QINT8, DT_QUINT8, DT_QINT16, DT_QUINT16, DT_QINT32};
}

DataTypeVector RealAndQuantizedTypes() {
return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8,
DT_UINT16, DT_UINT16, DT_INT8, DT_QINT8, DT_QUINT8,
DT_QINT16, DT_QUINT16, DT_QINT32};
return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8,
DT_UINT16, DT_UINT16, DT_INT8, DT_QINT8, DT_QUINT8,
DT_QINT16, DT_QUINT16, DT_QINT32, DT_HALF};
}

DataTypeVector NumberTypes() {
return {DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8,
DT_UINT16, DT_INT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128,
DT_QINT8, DT_QUINT8, DT_QINT32 };
return {DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8,
DT_UINT16, DT_INT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128,
DT_QINT8, DT_QUINT8, DT_QINT32, DT_HALF};
}

#elif defined(__ANDROID_TYPES_FULL__)
@@ -256,6 +261,7 @@ bool DataTypeCanUseMemcpy(DataType dt) {
case DT_QUINT16:
case DT_QINT32:
case DT_BFLOAT16:
case DT_HALF:
return true;
default:
return false;
@@ -183,6 +183,7 @@ MATCH_TYPE_AND_ENUM(qint16, DT_QINT16);
MATCH_TYPE_AND_ENUM(quint16, DT_QUINT16);
MATCH_TYPE_AND_ENUM(qint32, DT_QINT32);
MATCH_TYPE_AND_ENUM(bfloat16, DT_BFLOAT16);
MATCH_TYPE_AND_ENUM(Eigen::half, DT_HALF);

#undef MATCH_TYPE_AND_ENUM

@@ -31,6 +31,7 @@ enum DataType {
DT_QUINT16 = 16; // Quantized uint16
DT_UINT16 = 17;
DT_COMPLEX128 = 18; // Double-precision complex
DT_HALF = 19;

// TODO(josh11b): DT_GENERIC_PROTO = ??;
// TODO(jeff,josh11b): DT_UINT64? DT_UINT32?
@@ -55,4 +56,5 @@ enum DataType {
DT_QUINT16_REF = 116;
DT_UINT16_REF = 117;
DT_COMPLEX128_REF = 118;
DT_HALF_REF = 119;
}

0 comments on commit 9d03824

Please sign in to comment.