@@ -18,6 +18,7 @@ limitations under the License.
1818#include < utility>
1919#include < vector>
2020
21+ #include " tensorflow/core/framework/tensor_shape.pb.h"
2122#include " tensorflow/core/framework/types.h"
2223#include " tensorflow/core/framework/versions.pb.h"
2324#include " tensorflow/core/lib/core/status_test_util.h"
@@ -410,6 +411,31 @@ TEST(TensorSliceReaderTest, UnsupportedTensorType) {
410411 EXPECT_FALSE (reader.GetTensor (" test" , &tensor).ok ());
411412}
412413
414+ TEST (TensorSliceReaderTest, NegativeTensorShapeDimension) {
415+ const string fname =
416+ io::JoinPath (testing::TmpDir (), " negative_dim_checkpoint" );
417+ TensorSliceWriter writer (fname, CreateTableTensorSliceBuilder);
418+ const int32 data[] = {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 };
419+ TF_CHECK_OK (writer.Add (" test" , TensorShape ({4 , 5 }),
420+ TensorSlice::ParseOrDie (" 0,2:-" ), data));
421+ TF_CHECK_OK (writer.Finish ());
422+
423+ MutateSavedTensorSlices (fname, [](SavedTensorSlices sts) {
424+ if (sts.has_meta ()) {
425+ for (auto & tensor : *sts.mutable_meta ()->mutable_tensor ()) {
426+ for (auto & dim : *tensor.mutable_shape ()->mutable_dim ()) {
427+ dim.set_size (-dim.size ());
428+ }
429+ }
430+ }
431+ return sts.SerializeAsString ();
432+ });
433+
434+ TensorSliceReader reader (fname, OpenTableTensorSliceReader);
435+ // The negative dimension should cause loading to fail.
436+ EXPECT_FALSE (reader.status ().ok ());
437+ }
438+
413439void CachedTensorSliceReaderTesterHelper (
414440 const TensorSliceWriter::CreateBuilderFunction& create_function,
415441 const TensorSliceReader::OpenTableFunction& open_function) {
0 commit comments