Skip to content

Commit 368af87

Browse files
Avoid buffer overflow when loading tensors with insufficient data from checkpoints.
`CopyDataFromTensorSliceToTensorSlice` does not (and cannot conveniently) provide any bounds checking on its own, so the size is instead checked prior to passing unvalidated data to that function. PiperOrigin-RevId: 392971286 Change-Id: If2073b36d4d5eedd386329f56729395fd7effee1
1 parent 8ccde23 commit 368af87

File tree

3 files changed

+69
-0
lines changed

3 files changed

+69
-0
lines changed

Diff for: tensorflow/core/util/saved_tensor_slice_util.h

+26
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ Status ParseShapeAndSlice(const string& shape_and_slice, TensorShape* shape,
5959
template <typename T>
6060
struct SaveTypeTraits;
6161

62+
template <typename T>
63+
int TensorProtoDataSize(const TensorProto& t);
64+
6265
template <typename T>
6366
const typename SaveTypeTraits<T>::SavedType* TensorProtoData(
6467
const TensorProto& t);
@@ -95,6 +98,10 @@ void Fill(T* data, size_t n, TensorProto* t);
9598
#define TENSOR_PROTO_EXTRACT_TYPE(TYPE, FIELD, FTYPE) \
9699
TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, FTYPE) \
97100
template <> \
101+
inline int TensorProtoDataSize<TYPE>(const TensorProto& t) { \
102+
return t.FIELD##_val_size(); \
103+
} \
104+
template <> \
98105
inline void Fill(const TYPE* data, size_t n, TensorProto* t) { \
99106
typename protobuf::RepeatedField<FTYPE> copy(data, data + n); \
100107
t->mutable_##FIELD##_val()->Swap(&copy); \
@@ -104,6 +111,10 @@ void Fill(T* data, size_t n, TensorProto* t);
104111
#define TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(TYPE, FIELD, FTYPE) \
105112
TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, TYPE) \
106113
template <> \
114+
inline int TensorProtoDataSize<TYPE>(const TensorProto& t) { \
115+
return t.FIELD##_val_size() / 2; \
116+
} \
117+
template <> \
107118
inline void Fill(const TYPE* data, size_t n, TensorProto* t) { \
108119
const FTYPE* sub = reinterpret_cast<const FTYPE*>(data); \
109120
typename protobuf::RepeatedField<FTYPE> copy(sub, sub + 2 * n); \
@@ -136,6 +147,11 @@ TENSOR_PROTO_EXTRACT_TYPE(quint16, int, int32);
136147
template <>
137148
struct SaveTypeTraits<qint32> : SaveTypeTraits<int32> {};
138149

150+
template <>
151+
inline int TensorProtoDataSize<qint32>(const TensorProto& t) {
152+
return t.int_val_size();
153+
}
154+
139155
template <>
140156
inline const int32* TensorProtoData<qint32>(const TensorProto& t) {
141157
static_assert(SaveTypeTraits<qint32>::supported,
@@ -158,6 +174,11 @@ struct SaveTypeTraits<Eigen::half> {
158174
typedef protobuf::RepeatedField<int32> RepeatedField;
159175
};
160176

177+
template <>
178+
inline int TensorProtoDataSize<Eigen::half>(const TensorProto& t) {
179+
return t.half_val_size();
180+
}
181+
161182
template <>
162183
inline const int* TensorProtoData<Eigen::half>(const TensorProto& t) {
163184
return t.half_val().data();
@@ -187,6 +208,11 @@ struct SaveTypeTraits<tstring> {
187208
typedef protobuf::RepeatedPtrField<string> RepeatedField;
188209
};
189210

211+
template <>
212+
inline int TensorProtoDataSize<tstring>(const TensorProto& t) {
213+
return t.string_val_size();
214+
}
215+
190216
template <>
191217
inline const string* const* TensorProtoData<tstring>(const TensorProto& t) {
192218
static_assert(SaveTypeTraits<tstring>::supported,

Diff for: tensorflow/core/util/tensor_slice_reader.h

+16
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,22 @@ bool TensorSliceReader::CopySliceData(const string& name,
181181
<< slice_s.DebugString() << ": computed key = " << key;
182182
return false;
183183
}
184+
// Ensure the TensorSlice contains the expected amount of data.
185+
TensorShape shp_s;
186+
Status s = slice_s.SliceTensorShape(tss->shape(), &shp_s);
187+
if (!s.ok()) {
188+
VLOG(1) << "Failed to slice tensor " << name << ", slice "
189+
<< slice_s.DebugString() << ": " << s;
190+
return false;
191+
}
192+
if (checkpoint::TensorProtoDataSize<T>(sts.data().data()) !=
193+
shp_s.num_elements()) {
194+
VLOG(1) << "Tensor " << name << ", slice " << slice_s.DebugString()
195+
<< " had an unexpected amount of data: expected = "
196+
<< shp_s.num_elements() << ", got = "
197+
<< checkpoint::TensorProtoDataSize<T>(sts.data().data());
198+
return false;
199+
}
184200
CopyDataFromTensorSliceToTensorSlice(
185201
tss->shape(), slice_s, slice,
186202
checkpoint::TensorProtoData<T>(sts.data().data()), data);

Diff for: tensorflow/core/util/tensor_slice_reader_test.cc

+27
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,33 @@ TEST(TensorSliceReaderTest, InvalidTensorSlice) {
459459
EXPECT_FALSE(reader.status().ok());
460460
}
461461

462+
TEST(TensorSliceReaderTest, MissingTensorData) {
463+
const string fname =
464+
io::JoinPath(testing::TmpDir(), "missing_data_checkpoint");
465+
TensorSliceWriter writer(fname, CreateTableTensorSliceBuilder);
466+
const int32 data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
467+
TF_ASSERT_OK(writer.Add("test", TensorShape({4, 5}),
468+
TensorSlice::ParseOrDie("0,2:-"), data));
469+
TF_ASSERT_OK(writer.Finish());
470+
471+
MutateSavedTensorSlices(fname, [&](SavedTensorSlices sts) {
472+
if (sts.has_data()) {
473+
// Replace the data with only 4 elements.
474+
Fill(data, 4, sts.mutable_data()->mutable_data());
475+
}
476+
return sts.SerializeAsString();
477+
});
478+
479+
TensorSliceReader reader(fname, OpenTableTensorSliceReader);
480+
TF_ASSERT_OK(reader.status());
481+
482+
// The tensor should be present, but loading it should fail due to the missing
483+
// data.
484+
EXPECT_TRUE(reader.HasTensor("test", nullptr, nullptr));
485+
std::unique_ptr<Tensor> tensor;
486+
EXPECT_FALSE(reader.GetTensor("test", &tensor).ok());
487+
}
488+
462489
void CachedTensorSliceReaderTesterHelper(
463490
const TensorSliceWriter::CreateBuilderFunction& create_function,
464491
const TensorSliceReader::OpenTableFunction& open_function) {

0 commit comments

Comments
 (0)