Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix tf.raw_ops.SaveSlices vulnerability with unsupported dtypes.
Check that given dtype is supported and emit a descriptive error if not.

PiperOrigin-RevId: 461660795
  • Loading branch information
poulsbo authored and tensorflower-gardener committed Jul 18, 2022
1 parent 4419d10 commit 5dd7b86
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 8 deletions.
13 changes: 11 additions & 2 deletions tensorflow/core/util/tensor_slice_writer.cc
Expand Up @@ -131,6 +131,16 @@ Status TensorSliceWriter::Finish() {

/* static */
size_t TensorSliceWriter::MaxBytesPerElement(DataType dt) {
size_t max_bytes_per_element =
TensorSliceWriter::MaxBytesPerElementOrZero(dt);
if (max_bytes_per_element == 0) {
LOG(FATAL) << "MaxBytesPerElement not implemented for dtype: " << dt;
}
return max_bytes_per_element;
}

/* static */
size_t TensorSliceWriter::MaxBytesPerElementOrZero(DataType dt) {
switch (dt) {
case DT_FLOAT:
return 4;
Expand Down Expand Up @@ -170,9 +180,8 @@ size_t TensorSliceWriter::MaxBytesPerElement(DataType dt) {
case DT_STRING:
case DT_BFLOAT16:
default:
LOG(FATAL) << "MaxBytesPerElement not implemented for dtype: " << dt;
return 0;
}
return 0;
}

template <>
Expand Down
14 changes: 11 additions & 3 deletions tensorflow/core/util/tensor_slice_writer.h
Expand Up @@ -68,6 +68,8 @@ class TensorSliceWriter {
static size_t MaxBytesPerElement(DataType dt);

private:
static size_t MaxBytesPerElementOrZero(DataType dt);

static constexpr size_t kMaxMessageBytes = 1LL << 31;
// Filling in the TensorProto in a SavedSlice will add the following
// header bytes, in addition to the data:
Expand Down Expand Up @@ -162,9 +164,15 @@ Status TensorSliceWriter::Add(const string& name, const TensorShape& shape,
template <typename T>
Status TensorSliceWriter::SaveData(const T* data, int64_t num_elements,
SavedSlice* ss) {
size_t size_bound =
ss->ByteSize() + kTensorProtoHeaderBytes +
(MaxBytesPerElement(DataTypeToEnum<T>::value) * num_elements);
size_t max_bytes_per_element =
MaxBytesPerElementOrZero(DataTypeToEnum<T>::value);
if (max_bytes_per_element == 0) {
return errors::InvalidArgument(
"Tensor slice serialization not implemented for dtype ",
DataTypeToEnum<T>::value);
}
size_t size_bound = ss->ByteSize() + kTensorProtoHeaderBytes +
(max_bytes_per_element * num_elements);
if (size_bound > kMaxMessageBytes) {
return errors::InvalidArgument(
"Tensor slice is too large to serialize (conservative estimate: ",
Expand Down
19 changes: 16 additions & 3 deletions tensorflow/core/util/tensor_slice_writer_test.cc
Expand Up @@ -15,17 +15,19 @@ limitations under the License.

#include "tensorflow/core/util/tensor_slice_writer.h"

#include <algorithm>
#include <array>
#include <memory>
#include <vector>

#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/saved_tensor_slice_util.h"
#include "tensorflow/core/util/tensor_slice_reader.h"
Expand Down Expand Up @@ -362,6 +364,17 @@ TEST(TensorSliceWriteTest, SizeErrors) {
}
}

TEST(TensorSliceWriterTest, InvalidInput) {
SavedSlice ss;
std::array<uint32_t, 1> data;
std::fill(data.begin(), data.end(), 1234);
Status s = TensorSliceWriter::SaveData(data.data(), data.size(), &ss);
EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
EXPECT_TRUE(absl::StrContains(
s.error_message(),
"Tensor slice serialization not implemented for dtype"));
}

} // namespace checkpoint

} // namespace tensorflow

0 comments on commit 5dd7b86

Please sign in to comment.