Skip to content
Permalink
Browse files Browse the repository at this point in the history
Avoid buffer overflow when loading tensors with insufficient data fro…
…m checkpoints.

`CopyDataFromTensorSliceToTensorSlice` does not (and cannot conveniently)
provide any bounds checking on its own, so the size is instead checked prior
to passing unvalidated data to that function.

PiperOrigin-RevId: 392971286
Change-Id: If2073b36d4d5eedd386329f56729395fd7effee1
  • Loading branch information
tensorflower-gardener committed Aug 25, 2021
1 parent 8ccde23 commit 368af87
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 0 deletions.
26 changes: 26 additions & 0 deletions tensorflow/core/util/saved_tensor_slice_util.h
Expand Up @@ -59,6 +59,9 @@ Status ParseShapeAndSlice(const string& shape_and_slice, TensorShape* shape,
template <typename T>
struct SaveTypeTraits;

template <typename T>
int TensorProtoDataSize(const TensorProto& t);

template <typename T>
const typename SaveTypeTraits<T>::SavedType* TensorProtoData(
const TensorProto& t);
Expand Down Expand Up @@ -95,6 +98,10 @@ void Fill(T* data, size_t n, TensorProto* t);
#define TENSOR_PROTO_EXTRACT_TYPE(TYPE, FIELD, FTYPE) \
TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, FTYPE) \
template <> \
inline int TensorProtoDataSize<TYPE>(const TensorProto& t) { \
return t.FIELD##_val_size(); \
} \
template <> \
inline void Fill(const TYPE* data, size_t n, TensorProto* t) { \
typename protobuf::RepeatedField<FTYPE> copy(data, data + n); \
t->mutable_##FIELD##_val()->Swap(&copy); \
Expand All @@ -104,6 +111,10 @@ void Fill(T* data, size_t n, TensorProto* t);
#define TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(TYPE, FIELD, FTYPE) \
TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, TYPE) \
template <> \
inline int TensorProtoDataSize<TYPE>(const TensorProto& t) { \
return t.FIELD##_val_size() / 2; \
} \
template <> \
inline void Fill(const TYPE* data, size_t n, TensorProto* t) { \
const FTYPE* sub = reinterpret_cast<const FTYPE*>(data); \
typename protobuf::RepeatedField<FTYPE> copy(sub, sub + 2 * n); \
Expand Down Expand Up @@ -136,6 +147,11 @@ TENSOR_PROTO_EXTRACT_TYPE(quint16, int, int32);
template <>
struct SaveTypeTraits<qint32> : SaveTypeTraits<int32> {};

template <>
inline int TensorProtoDataSize<qint32>(const TensorProto& t) {
return t.int_val_size();
}

template <>
inline const int32* TensorProtoData<qint32>(const TensorProto& t) {
static_assert(SaveTypeTraits<qint32>::supported,
Expand All @@ -158,6 +174,11 @@ struct SaveTypeTraits<Eigen::half> {
typedef protobuf::RepeatedField<int32> RepeatedField;
};

template <>
inline int TensorProtoDataSize<Eigen::half>(const TensorProto& t) {
return t.half_val_size();
}

template <>
inline const int* TensorProtoData<Eigen::half>(const TensorProto& t) {
return t.half_val().data();
Expand Down Expand Up @@ -187,6 +208,11 @@ struct SaveTypeTraits<tstring> {
typedef protobuf::RepeatedPtrField<string> RepeatedField;
};

template <>
inline int TensorProtoDataSize<tstring>(const TensorProto& t) {
return t.string_val_size();
}

template <>
inline const string* const* TensorProtoData<tstring>(const TensorProto& t) {
static_assert(SaveTypeTraits<tstring>::supported,
Expand Down
16 changes: 16 additions & 0 deletions tensorflow/core/util/tensor_slice_reader.h
Expand Up @@ -181,6 +181,22 @@ bool TensorSliceReader::CopySliceData(const string& name,
<< slice_s.DebugString() << ": computed key = " << key;
return false;
}
// Ensure the TensorSlice contains the expected amount of data.
TensorShape shp_s;
Status s = slice_s.SliceTensorShape(tss->shape(), &shp_s);
if (!s.ok()) {
VLOG(1) << "Failed to slice tensor " << name << ", slice "
<< slice_s.DebugString() << ": " << s;
return false;
}
if (checkpoint::TensorProtoDataSize<T>(sts.data().data()) !=
shp_s.num_elements()) {
VLOG(1) << "Tensor " << name << ", slice " << slice_s.DebugString()
<< " had an unexpected amount of data: expected = "
<< shp_s.num_elements() << ", got = "
<< checkpoint::TensorProtoDataSize<T>(sts.data().data());
return false;
}
CopyDataFromTensorSliceToTensorSlice(
tss->shape(), slice_s, slice,
checkpoint::TensorProtoData<T>(sts.data().data()), data);
Expand Down
27 changes: 27 additions & 0 deletions tensorflow/core/util/tensor_slice_reader_test.cc
Expand Up @@ -459,6 +459,33 @@ TEST(TensorSliceReaderTest, InvalidTensorSlice) {
EXPECT_FALSE(reader.status().ok());
}

TEST(TensorSliceReaderTest, MissingTensorData) {
const string fname =
io::JoinPath(testing::TmpDir(), "missing_data_checkpoint");
TensorSliceWriter writer(fname, CreateTableTensorSliceBuilder);
const int32 data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
TF_ASSERT_OK(writer.Add("test", TensorShape({4, 5}),
TensorSlice::ParseOrDie("0,2:-"), data));
TF_ASSERT_OK(writer.Finish());

MutateSavedTensorSlices(fname, [&](SavedTensorSlices sts) {
if (sts.has_data()) {
// Replace the data with only 4 elements.
Fill(data, 4, sts.mutable_data()->mutable_data());
}
return sts.SerializeAsString();
});

TensorSliceReader reader(fname, OpenTableTensorSliceReader);
TF_ASSERT_OK(reader.status());

// The tensor should be present, but loading it should fail due to the missing
// data.
EXPECT_TRUE(reader.HasTensor("test", nullptr, nullptr));
std::unique_ptr<Tensor> tensor;
EXPECT_FALSE(reader.GetTensor("test", &tensor).ok());
}

void CachedTensorSliceReaderTesterHelper(
const TensorSliceWriter::CreateBuilderFunction& create_function,
const TensorSliceReader::OpenTableFunction& open_function) {
Expand Down

0 comments on commit 368af87

Please sign in to comment.