@@ -13,15 +13,19 @@ See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
1515
16- #include < utility>
17-
1816#include " tensorflow/core/util/tensor_slice_reader.h"
1917
18+ #include < utility>
19+ #include < vector>
20+
2021#include " tensorflow/core/framework/types.h"
2122#include " tensorflow/core/framework/versions.pb.h"
2223#include " tensorflow/core/lib/core/status_test_util.h"
2324#include " tensorflow/core/lib/core/stringpiece.h"
25+ #include " tensorflow/core/lib/io/iterator.h"
2426#include " tensorflow/core/lib/io/path.h"
27+ #include " tensorflow/core/lib/io/table.h"
28+ #include " tensorflow/core/lib/io/table_builder.h"
2529#include " tensorflow/core/lib/strings/str_util.h"
2630#include " tensorflow/core/lib/strings/strcat.h"
2731#include " tensorflow/core/platform/env.h"
@@ -30,6 +34,7 @@ limitations under the License.
3034#include " tensorflow/core/platform/test.h"
3135#include " tensorflow/core/platform/types.h"
3236#include " tensorflow/core/public/version.h"
37+ #include " tensorflow/core/util/saved_tensor_slice.pb.h"
3338#include " tensorflow/core/util/saved_tensor_slice_util.h"
3439#include " tensorflow/core/util/tensor_slice_reader_cache.h"
3540#include " tensorflow/core/util/tensor_slice_writer.h"
@@ -309,6 +314,102 @@ TEST_SIMPLE_INT(int16, int32)
309314TEST_SIMPLE_INT (int8, int32)
310315TEST_SIMPLE_INT (uint8, int32)
311316
317+ // Modifies the SavedTensorSlices messages in a checkpoint to allow creating
318+ // malformed or unsupported checkpoints.
319+ void MutateSavedTensorSlices (
320+ const std::string& fname,
321+ const std::function<std::string(SavedTensorSlices)>& mutator) {
322+ table::Options options;
323+ options.compression = table::kNoCompression ;
324+
325+ // Read all entres from the table.
326+ std::vector<std::pair<std::string, std::string>> entries;
327+ {
328+ std::unique_ptr<RandomAccessFile> file;
329+ TF_CHECK_OK (Env::Default ()->NewRandomAccessFile (fname, &file));
330+ uint64 file_size;
331+ TF_CHECK_OK (Env::Default ()->GetFileSize (fname, &file_size));
332+ table::Table* t;
333+ TF_CHECK_OK (table::Table::Open (options, file.get (), file_size, &t));
334+ std::unique_ptr<table::Table> table (t);
335+ std::unique_ptr<table::Iterator> it (table->NewIterator ());
336+ for (it->Seek (" " ); it->Valid (); it->Next ()) {
337+ entries.emplace_back (it->key (), it->value ());
338+ }
339+ TF_CHECK_OK (it->status ());
340+ }
341+
342+ // Rewrite the table, mutating each value.
343+ {
344+ std::unique_ptr<WritableFile> file;
345+ TF_CHECK_OK (Env::Default ()->NewWritableFile (fname, &file));
346+ table::TableBuilder builder (options, file.get ());
347+ for (const auto & entry : entries) {
348+ SavedTensorSlices sts;
349+ CHECK (sts.ParseFromString (entry.second ));
350+ builder.Add (entry.first , mutator (std::move (sts)));
351+ }
352+ TF_CHECK_OK (builder.Finish ());
353+ TF_CHECK_OK (file->Close ());
354+ }
355+ }
356+
357+ TEST (TensorSliceReaderTest, MissingTensorType) {
358+ const string fname = io::JoinPath (testing::TmpDir (), " invalid_checkpoint" );
359+ TensorSliceWriter writer (fname, CreateTableTensorSliceBuilder);
360+ const int32 data[] = {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 };
361+ TensorShape shape ({4 , 5 });
362+ TensorSlice slice = TensorSlice::ParseOrDie (" 0,2:-" );
363+ TF_CHECK_OK (writer.Add (" test" , shape, slice, data));
364+ TF_CHECK_OK (writer.Finish ());
365+
366+ MutateSavedTensorSlices (fname, [](SavedTensorSlices sts) {
367+ if (sts.has_meta ()) {
368+ for (auto & tensor : *sts.mutable_meta ()->mutable_tensor ()) {
369+ tensor.clear_type ();
370+ }
371+ }
372+ return sts.SerializeAsString ();
373+ });
374+
375+ TensorSliceReader reader (fname, OpenTableTensorSliceReader);
376+ TF_CHECK_OK (reader.status ());
377+
378+ // The tensor should be present, but loading it should fail due to the
379+ // unset (invalid) type.
380+ EXPECT_TRUE (reader.HasTensor (" test" , nullptr , nullptr ));
381+ std::unique_ptr<Tensor> tensor;
382+ EXPECT_FALSE (reader.GetTensor (" test" , &tensor).ok ());
383+ }
384+
385+ TEST (TensorSliceReaderTest, UnsupportedTensorType) {
386+ const string fname = io::JoinPath (testing::TmpDir (), " int32_ref_checkpoint" );
387+ TensorSliceWriter writer (fname, CreateTableTensorSliceBuilder);
388+ const int32 data[] = {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 };
389+ TensorShape shape ({4 , 5 });
390+ TensorSlice slice = TensorSlice::ParseOrDie (" 0,2:-" );
391+ TF_CHECK_OK (writer.Add (" test" , shape, slice, data));
392+ TF_CHECK_OK (writer.Finish ());
393+
394+ MutateSavedTensorSlices (fname, [](SavedTensorSlices sts) {
395+ if (sts.has_meta ()) {
396+ for (auto & tensor : *sts.mutable_meta ()->mutable_tensor ()) {
397+ tensor.set_type (DT_INT32_REF);
398+ }
399+ }
400+ return sts.SerializeAsString ();
401+ });
402+
403+ TensorSliceReader reader (fname, OpenTableTensorSliceReader);
404+ TF_CHECK_OK (reader.status ());
405+
406+ // The tensor should be present, but loading it should fail due to the
407+ // unsupported type.
408+ EXPECT_TRUE (reader.HasTensor (" test" , nullptr , nullptr ));
409+ std::unique_ptr<Tensor> tensor;
410+ EXPECT_FALSE (reader.GetTensor (" test" , &tensor).ok ());
411+ }
412+
312413void CachedTensorSliceReaderTesterHelper (
313414 const TensorSliceWriter::CreateBuilderFunction& create_function,
314415 const TensorSliceReader::OpenTableFunction& open_function) {
0 commit comments