Skip to content
Permalink
Browse files Browse the repository at this point in the history
Add error checking to ImmutableConst OP that strings are not yet supp…
…orted.

PiperOrigin-RevId: 401065359
Change-Id: I9dd2bd2a2c36f22f4a05153daf6ebdc4613469d2
  • Loading branch information
SeeForTwo authored and tensorflower-gardener committed Oct 5, 2021
1 parent c29040f commit 1cb6bb6
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
3 changes: 3 additions & 0 deletions tensorflow/core/kernels/immutable_constant_op.cc
Expand Up @@ -100,6 +100,9 @@ void ImmutableConstantOp::Compute(OpKernelContext* ctx) {

OP_REQUIRES_OK(ctx,
allocator->InitializeFromRegion(region_name_, ctx->env()));
OP_REQUIRES(ctx, dtype_ != DT_STRING,
errors::Unimplemented("Sorry, DT_STRING is not currently "
"supported for ImmutableConstOp."));
ctx->set_output(0, Tensor(allocator.get(), dtype_, shape_));
OP_REQUIRES_OK(ctx, allocator->allocation_status());
// Allocator is owned by the tensor from this point.
Expand Down
41 changes: 38 additions & 3 deletions tensorflow/core/kernels/immutable_constant_op_test.cc
Expand Up @@ -146,7 +146,8 @@ TEST(ImmutableConstantOpTest, ExecutionError) {
error::INTERNAL);
}

Status CreateTempFile(Env* env, float value, uint64 size, string* filename) {
Status CreateTempFileFloat(Env* env, float value, uint64 size,
string* filename) {
const string dir = testing::TmpDir();
*filename = io::JoinPath(dir, strings::StrCat("file_", value));
std::unique_ptr<WritableFile> file;
Expand All @@ -166,8 +167,8 @@ TEST(ImmutableConstantOpTest, FromFile) {
auto root = Scope::NewRootScope().ExitOnError();

string two_file, three_file;
TF_ASSERT_OK(CreateTempFile(env, 2.0f, 1000, &two_file));
TF_ASSERT_OK(CreateTempFile(env, 3.0f, 1000, &three_file));
TF_ASSERT_OK(CreateTempFileFloat(env, 2.0f, 1000, &two_file));
TF_ASSERT_OK(CreateTempFileFloat(env, 3.0f, 1000, &three_file));
auto node1 = ops::ImmutableConst(root, DT_FLOAT, kFileTensorShape, two_file);
auto node2 =
ops::ImmutableConst(root, DT_FLOAT, kFileTensorShape, three_file);
Expand All @@ -190,5 +191,39 @@ TEST(ImmutableConstantOpTest, FromFile) {
EXPECT_EQ(outputs.front().flat<float>()(2), 2.0f * 3.0f);
}

Status CreateTempFileBadString(Env* env, char value, uint64 size,
const string suffix, string* filename) {
const string dir = testing::TmpDir();
*filename = io::JoinPath(dir, strings::StrCat("file_", suffix));
std::unique_ptr<WritableFile> file;
TF_RETURN_IF_ERROR(env->NewWritableFile(*filename, &file));
TF_RETURN_IF_ERROR(file->Append(std::string(size, value)));
TF_RETURN_IF_ERROR(file->Close());
return Status::OK();
}

TEST(ImmutableConstantOpTest, FromFileStringUnimplmented) {
const TensorShape kFileTensorShape({1});
Env* env = Env::Default();
auto root = Scope::NewRootScope().ExitOnError();

string bad_file;
TF_ASSERT_OK(CreateTempFileBadString(env, '\xe2', 128, "bad_e2", &bad_file));
auto result =
ops::ImmutableConst(root, DT_STRING, kFileTensorShape, bad_file);
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
SessionOptions session_options;
session_options.env = Env::Default();
std::unique_ptr<Session> session(NewSession(session_options));
ASSERT_TRUE(session != nullptr) << "Failed to create session";
TF_ASSERT_OK(session->Create(graph_def)) << "Can't create test graph";
std::vector<Tensor> outputs;
// Check that the run returned error.
EXPECT_EQ(
session->Run({}, {result.node()->name() + ":0"}, {}, &outputs).code(),
error::UNIMPLEMENTED);
}

} // namespace
} // namespace tensorflow

0 comments on commit 1cb6bb6

Please sign in to comment.