Skip to content
Permalink
Browse files Browse the repository at this point in the history
Use BuildTensorShapeBase when parsing unverified TensorShapes during …
…checkpoint loading.

This avoids crashing when the TensorShape has negative dimensions.

PiperOrigin-RevId: 392769882
Change-Id: Id1f7ae7fcf8142193556af47abfda81b13d3cce4
  • Loading branch information
tensorflower-gardener committed Aug 24, 2021
1 parent 24cb834 commit b619c6f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tensorflow/core/util/tensor_slice_reader.cc
Expand Up @@ -168,7 +168,9 @@ void TensorSliceReader::LoadShard(int shard) const {
"checkpoint");
if (!status_.ok()) return;
for (const SavedSliceMeta& ssm : sts.meta().tensor()) {
TensorShape ssm_shape(ssm.shape());
TensorShape ssm_shape;
status_ = TensorShape::BuildTensorShapeBase(ssm.shape(), &ssm_shape);
if (!status_.ok()) return;
for (const TensorSliceProto& tsp : ssm.slice()) {
TensorSlice ss_slice(tsp);
status_ = RegisterTensorSlice(ssm.name(), ssm_shape, ssm.type(), fname,
Expand Down
26 changes: 26 additions & 0 deletions tensorflow/core/util/tensor_slice_reader_test.cc
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <utility>
#include <vector>

#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
Expand Down Expand Up @@ -410,6 +411,31 @@ TEST(TensorSliceReaderTest, UnsupportedTensorType) {
EXPECT_FALSE(reader.GetTensor("test", &tensor).ok());
}

TEST(TensorSliceReaderTest, NegativeTensorShapeDimension) {
const string fname =
io::JoinPath(testing::TmpDir(), "negative_dim_checkpoint");
TensorSliceWriter writer(fname, CreateTableTensorSliceBuilder);
const int32 data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
TF_CHECK_OK(writer.Add("test", TensorShape({4, 5}),
TensorSlice::ParseOrDie("0,2:-"), data));
TF_CHECK_OK(writer.Finish());

MutateSavedTensorSlices(fname, [](SavedTensorSlices sts) {
if (sts.has_meta()) {
for (auto& tensor : *sts.mutable_meta()->mutable_tensor()) {
for (auto& dim : *tensor.mutable_shape()->mutable_dim()) {
dim.set_size(-dim.size());
}
}
}
return sts.SerializeAsString();
});

TensorSliceReader reader(fname, OpenTableTensorSliceReader);
// The negative dimension should cause loading to fail.
EXPECT_FALSE(reader.status().ok());
}

void CachedTensorSliceReaderTesterHelper(
const TensorSliceWriter::CreateBuilderFunction& create_function,
const TensorSliceReader::OpenTableFunction& open_function) {
Expand Down

0 comments on commit b619c6f

Please sign in to comment.