Skip to content

Commit

Permalink
Merge pull request #51918 from pranve/cp_r2.420210909143715
Browse files Browse the repository at this point in the history
Prevent crashes when loading tensor slices with unsupported types.
  • Loading branch information
mihaimaruseac committed Oct 18, 2021
2 parents 26765ea + dabeecf commit f9ccdfd
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 11 deletions.
1 change: 1 addition & 0 deletions tensorflow/core/framework/BUILD
Expand Up @@ -785,6 +785,7 @@ tf_cuda_library(
"//tensorflow/core/lib/strings:str_util",
"//tensorflow/core/lib/strings:strcat",
"//tensorflow/core/platform:abi",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:macros",
"//tensorflow/core/platform:platform_port",
Expand Down
26 changes: 18 additions & 8 deletions tensorflow/core/framework/tensor.cc
Expand Up @@ -48,6 +48,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
Expand Down Expand Up @@ -725,11 +726,11 @@ bool Tensor::RefCountIsOne() const {
// The macro CASES() expands to a switch statement conditioned on
// TYPE_ENUM. Each case expands the STMTS after a typedef for T.
#define SINGLE_ARG(...) __VA_ARGS__
#define CASE(TYPE, STMTS) \
case DataTypeToEnum<TYPE>::value: { \
typedef TYPE T; \
STMTS; \
break; \
#define CASE(TYPE, STMTS) \
case DataTypeToEnum<TYPE>::value: { \
typedef TF_ATTRIBUTE_UNUSED TYPE T; \
STMTS; \
break; \
}
#define CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, INVALID, DEFAULT) \
switch (TYPE_ENUM) { \
Expand Down Expand Up @@ -765,9 +766,8 @@ bool Tensor::RefCountIsOne() const {
}

#define CASES(TYPE_ENUM, STMTS) \
CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, \
LOG(FATAL) << "Unexpected type: " << TYPE_ENUM; \
, LOG(FATAL) << "Type not set";)
CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \
, LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;)

Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
: shape_(shape), buf_(nullptr) {
Expand Down Expand Up @@ -797,6 +797,16 @@ Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape,
}
}

Status Tensor::BuildTensor(DataType type, const TensorShape& shape,
Tensor* out_tensor) {
// Avoid crashes due to invalid or unsupported types.
CASES_WITH_DEFAULT(
type, {}, return errors::InvalidArgument("Type not set"),
return errors::InvalidArgument("Unexpected type: ", DataType_Name(type)));
*out_tensor = Tensor(type, shape);
return Status::OK();
}

// NOTE(mrry): The default allocator for a Tensor (when none is specified) is
// the default CPU allocator for NUMA zone 0. Accessing that currently involves
// acquiring a lock, which guards initialization of the per-NUMA zone
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/core/framework/tensor.h
Expand Up @@ -164,6 +164,15 @@ class Tensor {
/// for details.
explicit Tensor(DataType type);

/// \brief Initializes a tensor with the input `type` and `shape`, or returns
/// an error and leaves `out_tensor` unmodified. This factory method should be
/// used instead of the corresponding constructor if calling code cannot
/// validate that the `DataType` is valid and supported.
///
/// The underlying buffer is allocated using a `CPUAllocator`.
static Status BuildTensor(DataType type, const TensorShape& shape,
Tensor* out_tensor);

private:
// A tag type for selecting the `Tensor` constructor overload that creates a
// scalar tensor in host memory.
Expand Down
4 changes: 3 additions & 1 deletion tensorflow/core/util/tensor_slice_reader.cc
Expand Up @@ -248,7 +248,9 @@ Status TensorSliceReader::GetTensor(
slice = tss->Slices().begin()->second.slice;
}

std::unique_ptr<tensorflow::Tensor> t(new tensorflow::Tensor(type, shape));
std::unique_ptr<tensorflow::Tensor> t(new tensorflow::Tensor);
Status s = tensorflow::Tensor::BuildTensor(type, shape, t.get());
if (!s.ok()) return s;
bool success = false;

#define READER_COPY(dt) \
Expand Down
105 changes: 103 additions & 2 deletions tensorflow/core/util/tensor_slice_reader_test.cc
Expand Up @@ -13,15 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <utility>

#include "tensorflow/core/util/tensor_slice_reader.h"

#include <utility>
#include <vector>

#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/iterator.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/io/table.h"
#include "tensorflow/core/lib/io/table_builder.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
Expand All @@ -30,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/saved_tensor_slice.pb.h"
#include "tensorflow/core/util/saved_tensor_slice_util.h"
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
#include "tensorflow/core/util/tensor_slice_writer.h"
Expand Down Expand Up @@ -309,6 +314,102 @@ TEST_SIMPLE_INT(int16, int32)
TEST_SIMPLE_INT(int8, int32)
TEST_SIMPLE_INT(uint8, int32)

// Modifies the SavedTensorSlices messages in a checkpoint to allow creating
// malformed or unsupported checkpoints.
void MutateSavedTensorSlices(
const std::string& fname,
const std::function<std::string(SavedTensorSlices)>& mutator) {
table::Options options;
options.compression = table::kNoCompression;

// Read all entres from the table.
std::vector<std::pair<std::string, std::string>> entries;
{
std::unique_ptr<RandomAccessFile> file;
TF_CHECK_OK(Env::Default()->NewRandomAccessFile(fname, &file));
uint64 file_size;
TF_CHECK_OK(Env::Default()->GetFileSize(fname, &file_size));
table::Table* t;
TF_CHECK_OK(table::Table::Open(options, file.get(), file_size, &t));
std::unique_ptr<table::Table> table(t);
std::unique_ptr<table::Iterator> it(table->NewIterator());
for (it->Seek(""); it->Valid(); it->Next()) {
entries.emplace_back(it->key(), it->value());
}
TF_CHECK_OK(it->status());
}

// Rewrite the table, mutating each value.
{
std::unique_ptr<WritableFile> file;
TF_CHECK_OK(Env::Default()->NewWritableFile(fname, &file));
table::TableBuilder builder(options, file.get());
for (const auto& entry : entries) {
SavedTensorSlices sts;
CHECK(sts.ParseFromString(entry.second));
builder.Add(entry.first, mutator(std::move(sts)));
}
TF_CHECK_OK(builder.Finish());
TF_CHECK_OK(file->Close());
}
}

TEST(TensorSliceReaderTest, MissingTensorType) {
const string fname = io::JoinPath(testing::TmpDir(), "invalid_checkpoint");
TensorSliceWriter writer(fname, CreateTableTensorSliceBuilder);
const int32 data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
TensorShape shape({4, 5});
TensorSlice slice = TensorSlice::ParseOrDie("0,2:-");
TF_CHECK_OK(writer.Add("test", shape, slice, data));
TF_CHECK_OK(writer.Finish());

MutateSavedTensorSlices(fname, [](SavedTensorSlices sts) {
if (sts.has_meta()) {
for (auto& tensor : *sts.mutable_meta()->mutable_tensor()) {
tensor.clear_type();
}
}
return sts.SerializeAsString();
});

TensorSliceReader reader(fname, OpenTableTensorSliceReader);
TF_CHECK_OK(reader.status());

// The tensor should be present, but loading it should fail due to the
// unset (invalid) type.
EXPECT_TRUE(reader.HasTensor("test", nullptr, nullptr));
std::unique_ptr<Tensor> tensor;
EXPECT_FALSE(reader.GetTensor("test", &tensor).ok());
}

TEST(TensorSliceReaderTest, UnsupportedTensorType) {
const string fname = io::JoinPath(testing::TmpDir(), "int32_ref_checkpoint");
TensorSliceWriter writer(fname, CreateTableTensorSliceBuilder);
const int32 data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
TensorShape shape({4, 5});
TensorSlice slice = TensorSlice::ParseOrDie("0,2:-");
TF_CHECK_OK(writer.Add("test", shape, slice, data));
TF_CHECK_OK(writer.Finish());

MutateSavedTensorSlices(fname, [](SavedTensorSlices sts) {
if (sts.has_meta()) {
for (auto& tensor : *sts.mutable_meta()->mutable_tensor()) {
tensor.set_type(DT_INT32_REF);
}
}
return sts.SerializeAsString();
});

TensorSliceReader reader(fname, OpenTableTensorSliceReader);
TF_CHECK_OK(reader.status());

// The tensor should be present, but loading it should fail due to the
// unsupported type.
EXPECT_TRUE(reader.HasTensor("test", nullptr, nullptr));
std::unique_ptr<Tensor> tensor;
EXPECT_FALSE(reader.GetTensor("test", &tensor).ok());
}

void CachedTensorSliceReaderTesterHelper(
const TensorSliceWriter::CreateBuilderFunction& create_function,
const TensorSliceReader::OpenTableFunction& open_function) {
Expand Down

0 comments on commit f9ccdfd

Please sign in to comment.