@@ -59,6 +59,9 @@ Status ParseShapeAndSlice(const string& shape_and_slice, TensorShape* shape,
5959template <typename T>
6060struct SaveTypeTraits ;
6161
62+ template <typename T>
63+ int TensorProtoDataSize (const TensorProto& t);
64+
6265template <typename T>
6366const typename SaveTypeTraits<T>::SavedType* TensorProtoData (
6467 const TensorProto& t);
@@ -95,6 +98,10 @@ void Fill(T* data, size_t n, TensorProto* t);
9598#define TENSOR_PROTO_EXTRACT_TYPE (TYPE, FIELD, FTYPE ) \
9699 TENSOR_PROTO_EXTRACT_TYPE_HELPER (TYPE, FIELD, FTYPE, FTYPE) \
97100 template <> \
101+ inline int TensorProtoDataSize<TYPE>(const TensorProto& t) { \
102+ return t.FIELD ##_val_size (); \
103+ } \
104+ template <> \
98105 inline void Fill (const TYPE* data, size_t n, TensorProto* t) { \
99106 typename protobuf::RepeatedField<FTYPE> copy (data, data + n); \
100107 t->mutable_ ##FIELD##_val ()->Swap (©); \
@@ -104,6 +111,10 @@ void Fill(T* data, size_t n, TensorProto* t);
104111#define TENSOR_PROTO_EXTRACT_TYPE_COMPLEX (TYPE, FIELD, FTYPE ) \
105112 TENSOR_PROTO_EXTRACT_TYPE_HELPER (TYPE, FIELD, FTYPE, TYPE) \
106113 template <> \
114+ inline int TensorProtoDataSize<TYPE>(const TensorProto& t) { \
115+ return t.FIELD ##_val_size () / 2 ; \
116+ } \
117+ template <> \
107118 inline void Fill (const TYPE* data, size_t n, TensorProto* t) { \
108119 const FTYPE* sub = reinterpret_cast <const FTYPE*>(data); \
109120 typename protobuf::RepeatedField<FTYPE> copy (sub, sub + 2 * n); \
@@ -136,6 +147,11 @@ TENSOR_PROTO_EXTRACT_TYPE(quint16, int, int32);
136147template <>
137148struct SaveTypeTraits <qint32> : SaveTypeTraits<int32> {};
138149
150+ template <>
151+ inline int TensorProtoDataSize<qint32>(const TensorProto& t) {
152+ return t.int_val_size ();
153+ }
154+
139155template <>
140156inline const int32* TensorProtoData<qint32>(const TensorProto& t) {
141157 static_assert (SaveTypeTraits<qint32>::supported,
@@ -158,6 +174,11 @@ struct SaveTypeTraits<Eigen::half> {
158174 typedef protobuf::RepeatedField<int32> RepeatedField;
159175};
160176
177+ template <>
178+ inline int TensorProtoDataSize<Eigen::half>(const TensorProto& t) {
179+ return t.half_val_size ();
180+ }
181+
161182template <>
162183inline const int * TensorProtoData<Eigen::half>(const TensorProto& t) {
163184 return t.half_val ().data ();
@@ -187,6 +208,11 @@ struct SaveTypeTraits<tstring> {
187208 typedef protobuf::RepeatedPtrField<string> RepeatedField;
188209};
189210
211+ template <>
212+ inline int TensorProtoDataSize<tstring>(const TensorProto& t) {
213+ return t.string_val_size ();
214+ }
215+
190216template <>
191217inline const string* const * TensorProtoData<tstring>(const TensorProto& t) {
192218 static_assert (SaveTypeTraits<tstring>::supported,
0 commit comments