Skip to content

Commit 1cb6bb6

Browse files
SeeForTwotensorflower-gardener
authored andcommitted
Add error checking to ImmutableConst OP that strings are not yet supported.
PiperOrigin-RevId: 401065359 Change-Id: I9dd2bd2a2c36f22f4a05153daf6ebdc4613469d2
1 parent c29040f commit 1cb6bb6

File tree

2 files changed

+41
-3
lines changed

2 files changed

+41
-3
lines changed

Diff for: tensorflow/core/kernels/immutable_constant_op.cc

+3
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ void ImmutableConstantOp::Compute(OpKernelContext* ctx) {
100100

101101
OP_REQUIRES_OK(ctx,
102102
allocator->InitializeFromRegion(region_name_, ctx->env()));
103+
OP_REQUIRES(ctx, dtype_ != DT_STRING,
104+
errors::Unimplemented("Sorry, DT_STRING is not currently "
105+
"supported for ImmutableConstOp."));
103106
ctx->set_output(0, Tensor(allocator.get(), dtype_, shape_));
104107
OP_REQUIRES_OK(ctx, allocator->allocation_status());
105108
// Allocator is owned by the tensor from this point.

Diff for: tensorflow/core/kernels/immutable_constant_op_test.cc

+38-3
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ TEST(ImmutableConstantOpTest, ExecutionError) {
146146
error::INTERNAL);
147147
}
148148

149-
Status CreateTempFile(Env* env, float value, uint64 size, string* filename) {
149+
Status CreateTempFileFloat(Env* env, float value, uint64 size,
150+
string* filename) {
150151
const string dir = testing::TmpDir();
151152
*filename = io::JoinPath(dir, strings::StrCat("file_", value));
152153
std::unique_ptr<WritableFile> file;
@@ -166,8 +167,8 @@ TEST(ImmutableConstantOpTest, FromFile) {
166167
auto root = Scope::NewRootScope().ExitOnError();
167168

168169
string two_file, three_file;
169-
TF_ASSERT_OK(CreateTempFile(env, 2.0f, 1000, &two_file));
170-
TF_ASSERT_OK(CreateTempFile(env, 3.0f, 1000, &three_file));
170+
TF_ASSERT_OK(CreateTempFileFloat(env, 2.0f, 1000, &two_file));
171+
TF_ASSERT_OK(CreateTempFileFloat(env, 3.0f, 1000, &three_file));
171172
auto node1 = ops::ImmutableConst(root, DT_FLOAT, kFileTensorShape, two_file);
172173
auto node2 =
173174
ops::ImmutableConst(root, DT_FLOAT, kFileTensorShape, three_file);
@@ -190,5 +191,39 @@ TEST(ImmutableConstantOpTest, FromFile) {
190191
EXPECT_EQ(outputs.front().flat<float>()(2), 2.0f * 3.0f);
191192
}
192193

194+
Status CreateTempFileBadString(Env* env, char value, uint64 size,
195+
const string suffix, string* filename) {
196+
const string dir = testing::TmpDir();
197+
*filename = io::JoinPath(dir, strings::StrCat("file_", suffix));
198+
std::unique_ptr<WritableFile> file;
199+
TF_RETURN_IF_ERROR(env->NewWritableFile(*filename, &file));
200+
TF_RETURN_IF_ERROR(file->Append(std::string(size, value)));
201+
TF_RETURN_IF_ERROR(file->Close());
202+
return Status::OK();
203+
}
204+
205+
TEST(ImmutableConstantOpTest, FromFileStringUnimplmented) {
206+
const TensorShape kFileTensorShape({1});
207+
Env* env = Env::Default();
208+
auto root = Scope::NewRootScope().ExitOnError();
209+
210+
string bad_file;
211+
TF_ASSERT_OK(CreateTempFileBadString(env, '\xe2', 128, "bad_e2", &bad_file));
212+
auto result =
213+
ops::ImmutableConst(root, DT_STRING, kFileTensorShape, bad_file);
214+
GraphDef graph_def;
215+
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
216+
SessionOptions session_options;
217+
session_options.env = Env::Default();
218+
std::unique_ptr<Session> session(NewSession(session_options));
219+
ASSERT_TRUE(session != nullptr) << "Failed to create session";
220+
TF_ASSERT_OK(session->Create(graph_def)) << "Can't create test graph";
221+
std::vector<Tensor> outputs;
222+
// Check that the run returned error.
223+
EXPECT_EQ(
224+
session->Run({}, {result.node()->name() + ":0"}, {}, &outputs).code(),
225+
error::UNIMPLEMENTED);
226+
}
227+
193228
} // namespace
194229
} // namespace tensorflow

0 commit comments

Comments
 (0)