Skip to content

Commit

Permalink
Add BuildTensorSlice for building from unvalidated TensorSliceProtos.
Browse files Browse the repository at this point in the history
This avoids several sources of crashes and undefined behavior when loading
invalid checkpoints.

PiperOrigin-RevId: 392785704
Change-Id: Icd9713c768b882f3b58b427eddac376060696833
  • Loading branch information
tensorflower-gardener committed Aug 25, 2021
1 parent dd2c199 commit e8dc637
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 1 deletion.
31 changes: 31 additions & 0 deletions tensorflow/core/framework/tensor_slice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ limitations under the License.
==============================================================================*/

#include "tensorflow/core/framework/tensor_slice.h"

#include <limits>
#include <vector>

#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
Expand Down Expand Up @@ -44,6 +47,34 @@ TensorSlice::TensorSlice(
}
}

Status TensorSlice::BuildTensorSlice(const TensorSliceProto& proto,
TensorSlice* output) {
output->Clear();
output->starts_.reserve(proto.extent_size());
output->lengths_.reserve(proto.extent_size());
for (const auto& e : proto.extent()) {
int64_t l = GetExtentLength(e);
if (e.start() != 0 || l != kFullExtent) {
if (e.start() < 0 || l <= 0) {
return errors::InvalidArgument(
"Expected non-negative start and positive length but got start = ",
e.start(), ", length = ", l, ": extent = ", e.ShortDebugString());
}
// Calculating the extent end must not cause signed integer overflow.
if (static_cast<uint64_t>(e.start()) + static_cast<uint64_t>(e.length()) >
std::numeric_limits<int64_t>::max()) {
return errors::InvalidArgument(
"Extent end exceeds the maximum possible size: extent = ",
e.ShortDebugString());
}
}
output->starts_.push_back(e.start());
output->lengths_.push_back(l);
}

return Status::OK();
}

Status TensorSlice::Parse(const string& str, TensorSlice* slice) {
std::vector<string> items = str_util::Split(str, ':', str_util::SkipEmpty());
slice->starts_.reserve(items.size());
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/core/framework/tensor_slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ class TensorSlice {
explicit TensorSlice(
std::initializer_list<std::pair<int64_t, int64_t>> extents);

// This factory methods should be used instead of the constructor that takes a
// `TensorSliceProto` if calling code cannot validate that the sizes specify a
// valid `TensorSlice`.
static Status BuildTensorSlice(const TensorSliceProto& proto,
TensorSlice* output);

static Status Parse(const string& str, TensorSlice* output);
static TensorSlice ParseOrDie(const string& str) {
TensorSlice ret;
Expand Down
44 changes: 44 additions & 0 deletions tensorflow/core/framework/tensor_slice_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.

#include "tensorflow/core/framework/tensor_slice.h"

#include <limits>

#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
Expand Down Expand Up @@ -125,6 +127,48 @@ TEST(TensorSliceTest, Serialization) {
}
}

// Testing `BuildTensorSlice` with valid and invalid input protos.
TEST(TensorSliceTest, BuildTensorSlice) {
TensorSliceProto proto;
TensorSlice({{0, -1}, {0, 10}, {14, 1}}).AsProto(&proto);
TensorSlice s;

// Successful building.
{
TF_ASSERT_OK(TensorSlice::BuildTensorSlice(proto, &s));
EXPECT_EQ("-:0,10:14,1", s.DebugString());
}

// Failed building due to negative extent start.
{
TensorSliceProto invalid_proto = proto;
invalid_proto.mutable_extent(0)->set_start(-1);
EXPECT_FALSE(TensorSlice::BuildTensorSlice(invalid_proto, &s).ok());
}

// Failed building due to negative extent length.
{
TensorSliceProto invalid_proto = proto;
invalid_proto.mutable_extent(2)->set_length(-1);
EXPECT_FALSE(TensorSlice::BuildTensorSlice(invalid_proto, &s).ok());
}

// Failed building due to missing extent length.
{
TensorSliceProto invalid_proto = proto;
invalid_proto.mutable_extent(2)->clear_length();
EXPECT_FALSE(TensorSlice::BuildTensorSlice(invalid_proto, &s).ok());
}

// Failed building due to extent end overflowing.
{
TensorSliceProto invalid_proto = proto;
invalid_proto.mutable_extent(2)->set_length(
std::numeric_limits<int64_t>::max());
EXPECT_FALSE(TensorSlice::BuildTensorSlice(invalid_proto, &s).ok());
}
}

// Testing the slice intersection
TEST(TensorSliceTest, Intersection) {
// "EVERYTHING" intersects with everything
Expand Down
4 changes: 3 additions & 1 deletion tensorflow/core/util/tensor_slice_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ void TensorSliceReader::LoadShard(int shard) const {
status_ = TensorShape::BuildTensorShapeBase(ssm.shape(), &ssm_shape);
if (!status_.ok()) return;
for (const TensorSliceProto& tsp : ssm.slice()) {
TensorSlice ss_slice(tsp);
TensorSlice ss_slice;
status_ = TensorSlice::BuildTensorSlice(tsp, &ss_slice);
if (!status_.ok()) return;
status_ = RegisterTensorSlice(ssm.name(), ssm_shape, ssm.type(), fname,
ss_slice, &tensors_);
if (!status_.ok()) return;
Expand Down
23 changes: 23 additions & 0 deletions tensorflow/core/util/tensor_slice_reader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,29 @@ TEST(TensorSliceReaderTest, NegativeTensorShapeDimension) {
EXPECT_FALSE(reader.status().ok());
}

TEST(TensorSliceReaderTest, InvalidTensorSlice) {
const string fname =
io::JoinPath(testing::TmpDir(), "invalid_slice_checkpoint");
TensorSliceWriter writer(fname, CreateTableTensorSliceBuilder);
const int32 data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
TF_CHECK_OK(writer.Add("test", TensorShape({4, 5}),
TensorSlice::ParseOrDie("0,2:-"), data));
TF_CHECK_OK(writer.Finish());

MutateSavedTensorSlices(fname, [](SavedTensorSlices sts) {
if (sts.has_meta()) {
for (auto& tensor : *sts.mutable_meta()->mutable_tensor()) {
tensor.mutable_slice(0)->mutable_extent(0)->set_length(-10);
}
}
return sts.SerializeAsString();
});

TensorSliceReader reader(fname, OpenTableTensorSliceReader);
// The negative exent length should cause loading to fail.
EXPECT_FALSE(reader.status().ok());
}

void CachedTensorSliceReaderTesterHelper(
const TensorSliceWriter::CreateBuilderFunction& create_function,
const TensorSliceReader::OpenTableFunction& open_function) {
Expand Down

0 comments on commit e8dc637

Please sign in to comment.