diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc index 7f888da69d6c7f..6764deee21d286 100644 --- a/tensorflow/core/kernels/summary_kernels.cc +++ b/tensorflow/core/kernels/summary_kernels.cc @@ -38,12 +38,20 @@ class CreateSummaryFileWriterOp : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor* tmp; OP_REQUIRES_OK(ctx, ctx->input("logdir", &tmp)); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()), + errors::InvalidArgument("logdir must be a scalar")); const string logdir = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("max_queue", &tmp)); - const int32 max_queue = tmp->scalar()(); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()), + errors::InvalidArgument("max_queue must be a scalar")); + const int32_t max_queue = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("flush_millis", &tmp)); - const int32 flush_millis = tmp->scalar()(); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()), + errors::InvalidArgument("flush_millis must be a scalar")); + const int32_t flush_millis = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("filename_suffix", &tmp)); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()), + errors::InvalidArgument("filename_suffix must be a scalar")); const string filename_suffix = tmp->scalar()(); core::RefCountPtr s; diff --git a/tensorflow/python/summary/writer/writer_test.py b/tensorflow/python/summary/writer/writer_test.py index 19138b1372dea6..9fcac4952f6a25 100644 --- a/tensorflow/python/summary/writer/writer_test.py +++ b/tensorflow/python/summary/writer/writer_test.py @@ -34,6 +34,7 @@ from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.framework import test_util @@ -685,6 +686,16 @@ def testSharing_withExplicitSummaryFileWriters(self): # No more files self.assertRaises(StopIteration, lambda: next(event_paths)) + def testSummaryFileWritersInvalidInput(self): + # Test case for GitHub issue 46909 + logdir = self.get_temp_dir() + with session.Session() as sess: + with self.assertRaises(errors_impl.InvalidArgumentError): + writer = summary_ops_v2.create_file_writer( + logdir=logdir, flush_millis=[1, 2]) + sess.run(writer.init()) + sess.run(writer.flush()) + class FileWriterCacheTest(test.TestCase): """FileWriterCache tests."""