Skip to content

Commit b619c6f

Browse files
Use BuildTensorShapeBase when parsing unverified TensorShapes during checkpoint loading.
This avoids crashing when the TensorShape has negative dimensions. PiperOrigin-RevId: 392769882 Change-Id: Id1f7ae7fcf8142193556af47abfda81b13d3cce4
1 parent 24cb834 commit b619c6f

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

Diff for: tensorflow/core/util/tensor_slice_reader.cc

+3-1
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,9 @@ void TensorSliceReader::LoadShard(int shard) const {
168168
"checkpoint");
169169
if (!status_.ok()) return;
170170
for (const SavedSliceMeta& ssm : sts.meta().tensor()) {
171-
TensorShape ssm_shape(ssm.shape());
171+
TensorShape ssm_shape;
172+
status_ = TensorShape::BuildTensorShapeBase(ssm.shape(), &ssm_shape);
173+
if (!status_.ok()) return;
172174
for (const TensorSliceProto& tsp : ssm.slice()) {
173175
TensorSlice ss_slice(tsp);
174176
status_ = RegisterTensorSlice(ssm.name(), ssm_shape, ssm.type(), fname,

Diff for: tensorflow/core/util/tensor_slice_reader_test.cc

+26
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
#include <utility>
1919
#include <vector>
2020

21+
#include "tensorflow/core/framework/tensor_shape.pb.h"
2122
#include "tensorflow/core/framework/types.h"
2223
#include "tensorflow/core/framework/versions.pb.h"
2324
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -410,6 +411,31 @@ TEST(TensorSliceReaderTest, UnsupportedTensorType) {
410411
EXPECT_FALSE(reader.GetTensor("test", &tensor).ok());
411412
}
412413

414+
TEST(TensorSliceReaderTest, NegativeTensorShapeDimension) {
415+
const string fname =
416+
io::JoinPath(testing::TmpDir(), "negative_dim_checkpoint");
417+
TensorSliceWriter writer(fname, CreateTableTensorSliceBuilder);
418+
const int32 data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
419+
TF_CHECK_OK(writer.Add("test", TensorShape({4, 5}),
420+
TensorSlice::ParseOrDie("0,2:-"), data));
421+
TF_CHECK_OK(writer.Finish());
422+
423+
MutateSavedTensorSlices(fname, [](SavedTensorSlices sts) {
424+
if (sts.has_meta()) {
425+
for (auto& tensor : *sts.mutable_meta()->mutable_tensor()) {
426+
for (auto& dim : *tensor.mutable_shape()->mutable_dim()) {
427+
dim.set_size(-dim.size());
428+
}
429+
}
430+
}
431+
return sts.SerializeAsString();
432+
});
433+
434+
TensorSliceReader reader(fname, OpenTableTensorSliceReader);
435+
// The negative dimension should cause loading to fail.
436+
EXPECT_FALSE(reader.status().ok());
437+
}
438+
413439
void CachedTensorSliceReaderTesterHelper(
414440
const TensorSliceWriter::CreateBuilderFunction& create_function,
415441
const TensorSliceReader::OpenTableFunction& open_function) {

0 commit comments

Comments
 (0)