Skip to content

Commit abcced0

Browse files
Prevent crashes when loading tensor slices with unsupported types.
Also fix the `Tensor(const TensorShape&)` constructor swapping the LOG(FATAL) messages for the unset and unsupported types. PiperOrigin-RevId: 392695027 Change-Id: I4beda7db950db951d273e3259a7c8534ece49354
1 parent 0622858 commit abcced0

File tree

5 files changed

+134
-11
lines changed

5 files changed

+134
-11
lines changed

Diff for: tensorflow/core/framework/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,7 @@ tf_cuda_library(
835835
"//tensorflow/core/lib/strings:str_util",
836836
"//tensorflow/core/lib/strings:strcat",
837837
"//tensorflow/core/platform:abi",
838+
"//tensorflow/core/platform:errors",
838839
"//tensorflow/core/platform:logging",
839840
"//tensorflow/core/platform:macros",
840841
"//tensorflow/core/platform:platform_port",

Diff for: tensorflow/core/framework/tensor.cc

+18-8
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ limitations under the License.
5252
#include "tensorflow/core/lib/gtl/inlined_vector.h"
5353
#include "tensorflow/core/lib/strings/str_util.h"
5454
#include "tensorflow/core/lib/strings/strcat.h"
55+
#include "tensorflow/core/platform/errors.h"
5556
#include "tensorflow/core/platform/logging.h"
5657
#include "tensorflow/core/platform/macros.h"
5758
#include "tensorflow/core/platform/protobuf.h"
@@ -723,11 +724,11 @@ bool Tensor::RefCountIsOne() const {
723724
// The macro CASES() expands to a switch statement conditioned on
724725
// TYPE_ENUM. Each case expands the STMTS after a typedef for T.
725726
#define SINGLE_ARG(...) __VA_ARGS__
726-
#define CASE(TYPE, STMTS) \
727-
case DataTypeToEnum<TYPE>::value: { \
728-
typedef TYPE T; \
729-
STMTS; \
730-
break; \
727+
#define CASE(TYPE, STMTS) \
728+
case DataTypeToEnum<TYPE>::value: { \
729+
typedef TF_ATTRIBUTE_UNUSED TYPE T; \
730+
STMTS; \
731+
break; \
731732
}
732733
#define CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, INVALID, DEFAULT) \
733734
switch (TYPE_ENUM) { \
@@ -763,9 +764,8 @@ bool Tensor::RefCountIsOne() const {
763764
}
764765

765766
#define CASES(TYPE_ENUM, STMTS) \
766-
CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, \
767-
LOG(FATAL) << "Unexpected type: " << TYPE_ENUM; \
768-
, LOG(FATAL) << "Type not set";)
767+
CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \
768+
, LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;)
769769

770770
Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
771771
: shape_(shape), buf_(nullptr) {
@@ -795,6 +795,16 @@ Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape,
795795
}
796796
}
797797

798+
Status Tensor::BuildTensor(DataType type, const TensorShape& shape,
799+
Tensor* out_tensor) {
800+
// Avoid crashes due to invalid or unsupported types.
801+
CASES_WITH_DEFAULT(
802+
type, {}, return errors::InvalidArgument("Type not set"),
803+
return errors::InvalidArgument("Unexpected type: ", DataType_Name(type)));
804+
*out_tensor = Tensor(type, shape);
805+
return Status::OK();
806+
}
807+
798808
// NOTE(mrry): The default allocator for a Tensor (when none is specified) is
799809
// the default CPU allocator for NUMA zone 0. Accessing that currently involves
800810
// acquiring a lock, which guards initialization of the per-NUMA zone

Diff for: tensorflow/core/framework/tensor.h

+9
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,15 @@ class Tensor {
170170
/// for details.
171171
explicit Tensor(DataType type);
172172

173+
/// \brief Initializes a tensor with the input `type` and `shape`, or returns
174+
/// an error and leaves `out_tensor` unmodified. This factory method should be
175+
/// used instead of the corresponding constructor if calling code cannot
176+
/// validate that the `DataType` is valid and supported.
177+
///
178+
/// The underlying buffer is allocated using a `CPUAllocator`.
179+
static Status BuildTensor(DataType type, const TensorShape& shape,
180+
Tensor* out_tensor);
181+
173182
private:
174183
// A tag type for selecting the `Tensor` constructor overload that creates a
175184
// scalar tensor in host memory.

Diff for: tensorflow/core/util/tensor_slice_reader.cc

+3-1
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,9 @@ Status TensorSliceReader::GetTensor(
248248
slice = tss->Slices().begin()->second.slice;
249249
}
250250

251-
std::unique_ptr<tensorflow::Tensor> t(new tensorflow::Tensor(type, shape));
251+
std::unique_ptr<tensorflow::Tensor> t(new tensorflow::Tensor);
252+
Status s = tensorflow::Tensor::BuildTensor(type, shape, t.get());
253+
if (!s.ok()) return s;
252254
bool success = false;
253255

254256
#define READER_COPY(dt) \

Diff for: tensorflow/core/util/tensor_slice_reader_test.cc

+103-2
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,19 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#include <utility>
17-
1816
#include "tensorflow/core/util/tensor_slice_reader.h"
1917

18+
#include <utility>
19+
#include <vector>
20+
2021
#include "tensorflow/core/framework/types.h"
2122
#include "tensorflow/core/framework/versions.pb.h"
2223
#include "tensorflow/core/lib/core/status_test_util.h"
2324
#include "tensorflow/core/lib/core/stringpiece.h"
25+
#include "tensorflow/core/lib/io/iterator.h"
2426
#include "tensorflow/core/lib/io/path.h"
27+
#include "tensorflow/core/lib/io/table.h"
28+
#include "tensorflow/core/lib/io/table_builder.h"
2529
#include "tensorflow/core/lib/strings/str_util.h"
2630
#include "tensorflow/core/lib/strings/strcat.h"
2731
#include "tensorflow/core/platform/env.h"
@@ -30,6 +34,7 @@ limitations under the License.
3034
#include "tensorflow/core/platform/test.h"
3135
#include "tensorflow/core/platform/types.h"
3236
#include "tensorflow/core/public/version.h"
37+
#include "tensorflow/core/util/saved_tensor_slice.pb.h"
3338
#include "tensorflow/core/util/saved_tensor_slice_util.h"
3439
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
3540
#include "tensorflow/core/util/tensor_slice_writer.h"
@@ -309,6 +314,102 @@ TEST_SIMPLE_INT(int16, int32)
309314
TEST_SIMPLE_INT(int8, int32)
310315
TEST_SIMPLE_INT(uint8, int32)
311316

317+
// Modifies the SavedTensorSlices messages in a checkpoint to allow creating
318+
// malformed or unsupported checkpoints.
319+
void MutateSavedTensorSlices(
320+
const std::string& fname,
321+
const std::function<std::string(SavedTensorSlices)>& mutator) {
322+
table::Options options;
323+
options.compression = table::kNoCompression;
324+
325+
// Read all entres from the table.
326+
std::vector<std::pair<std::string, std::string>> entries;
327+
{
328+
std::unique_ptr<RandomAccessFile> file;
329+
TF_CHECK_OK(Env::Default()->NewRandomAccessFile(fname, &file));
330+
uint64 file_size;
331+
TF_CHECK_OK(Env::Default()->GetFileSize(fname, &file_size));
332+
table::Table* t;
333+
TF_CHECK_OK(table::Table::Open(options, file.get(), file_size, &t));
334+
std::unique_ptr<table::Table> table(t);
335+
std::unique_ptr<table::Iterator> it(table->NewIterator());
336+
for (it->Seek(""); it->Valid(); it->Next()) {
337+
entries.emplace_back(it->key(), it->value());
338+
}
339+
TF_CHECK_OK(it->status());
340+
}
341+
342+
// Rewrite the table, mutating each value.
343+
{
344+
std::unique_ptr<WritableFile> file;
345+
TF_CHECK_OK(Env::Default()->NewWritableFile(fname, &file));
346+
table::TableBuilder builder(options, file.get());
347+
for (const auto& entry : entries) {
348+
SavedTensorSlices sts;
349+
CHECK(sts.ParseFromString(entry.second));
350+
builder.Add(entry.first, mutator(std::move(sts)));
351+
}
352+
TF_CHECK_OK(builder.Finish());
353+
TF_CHECK_OK(file->Close());
354+
}
355+
}
356+
357+
TEST(TensorSliceReaderTest, MissingTensorType) {
358+
const string fname = io::JoinPath(testing::TmpDir(), "invalid_checkpoint");
359+
TensorSliceWriter writer(fname, CreateTableTensorSliceBuilder);
360+
const int32 data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
361+
TensorShape shape({4, 5});
362+
TensorSlice slice = TensorSlice::ParseOrDie("0,2:-");
363+
TF_CHECK_OK(writer.Add("test", shape, slice, data));
364+
TF_CHECK_OK(writer.Finish());
365+
366+
MutateSavedTensorSlices(fname, [](SavedTensorSlices sts) {
367+
if (sts.has_meta()) {
368+
for (auto& tensor : *sts.mutable_meta()->mutable_tensor()) {
369+
tensor.clear_type();
370+
}
371+
}
372+
return sts.SerializeAsString();
373+
});
374+
375+
TensorSliceReader reader(fname, OpenTableTensorSliceReader);
376+
TF_CHECK_OK(reader.status());
377+
378+
// The tensor should be present, but loading it should fail due to the
379+
// unset (invalid) type.
380+
EXPECT_TRUE(reader.HasTensor("test", nullptr, nullptr));
381+
std::unique_ptr<Tensor> tensor;
382+
EXPECT_FALSE(reader.GetTensor("test", &tensor).ok());
383+
}
384+
385+
TEST(TensorSliceReaderTest, UnsupportedTensorType) {
386+
const string fname = io::JoinPath(testing::TmpDir(), "int32_ref_checkpoint");
387+
TensorSliceWriter writer(fname, CreateTableTensorSliceBuilder);
388+
const int32 data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
389+
TensorShape shape({4, 5});
390+
TensorSlice slice = TensorSlice::ParseOrDie("0,2:-");
391+
TF_CHECK_OK(writer.Add("test", shape, slice, data));
392+
TF_CHECK_OK(writer.Finish());
393+
394+
MutateSavedTensorSlices(fname, [](SavedTensorSlices sts) {
395+
if (sts.has_meta()) {
396+
for (auto& tensor : *sts.mutable_meta()->mutable_tensor()) {
397+
tensor.set_type(DT_INT32_REF);
398+
}
399+
}
400+
return sts.SerializeAsString();
401+
});
402+
403+
TensorSliceReader reader(fname, OpenTableTensorSliceReader);
404+
TF_CHECK_OK(reader.status());
405+
406+
// The tensor should be present, but loading it should fail due to the
407+
// unsupported type.
408+
EXPECT_TRUE(reader.HasTensor("test", nullptr, nullptr));
409+
std::unique_ptr<Tensor> tensor;
410+
EXPECT_FALSE(reader.GetTensor("test", &tensor).ok());
411+
}
412+
312413
void CachedTensorSliceReaderTesterHelper(
313414
const TensorSliceWriter::CreateBuilderFunction& create_function,
314415
const TensorSliceReader::OpenTableFunction& open_function) {

0 commit comments

Comments
 (0)