Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix SDCA optimizer crash.
Validates size of the dense_features and example state_data_inputs.
Other validation already verifies that sizes are otherwise consistent.

This looks to be a v1-only op that isn't used internally at all outside
of `contrib`, and no tests.

PiperOrigin-RevId: 478073762
  • Loading branch information
cantonios authored and tensorflower-gardener committed Sep 30, 2022
1 parent a628192 commit 80ff197
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tensorflow/core/kernels/sdca_internal.cc
Expand Up @@ -389,6 +389,13 @@ Status Examples::Initialize(OpKernelContext* const context,
OpInputList dense_features_inputs;
TF_RETURN_IF_ERROR(
context->input_list("dense_features", &dense_features_inputs));
for (int i = 0; i < dense_features_inputs.size(); ++i) {
if (!TensorShapeUtils::IsMatrix(dense_features_inputs[i].shape())) {
return errors::InvalidArgument("Dense features at index ", i,
" must be rank 2 but is rank ",
dense_features_inputs[i].dims());
}
}

examples_.clear();
examples_.resize(num_examples);
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/core/kernels/sdca_ops.cc
Expand Up @@ -49,6 +49,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
Expand Down Expand Up @@ -142,6 +143,10 @@ void DoCompute(const ComputeOptions& options, OpKernelContext* const context) {
const Tensor* example_state_data_t;
OP_REQUIRES_OK(context,
context->input("example_state_data", &example_state_data_t));
OP_REQUIRES(
context, TensorShapeUtils::IsMatrix(example_state_data_t->shape()),
errors::InvalidArgument("example_state_data must be rank 2 but is rank ",
example_state_data_t->dims()));
TensorShape expected_example_state_shape({examples.num_examples(), 4});
OP_REQUIRES(context,
example_state_data_t->shape() == expected_example_state_shape,
Expand Down

0 comments on commit 80ff197

Please sign in to comment.