Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent crashes when loading tensor slices with unsupported types. #51918

Merged
merged 1 commit into from Oct 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorflow/core/framework/BUILD
Expand Up @@ -785,6 +785,7 @@ tf_cuda_library(
"//tensorflow/core/lib/strings:str_util",
"//tensorflow/core/lib/strings:strcat",
"//tensorflow/core/platform:abi",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:macros",
"//tensorflow/core/platform:platform_port",
Expand Down
26 changes: 18 additions & 8 deletions tensorflow/core/framework/tensor.cc
Expand Up @@ -48,6 +48,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
Expand Down Expand Up @@ -725,11 +726,11 @@ bool Tensor::RefCountIsOne() const {
// The macro CASES() expands to a switch statement conditioned on
// TYPE_ENUM. Each case expands the STMTS after a typedef for T.
#define SINGLE_ARG(...) __VA_ARGS__
#define CASE(TYPE, STMTS) \
case DataTypeToEnum<TYPE>::value: { \
typedef TYPE T; \
STMTS; \
break; \
#define CASE(TYPE, STMTS) \
case DataTypeToEnum<TYPE>::value: { \
typedef TF_ATTRIBUTE_UNUSED TYPE T; \
STMTS; \
break; \
}
#define CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, INVALID, DEFAULT) \
switch (TYPE_ENUM) { \
Expand Down Expand Up @@ -765,9 +766,8 @@ bool Tensor::RefCountIsOne() const {
}

#define CASES(TYPE_ENUM, STMTS) \
CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, \
LOG(FATAL) << "Unexpected type: " << TYPE_ENUM; \
, LOG(FATAL) << "Type not set";)
CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \
, LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;)

Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
: shape_(shape), buf_(nullptr) {
Expand Down Expand Up @@ -797,6 +797,16 @@ Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape,
}
}

Status Tensor::BuildTensor(DataType type, const TensorShape& shape,
Tensor* out_tensor) {
// Avoid crashes due to invalid or unsupported types.
CASES_WITH_DEFAULT(
type, {}, return errors::InvalidArgument("Type not set"),
return errors::InvalidArgument("Unexpected type: ", DataType_Name(type)));
*out_tensor = Tensor(type, shape);
return Status::OK();
}

// NOTE(mrry): The default allocator for a Tensor (when none is specified) is
// the default CPU allocator for NUMA zone 0. Accessing that currently involves
// acquiring a lock, which guards initialization of the per-NUMA zone
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/core/framework/tensor.h
Expand Up @@ -164,6 +164,15 @@ class Tensor {
/// for details.
explicit Tensor(DataType type);

/// \brief Initializes a tensor with the input `type` and `shape`, or returns
/// an error and leaves `out_tensor` unmodified. This factory method should be
/// used instead of the corresponding constructor if calling code cannot
/// validate that the `DataType` is valid and supported.
///
/// The underlying buffer is allocated using a `CPUAllocator`.
static Status BuildTensor(DataType type, const TensorShape& shape,
Tensor* out_tensor);

private:
// A tag type for selecting the `Tensor` constructor overload that creates a
// scalar tensor in host memory.
Expand Down
4 changes: 3 additions & 1 deletion tensorflow/core/util/tensor_slice_reader.cc
Expand Up @@ -248,7 +248,9 @@ Status TensorSliceReader::GetTensor(
slice = tss->Slices().begin()->second.slice;
}

std::unique_ptr<tensorflow::Tensor> t(new tensorflow::Tensor(type, shape));
std::unique_ptr<tensorflow::Tensor> t(new tensorflow::Tensor);
Status s = tensorflow::Tensor::BuildTensor(type, shape, t.get());
if (!s.ok()) return s;
bool success = false;

#define READER_COPY(dt) \
Expand Down
105 changes: 103 additions & 2 deletions tensorflow/core/util/tensor_slice_reader_test.cc
Expand Up @@ -13,15 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <utility>

#include "tensorflow/core/util/tensor_slice_reader.h"

#include <utility>
#include <vector>

#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/iterator.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/io/table.h"
#include "tensorflow/core/lib/io/table_builder.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
Expand All @@ -30,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/saved_tensor_slice.pb.h"
#include "tensorflow/core/util/saved_tensor_slice_util.h"
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
#include "tensorflow/core/util/tensor_slice_writer.h"
Expand Down Expand Up @@ -309,6 +314,102 @@ TEST_SIMPLE_INT(int16, int32)
TEST_SIMPLE_INT(int8, int32)
TEST_SIMPLE_INT(uint8, int32)

// Modifies the SavedTensorSlices messages in a checkpoint to allow creating
// malformed or unsupported checkpoints.
void MutateSavedTensorSlices(
const std::string& fname,
const std::function<std::string(SavedTensorSlices)>& mutator) {
table::Options options;
options.compression = table::kNoCompression;

// Read all entres from the table.
std::vector<std::pair<std::string, std::string>> entries;
{
std::unique_ptr<RandomAccessFile> file;
TF_CHECK_OK(Env::Default()->NewRandomAccessFile(fname, &file));
uint64 file_size;
TF_CHECK_OK(Env::Default()->GetFileSize(fname, &file_size));
table::Table* t;
TF_CHECK_OK(table::Table::Open(options, file.get(), file_size, &t));
std::unique_ptr<table::Table> table(t);
std::unique_ptr<table::Iterator> it(table->NewIterator());
for (it->Seek(""); it->Valid(); it->Next()) {
entries.emplace_back(it->key(), it->value());
}
TF_CHECK_OK(it->status());
}

// Rewrite the table, mutating each value.
{
std::unique_ptr<WritableFile> file;
TF_CHECK_OK(Env::Default()->NewWritableFile(fname, &file));
table::TableBuilder builder(options, file.get());
for (const auto& entry : entries) {
SavedTensorSlices sts;
CHECK(sts.ParseFromString(entry.second));
builder.Add(entry.first, mutator(std::move(sts)));
}
TF_CHECK_OK(builder.Finish());
TF_CHECK_OK(file->Close());
}
}

TEST(TensorSliceReaderTest, MissingTensorType) {
const string fname = io::JoinPath(testing::TmpDir(), "invalid_checkpoint");
TensorSliceWriter writer(fname, CreateTableTensorSliceBuilder);
const int32 data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
TensorShape shape({4, 5});
TensorSlice slice = TensorSlice::ParseOrDie("0,2:-");
TF_CHECK_OK(writer.Add("test", shape, slice, data));
TF_CHECK_OK(writer.Finish());

MutateSavedTensorSlices(fname, [](SavedTensorSlices sts) {
if (sts.has_meta()) {
for (auto& tensor : *sts.mutable_meta()->mutable_tensor()) {
tensor.clear_type();
}
}
return sts.SerializeAsString();
});

TensorSliceReader reader(fname, OpenTableTensorSliceReader);
TF_CHECK_OK(reader.status());

// The tensor should be present, but loading it should fail due to the
// unset (invalid) type.
EXPECT_TRUE(reader.HasTensor("test", nullptr, nullptr));
std::unique_ptr<Tensor> tensor;
EXPECT_FALSE(reader.GetTensor("test", &tensor).ok());
}

TEST(TensorSliceReaderTest, UnsupportedTensorType) {
const string fname = io::JoinPath(testing::TmpDir(), "int32_ref_checkpoint");
TensorSliceWriter writer(fname, CreateTableTensorSliceBuilder);
const int32 data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
TensorShape shape({4, 5});
TensorSlice slice = TensorSlice::ParseOrDie("0,2:-");
TF_CHECK_OK(writer.Add("test", shape, slice, data));
TF_CHECK_OK(writer.Finish());

MutateSavedTensorSlices(fname, [](SavedTensorSlices sts) {
if (sts.has_meta()) {
for (auto& tensor : *sts.mutable_meta()->mutable_tensor()) {
tensor.set_type(DT_INT32_REF);
}
}
return sts.SerializeAsString();
});

TensorSliceReader reader(fname, OpenTableTensorSliceReader);
TF_CHECK_OK(reader.status());

// The tensor should be present, but loading it should fail due to the
// unsupported type.
EXPECT_TRUE(reader.HasTensor("test", nullptr, nullptr));
std::unique_ptr<Tensor> tensor;
EXPECT_FALSE(reader.GetTensor("test", &tensor).ok());
}

void CachedTensorSliceReaderTesterHelper(
const TensorSliceWriter::CreateBuilderFunction& create_function,
const TensorSliceReader::OpenTableFunction& open_function) {
Expand Down