Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix check failure in SparseCrossV2Op by adding check for scalar value…
… for separator.

PiperOrigin-RevId: 461001180
  • Loading branch information
ishark authored and tensorflower-gardener committed Jul 14, 2022
1 parent 3ff9e48 commit 83dcb4d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tensorflow/core/kernels/sparse_cross_op.cc
Expand Up @@ -24,12 +24,14 @@ limitations under the License.
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/strong_hash.h"
#include "tensorflow/core/util/work_sharder.h"
Expand Down Expand Up @@ -832,6 +834,10 @@ class SparseCrossV2Op : public OpKernel {

const Tensor* sep_t;
OP_REQUIRES_OK(context, context->input("sep", &sep_t));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(sep_t->shape()),
errors::InvalidArgument("Input separator should be a scalar. "
"Received: ",
sep_t->DebugString()));
const tstring separator = sep_t->scalar<tstring>()();

std::vector<std::unique_ptr<ColumnInterface<tstring>>> columns =
Expand Down
Expand Up @@ -873,6 +873,14 @@ def test_all_columns_empty(self):
with self.cached_session():
self._assert_sparse_tensor_empty(self.evaluate(out))

def testNonScalarInput(self):
with self.assertRaisesRegex(errors.InvalidArgumentError,
'Input separator should be a scalar.'):
self.evaluate(sparse_ops.sparse_cross(
inputs=[],
name='a',
separator=constant_op.constant(['a', 'b'], dtype=dtypes.string)))


class SparseCrossHashedOpTest(BaseSparseCrossOpTest):

Expand Down

0 comments on commit 83dcb4d

Please sign in to comment.