Skip to content

Commit

Permalink
Change: 110592065
Browse files Browse the repository at this point in the history
  • Loading branch information
A. Unique TensorFlower authored and Vijay Vasudevan committed Dec 21, 2015
1 parent 57d1c8a commit ef50775
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 18 deletions.
4 changes: 2 additions & 2 deletions tensorflow/core/framework/op_def_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ TEST_F(OpDefBuilderTest, AttrWithRestrictions) {
"DT_INT8] } } }");
ExpectSuccess(b().Attr("a:quantizedtype"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_QINT8, DT_QUINT8, DT_QINT32] } } }");
"[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16]} } }");
ExpectSuccess(b().Attr("a:{string,int32}"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_STRING, DT_INT32] } } }");
Expand Down Expand Up @@ -174,7 +174,7 @@ TEST_F(OpDefBuilderTest, AttrListOfRestricted) {
ExpectSuccess(
b().Attr("a:list(quantizedtype)"),
"attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
"[DT_QINT8, DT_QUINT8, DT_QINT32] } } }");
"[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16] } } }");
ExpectSuccess(
b().Attr("a: list({float, string, bool})"),
"attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/framework/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ PROTO_TRAITS(bool, bool, bool);
PROTO_TRAITS(string, string, string);
PROTO_TRAITS(qint8, int32, int);
PROTO_TRAITS(quint8, int32, int);
PROTO_TRAITS(qint16, int32, int);
PROTO_TRAITS(quint16, int32, int);
#undef PROTO_TRAITS

template <>
Expand Down Expand Up @@ -387,6 +389,8 @@ void Tensor::CopyFromInternal(const Tensor& other, const TensorShape& shape) {
CASE(qint32, SINGLE_ARG(STMTS)) \
CASE(quint8, SINGLE_ARG(STMTS)) \
CASE(qint8, SINGLE_ARG(STMTS)) \
CASE(quint16, SINGLE_ARG(STMTS)) \
CASE(qint16, SINGLE_ARG(STMTS)) \
CASE(bfloat16, SINGLE_ARG(STMTS)) \
case DT_INVALID: \
LOG(FATAL) << "Type not set"; \
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/core/framework/type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ template <>
struct is_quantized<quint8> : true_type {};
template <>
struct is_quantized<qint32> : true_type {};
template <>
struct is_quantized<qint16> : true_type {};
template <>
struct is_quantized<quint16> : true_type {};

// All types not specialized are marked invalid.
template <class T>
Expand All @@ -68,6 +72,12 @@ template <>
class numeric_limits<tensorflow::quint8>
: public numeric_limits<tensorflow::uint8> {};
template <>
class numeric_limits<tensorflow::qint16>
: public numeric_limits<tensorflow::int16> {};
template <>
class numeric_limits<tensorflow::quint16>
: public numeric_limits<tensorflow::uint16> {};
template <>
class numeric_limits<tensorflow::qint32>
: public numeric_limits<tensorflow::int32> {};

Expand All @@ -77,6 +87,10 @@ struct is_signed<tensorflow::qint8> : public is_signed<tensorflow::int8> {};
template <>
struct is_signed<tensorflow::quint8> : public is_signed<tensorflow::uint8> {};
template <>
struct is_signed<tensorflow::qint16> : public is_signed<tensorflow::int16> {};
template <>
struct is_signed<tensorflow::quint16> : public is_signed<tensorflow::uint16> {};
template <>
struct is_signed<tensorflow::qint32> : public is_signed<tensorflow::int32> {};

} // namespace std
Expand Down
35 changes: 27 additions & 8 deletions tensorflow/core/framework/types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ string DataTypeString(DataType dtype) {
return "qint8";
case DT_QUINT8:
return "quint8";
case DT_QUINT16:
return "quint16";
case DT_QINT16:
return "qint16";
case DT_QINT32:
return "qint32";
case DT_BFLOAT16:
Expand Down Expand Up @@ -127,6 +131,12 @@ bool DataTypeFromString(StringPiece sp, DataType* dt) {
} else if (sp == "quint8") {
*dt = DT_QUINT8;
return true;
} else if (sp == "qint16") {
*dt = DT_QINT16;
return true;
} else if (sp == "quint16") {
*dt = DT_QUINT16;
return true;
} else if (sp == "qint32") {
*dt = DT_QINT32;
return true;
Expand All @@ -149,9 +159,9 @@ string DataTypeSliceString(const DataTypeSlice types) {
}

DataTypeVector AllTypes() {
return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16,
DT_INT8, DT_STRING, DT_COMPLEX64, DT_INT64, DT_BOOL,
DT_QINT8, DT_QUINT8, DT_QINT32};
return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16,
DT_INT8, DT_STRING, DT_COMPLEX64, DT_INT64, DT_BOOL,
DT_QINT8, DT_QUINT8, DT_QINT16, DT_QUINT16, DT_QINT32};
}

#if !defined(__ANDROID__)
Expand All @@ -160,11 +170,13 @@ DataTypeVector RealNumberTypes() {
return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8, DT_INT16, DT_INT8};
}

DataTypeVector QuantizedTypes() { return {DT_QINT8, DT_QUINT8, DT_QINT32}; }
DataTypeVector QuantizedTypes() {
return {DT_QINT8, DT_QUINT8, DT_QINT16, DT_QUINT16, DT_QINT32};
}

DataTypeVector RealAndQuantizedTypes() {
return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8,
DT_INT16, DT_INT8, DT_QINT8, DT_QUINT8, DT_QINT32};
return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8, DT_INT16,
DT_INT8, DT_QINT8, DT_QUINT8, DT_QINT16, DT_QUINT16, DT_QINT32};
}

DataTypeVector NumberTypes() {
Expand All @@ -180,10 +192,13 @@ DataTypeVector NumberTypes() {
return {DT_FLOAT, DT_INT32, DT_QINT8, DT_QUINT8, DT_QINT32};
}

DataTypeVector QuantizedTypes() { return {DT_QINT8, DT_QUINT8, DT_QINT32}; }
DataTypeVector QuantizedTypes() {
return {DT_QINT8, DT_QUINT8, DT_QINT16, DT_QUINT16, DT_QINT32};
}

DataTypeVector RealAndQuantizedTypes() {
return {DT_FLOAT, DT_INT32, DT_QINT8, DT_QUINT8, DT_QINT32};
return {DT_FLOAT, DT_INT32, DT_QINT8, DT_QUINT8,
DT_QINT16, DT_QUINT16, DT_QINT32};
}

#endif // defined(__ANDROID__)
Expand All @@ -203,6 +218,8 @@ bool DataTypeCanUseMemcpy(DataType dt) {
case DT_BOOL:
case DT_QINT8:
case DT_QUINT8:
case DT_QINT16:
case DT_QUINT16:
case DT_QINT32:
case DT_BFLOAT16:
return true;
Expand All @@ -215,6 +232,8 @@ bool DataTypeIsQuantized(DataType dt) {
switch (dt) {
case DT_QINT8:
case DT_QUINT8:
case DT_QINT16:
case DT_QUINT16:
case DT_QINT32:
return true;
default:
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/framework/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ struct EnumToDataType {}; // Specializations below
typedef Eigen::QInt8 qint8;
typedef Eigen::QUInt8 quint8;
typedef Eigen::QInt32 qint32;
typedef Eigen::QInt16 qint16;
typedef Eigen::QUInt16 quint16;

MATCH_TYPE_AND_ENUM(float, DT_FLOAT);
MATCH_TYPE_AND_ENUM(double, DT_DOUBLE);
Expand All @@ -174,6 +176,8 @@ MATCH_TYPE_AND_ENUM(int64, DT_INT64);
MATCH_TYPE_AND_ENUM(bool, DT_BOOL);
MATCH_TYPE_AND_ENUM(qint8, DT_QINT8);
MATCH_TYPE_AND_ENUM(quint8, DT_QUINT8);
MATCH_TYPE_AND_ENUM(qint16, DT_QINT16);
MATCH_TYPE_AND_ENUM(quint16, DT_QUINT16);
MATCH_TYPE_AND_ENUM(qint32, DT_QINT32);
MATCH_TYPE_AND_ENUM(bfloat16, DT_BFLOAT16);

Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/framework/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ enum DataType {
DT_QUINT8 = 12; // Quantized uint8
DT_QINT32 = 13; // Quantized int32
DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops.
DT_QINT16 = 15; // Quantized int16
DT_QUINT16 = 16; // Quantized uint16

// TODO(josh11b): DT_GENERIC_PROTO = ??;
// TODO(jeff,josh11b): DT_UINT64? DT_UINT32? DT_UINT16?
Expand All @@ -45,4 +47,6 @@ enum DataType {
DT_QUINT8_REF = 112;
DT_QINT32_REF = 113;
DT_BFLOAT16_REF = 114;
DT_QINT16_REF = 115;
DT_QUINT16_REF = 116;
}
2 changes: 2 additions & 0 deletions tensorflow/core/public/tensor_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ typedef enum {
TF_QUINT8 = 12, // Quantized uint8
TF_QINT32 = 13, // Quantized int32
TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. Only for cast ops.
TF_QINT16 = 15, // Quantized int16
TF_QUINT16 = 16, // Quantized uint16
} TF_DataType;

// --------------------------------------------------------------------------
Expand Down
22 changes: 16 additions & 6 deletions tensorflow/python/client/tf_session_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ Status PyArrayDescr_to_TF_DataType(PyArray_Descr* descr,
*out_tf_datatype = TF_QUINT8;
} else if (key == "qint8") {
*out_tf_datatype = TF_QINT8;
} else if (key == "qint16") {
*out_tf_datatype = TF_QINT16;
} else if (key == "quint16") {
*out_tf_datatype = TF_QUINT16;
} else if (key == "qint32") {
*out_tf_datatype = TF_QINT32;
} else {
Expand Down Expand Up @@ -98,12 +102,12 @@ Status PyArray_TYPE_to_TF_DataType(PyArrayObject* array,
case NPY_UINT8:
*out_tf_datatype = TF_UINT8;
break;
case NPY_INT16:
*out_tf_datatype = TF_INT16;
break;
case NPY_INT8:
*out_tf_datatype = TF_INT8;
break;
case NPY_INT16:
*out_tf_datatype = TF_INT16;
break;
case NPY_INT64:
*out_tf_datatype = TF_INT64;
break;
Expand Down Expand Up @@ -143,12 +147,12 @@ Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype,
case TF_UINT8:
*out_pyarray_type = NPY_UINT8;
break;
case TF_INT16:
*out_pyarray_type = NPY_INT16;
break;
case TF_INT8:
*out_pyarray_type = NPY_INT8;
break;
case TF_INT16:
*out_pyarray_type = NPY_INT16;
break;
case TF_INT64:
*out_pyarray_type = NPY_INT64;
break;
Expand All @@ -170,6 +174,12 @@ Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype,
case TF_QUINT8:
*out_pyarray_type = NPY_UINT8;
break;
case TF_QINT16:
*out_pyarray_type = NPY_INT16;
break;
case TF_QUINT16:
*out_pyarray_type = NPY_UINT16;
break;
case TF_QINT32:
*out_pyarray_type = NPY_INT32;
break;
Expand Down
27 changes: 25 additions & 2 deletions tensorflow/python/framework/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class DType(object):
* `tf.qint8`: Quantized 8-bit signed integer.
* `tf.quint8`: Quantized 8-bit unsigned integer.
* `tf.qint16`: Quantized 16-bit signed integer.
* `tf.quint16`: Quantized 16-bit unsigned integer.
* `tf.qint32`: Quantized 32-bit signed integer.
In addition, variants of these types with the `_ref` suffix are
Expand Down Expand Up @@ -136,7 +138,7 @@ def is_floating(self):
@property
def is_quantized(self):
"""Returns whether this is a quantized data type."""
return self.base_dtype in [qint8, quint8, qint32, bfloat16]
return self.base_dtype in [qint8, quint8, qint16, quint16, qint32, bfloat16]

@property
def is_unsigned(self):
Expand Down Expand Up @@ -258,6 +260,8 @@ def __hash__(self):
bool = DType(types_pb2.DT_BOOL)
qint8 = DType(types_pb2.DT_QINT8)
quint8 = DType(types_pb2.DT_QUINT8)
qint16 = DType(types_pb2.DT_QINT16)
quint16 = DType(types_pb2.DT_QUINT16)
qint32 = DType(types_pb2.DT_QINT32)
bfloat16 = DType(types_pb2.DT_BFLOAT16)
float32_ref = DType(types_pb2.DT_FLOAT_REF)
Expand All @@ -273,6 +277,8 @@ def __hash__(self):
bool_ref = DType(types_pb2.DT_BOOL_REF)
qint8_ref = DType(types_pb2.DT_QINT8_REF)
quint8_ref = DType(types_pb2.DT_QUINT8_REF)
qint16_ref = DType(types_pb2.DT_QINT16_REF)
quint16_ref = DType(types_pb2.DT_QUINT16_REF)
qint32_ref = DType(types_pb2.DT_QINT32_REF)
bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF)

Expand All @@ -292,6 +298,8 @@ def __hash__(self):
types_pb2.DT_BOOL: bool,
types_pb2.DT_QINT8: qint8,
types_pb2.DT_QUINT8: quint8,
types_pb2.DT_QINT16: qint16,
types_pb2.DT_QUINT16: quint16,
types_pb2.DT_QINT32: qint32,
types_pb2.DT_BFLOAT16: bfloat16,
types_pb2.DT_FLOAT_REF: float32_ref,
Expand All @@ -306,6 +314,8 @@ def __hash__(self):
types_pb2.DT_BOOL_REF: bool_ref,
types_pb2.DT_QINT8_REF: qint8_ref,
types_pb2.DT_QUINT8_REF: quint8_ref,
types_pb2.DT_QINT16_REF: qint16_ref,
types_pb2.DT_QUINT16_REF: quint16_ref,
types_pb2.DT_QINT32_REF: qint32_ref,
types_pb2.DT_BFLOAT16_REF: bfloat16_ref,
}
Expand All @@ -325,6 +335,8 @@ def __hash__(self):
types_pb2.DT_BOOL: "bool",
types_pb2.DT_QINT8: "qint8",
types_pb2.DT_QUINT8: "quint8",
types_pb2.DT_QINT16: "qint16",
types_pb2.DT_QUINT16: "quint16",
types_pb2.DT_QINT32: "qint32",
types_pb2.DT_BFLOAT16: "bfloat16",
types_pb2.DT_FLOAT_REF: "float32_ref",
Expand All @@ -339,6 +351,8 @@ def __hash__(self):
types_pb2.DT_BOOL_REF: "bool_ref",
types_pb2.DT_QINT8_REF: "qint8_ref",
types_pb2.DT_QUINT8_REF: "quint8_ref",
types_pb2.DT_QINT16_REF: "qint16_ref",
types_pb2.DT_QUINT16_REF: "quint16_ref",
types_pb2.DT_QINT32_REF: "qint32_ref",
types_pb2.DT_BFLOAT16_REF: "bfloat16_ref",
}
Expand All @@ -359,6 +373,8 @@ def __hash__(self):
# hard-coding of names.
_np_qint8 = np.dtype([("qint8", np.int8, 1)])
_np_quint8 = np.dtype([("quint8", np.uint8, 1)])
_np_qint16 = np.dtype([("qint16", np.int16, 1)])
_np_quint16 = np.dtype([("quint16", np.uint16, 1)])
_np_qint32 = np.dtype([("qint32", np.int32, 1)])

# Standard mappings between types_pb2.DataType values and numpy.dtypes.
Expand All @@ -375,6 +391,8 @@ def __hash__(self):
(np.bool, bool),
(_np_qint8, qint8),
(_np_quint8, quint8),
(_np_qint16, qint16),
(_np_quint16, quint16),
(_np_qint32, qint32),
# NOTE(touts): Intentionally no way to feed a DT_BFLOAT16.
])
Expand All @@ -393,6 +411,8 @@ def __hash__(self):
types_pb2.DT_BOOL: np.bool,
types_pb2.DT_QINT8: _np_qint8,
types_pb2.DT_QUINT8: _np_quint8,
types_pb2.DT_QINT16: _np_qint16,
types_pb2.DT_QUINT16: _np_quint16,
types_pb2.DT_QINT32: _np_qint32,
types_pb2.DT_BFLOAT16: np.uint16,

Expand All @@ -409,13 +429,16 @@ def __hash__(self):
types_pb2.DT_BOOL_REF: np.bool,
types_pb2.DT_QINT8_REF: _np_qint8,
types_pb2.DT_QUINT8_REF: _np_quint8,
types_pb2.DT_QINT16_REF: _np_qint16,
types_pb2.DT_QUINT16_REF: _np_quint16,
types_pb2.DT_QINT32_REF: _np_qint32,
types_pb2.DT_BFLOAT16_REF: np.uint16,
}


QUANTIZED_DTYPES = frozenset(
[qint8, quint8, qint32, qint8_ref, quint8_ref, qint32_ref])
[qint8, quint8, qint16, quint16, qint32, qint8_ref, quint8_ref, qint16_ref,
quint16_ref, qint32_ref])


def as_dtype(type_value):
Expand Down

0 comments on commit ef50775

Please sign in to comment.