Skip to content

Commit e8dc637

Browse files
Add BuildTensorSlice for building from unvalidated TensorSliceProtos.
This avoids several sources of crashes and undefined behavior when loading invalid checkpoints. PiperOrigin-RevId: 392785704 Change-Id: Icd9713c768b882f3b58b427eddac376060696833
1 parent dd2c199 commit e8dc637

File tree

5 files changed

+107
-1
lines changed

5 files changed

+107
-1
lines changed

Diff for: tensorflow/core/framework/tensor_slice.cc

+31
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include "tensorflow/core/framework/tensor_slice.h"
17+
18+
#include <limits>
1719
#include <vector>
20+
1821
#include "tensorflow/core/lib/core/errors.h"
1922
#include "tensorflow/core/lib/strings/numbers.h"
2023
#include "tensorflow/core/lib/strings/str_util.h"
@@ -44,6 +47,34 @@ TensorSlice::TensorSlice(
4447
}
4548
}
4649

50+
Status TensorSlice::BuildTensorSlice(const TensorSliceProto& proto,
51+
TensorSlice* output) {
52+
output->Clear();
53+
output->starts_.reserve(proto.extent_size());
54+
output->lengths_.reserve(proto.extent_size());
55+
for (const auto& e : proto.extent()) {
56+
int64_t l = GetExtentLength(e);
57+
if (e.start() != 0 || l != kFullExtent) {
58+
if (e.start() < 0 || l <= 0) {
59+
return errors::InvalidArgument(
60+
"Expected non-negative start and positive length but got start = ",
61+
e.start(), ", length = ", l, ": extent = ", e.ShortDebugString());
62+
}
63+
// Calculating the extent end must not cause signed integer overflow.
64+
if (static_cast<uint64_t>(e.start()) + static_cast<uint64_t>(e.length()) >
65+
std::numeric_limits<int64_t>::max()) {
66+
return errors::InvalidArgument(
67+
"Extent end exceeds the maximum possible size: extent = ",
68+
e.ShortDebugString());
69+
}
70+
}
71+
output->starts_.push_back(e.start());
72+
output->lengths_.push_back(l);
73+
}
74+
75+
return Status::OK();
76+
}
77+
4778
Status TensorSlice::Parse(const string& str, TensorSlice* slice) {
4879
std::vector<string> items = str_util::Split(str, ':', str_util::SkipEmpty());
4980
slice->starts_.reserve(items.size());

Diff for: tensorflow/core/framework/tensor_slice.h

+6
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ class TensorSlice {
4848
explicit TensorSlice(
4949
std::initializer_list<std::pair<int64_t, int64_t>> extents);
5050

51+
// This factory methods should be used instead of the constructor that takes a
52+
// `TensorSliceProto` if calling code cannot validate that the sizes specify a
53+
// valid `TensorSlice`.
54+
static Status BuildTensorSlice(const TensorSliceProto& proto,
55+
TensorSlice* output);
56+
5157
static Status Parse(const string& str, TensorSlice* output);
5258
static TensorSlice ParseOrDie(const string& str) {
5359
TensorSlice ret;

Diff for: tensorflow/core/framework/tensor_slice_test.cc

+44
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ limitations under the License.
1515

1616
#include "tensorflow/core/framework/tensor_slice.h"
1717

18+
#include <limits>
19+
1820
#include "tensorflow/core/lib/core/status_test_util.h"
1921
#include "tensorflow/core/platform/logging.h"
2022
#include "tensorflow/core/platform/protobuf.h"
@@ -125,6 +127,48 @@ TEST(TensorSliceTest, Serialization) {
125127
}
126128
}
127129

130+
// Testing `BuildTensorSlice` with valid and invalid input protos.
131+
TEST(TensorSliceTest, BuildTensorSlice) {
132+
TensorSliceProto proto;
133+
TensorSlice({{0, -1}, {0, 10}, {14, 1}}).AsProto(&proto);
134+
TensorSlice s;
135+
136+
// Successful building.
137+
{
138+
TF_ASSERT_OK(TensorSlice::BuildTensorSlice(proto, &s));
139+
EXPECT_EQ("-:0,10:14,1", s.DebugString());
140+
}
141+
142+
// Failed building due to negative extent start.
143+
{
144+
TensorSliceProto invalid_proto = proto;
145+
invalid_proto.mutable_extent(0)->set_start(-1);
146+
EXPECT_FALSE(TensorSlice::BuildTensorSlice(invalid_proto, &s).ok());
147+
}
148+
149+
// Failed building due to negative extent length.
150+
{
151+
TensorSliceProto invalid_proto = proto;
152+
invalid_proto.mutable_extent(2)->set_length(-1);
153+
EXPECT_FALSE(TensorSlice::BuildTensorSlice(invalid_proto, &s).ok());
154+
}
155+
156+
// Failed building due to missing extent length.
157+
{
158+
TensorSliceProto invalid_proto = proto;
159+
invalid_proto.mutable_extent(2)->clear_length();
160+
EXPECT_FALSE(TensorSlice::BuildTensorSlice(invalid_proto, &s).ok());
161+
}
162+
163+
// Failed building due to extent end overflowing.
164+
{
165+
TensorSliceProto invalid_proto = proto;
166+
invalid_proto.mutable_extent(2)->set_length(
167+
std::numeric_limits<int64_t>::max());
168+
EXPECT_FALSE(TensorSlice::BuildTensorSlice(invalid_proto, &s).ok());
169+
}
170+
}
171+
128172
// Testing the slice intersection
129173
TEST(TensorSliceTest, Intersection) {
130174
// "EVERYTHING" intersects with everything

Diff for: tensorflow/core/util/tensor_slice_reader.cc

+3-1
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,9 @@ void TensorSliceReader::LoadShard(int shard) const {
172172
status_ = TensorShape::BuildTensorShapeBase(ssm.shape(), &ssm_shape);
173173
if (!status_.ok()) return;
174174
for (const TensorSliceProto& tsp : ssm.slice()) {
175-
TensorSlice ss_slice(tsp);
175+
TensorSlice ss_slice;
176+
status_ = TensorSlice::BuildTensorSlice(tsp, &ss_slice);
177+
if (!status_.ok()) return;
176178
status_ = RegisterTensorSlice(ssm.name(), ssm_shape, ssm.type(), fname,
177179
ss_slice, &tensors_);
178180
if (!status_.ok()) return;

Diff for: tensorflow/core/util/tensor_slice_reader_test.cc

+23
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,29 @@ TEST(TensorSliceReaderTest, NegativeTensorShapeDimension) {
436436
EXPECT_FALSE(reader.status().ok());
437437
}
438438

439+
TEST(TensorSliceReaderTest, InvalidTensorSlice) {
440+
const string fname =
441+
io::JoinPath(testing::TmpDir(), "invalid_slice_checkpoint");
442+
TensorSliceWriter writer(fname, CreateTableTensorSliceBuilder);
443+
const int32 data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
444+
TF_CHECK_OK(writer.Add("test", TensorShape({4, 5}),
445+
TensorSlice::ParseOrDie("0,2:-"), data));
446+
TF_CHECK_OK(writer.Finish());
447+
448+
MutateSavedTensorSlices(fname, [](SavedTensorSlices sts) {
449+
if (sts.has_meta()) {
450+
for (auto& tensor : *sts.mutable_meta()->mutable_tensor()) {
451+
tensor.mutable_slice(0)->mutable_extent(0)->set_length(-10);
452+
}
453+
}
454+
return sts.SerializeAsString();
455+
});
456+
457+
TensorSliceReader reader(fname, OpenTableTensorSliceReader);
458+
// The negative exent length should cause loading to fail.
459+
EXPECT_FALSE(reader.status().ok());
460+
}
461+
439462
void CachedTensorSliceReaderTesterHelper(
440463
const TensorSliceWriter::CreateBuilderFunction& create_function,
441464
const TensorSliceReader::OpenTableFunction& open_function) {

0 commit comments

Comments
 (0)