Skip to content

Commit

Permalink
Merge pull request #49648 from geetachavan1/cherrypicks_5G8K9
Browse files Browse the repository at this point in the history
Add several missing validations in SDCA
  • Loading branch information
mihaimaruseac committed May 26, 2021
2 parents 6e235cc + 271fba6 commit 2ef6c20
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions tensorflow/core/kernels/sdca_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,31 @@ Status ModelWeights::Initialize(OpKernelContext* const context) {
OpInputList sparse_weights_inputs;
TF_RETURN_IF_ERROR(
context->input_list("sparse_weights", &sparse_weights_inputs));
if (sparse_indices_inputs.size() != sparse_weights_inputs.size())
return errors::InvalidArgument(
"sparse_indices and sparse_weights must have the same length, got ",
sparse_indices_inputs.size(), " and ", sparse_weights_inputs.size());
OpInputList dense_weights_inputs;
TF_RETURN_IF_ERROR(
context->input_list("dense_weights", &dense_weights_inputs));

OpOutputList sparse_weights_outputs;
TF_RETURN_IF_ERROR(context->output_list("out_delta_sparse_weights",
&sparse_weights_outputs));
if (sparse_weights_outputs.size() != sparse_weights_inputs.size())
return errors::InvalidArgument(
"out_delta_sparse_weights and sparse_weights must have the same "
"length, got ",
sparse_weights_outputs.size(), " and ", sparse_weights_inputs.size());

OpOutputList dense_weights_outputs;
TF_RETURN_IF_ERROR(
context->output_list("out_delta_dense_weights", &dense_weights_outputs));
if (dense_weights_outputs.size() != dense_weights_inputs.size())
return errors::InvalidArgument(
"out_delta_dense_weights and dense_weights must have the same length, "
"got ",
dense_weights_outputs.size(), " and ", dense_weights_inputs.size());

for (int i = 0; i < sparse_weights_inputs.size(); ++i) {
Tensor* delta_t;
Expand Down Expand Up @@ -327,13 +341,28 @@ Status Examples::Initialize(OpKernelContext* const context,
OpInputList sparse_example_indices_inputs;
TF_RETURN_IF_ERROR(context->input_list("sparse_example_indices",
&sparse_example_indices_inputs));
if (sparse_example_indices_inputs.size() != num_sparse_features)
return errors::InvalidArgument(
"Expected ", num_sparse_features,
" tensors in sparse_example_indices but got ",
sparse_example_indices_inputs.size());
OpInputList sparse_feature_indices_inputs;
TF_RETURN_IF_ERROR(context->input_list("sparse_feature_indices",
&sparse_feature_indices_inputs));
if (sparse_feature_indices_inputs.size() != num_sparse_features)
return errors::InvalidArgument(
"Expected ", num_sparse_features,
" tensors in sparse_feature_indices but got ",
sparse_feature_indices_inputs.size());
OpInputList sparse_feature_values_inputs;
if (num_sparse_features_with_values > 0) {
TF_RETURN_IF_ERROR(context->input_list("sparse_feature_values",
&sparse_feature_values_inputs));
if (sparse_feature_values_inputs.size() != num_sparse_features_with_values)
return errors::InvalidArgument(
"Expected ", num_sparse_features_with_values,
" tensors in sparse_feature_values but got ",
sparse_feature_values_inputs.size());
}

const Tensor* example_weights_t;
Expand Down Expand Up @@ -400,6 +429,13 @@ Status Examples::CreateSparseFeatureRepresentation(
sparse_example_indices_inputs[i].template flat<int64>();
auto feature_indices =
sparse_feature_indices_inputs[i].template flat<int64>();
if (example_indices.size() != feature_indices.size()) {
mutex_lock l(mu);
result = errors::InvalidArgument(
"Found mismatched example_indices and feature_indices [",
example_indices, "] vs [", feature_indices, "]");
return;
}

// Parse features for each example. Features for a particular example
// are at the offsets (start_id, end_id]
Expand Down

0 comments on commit 2ef6c20

Please sign in to comment.