Skip to content
Permalink
Browse files Browse the repository at this point in the history
[lite] add validation check for sparse fully connected
PiperOrigin-RevId: 417629354
Change-Id: If96171c4bd4f5fdb01d6368d6deab19d1c9beca7
  • Loading branch information
karimnosseir authored and tensorflower-gardener committed Dec 21, 2021
1 parent 1de4972 commit 6c0b2b7
Showing 1 changed file with 48 additions and 10 deletions.
58 changes: 48 additions & 10 deletions tensorflow/lite/kernels/fully_connected.cc
Expand Up @@ -928,6 +928,36 @@ TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node,
return kTfLiteOk;
}

// Verifies that sparsity values are valid given input/weight/output.
bool VerifySparsity(const RuntimeShape& weights_shape,
const RuntimeShape& input_shape,
const RuntimeShape& output_shape,
const TfLiteSparsity* sparsity) {
const int weights_dims_count = weights_shape.DimensionsCount();
const int output_dims_count = output_shape.DimensionsCount();
const int w0_size = sparsity->dim_metadata[0].dense_size;
const int accum_depth = weights_shape.Dims(weights_dims_count - 1);
const int output_elements = output_shape.FlatSize();
const int input_elements = input_shape.FlatSize();
const int batches = FlatSizeSkipDim(output_shape, output_dims_count - 1);
const int output_depth = MatchingDim(weights_shape, weights_dims_count - 2,
output_shape, output_dims_count - 1);
const int max_batch_index = batches - 1;
const int max_output = max_batch_index * output_depth + w0_size;
const int max_batch_depth = accum_depth * max_batch_index;

// Verify output size is enough.
if (output_elements < max_output) return false;

// Verify index from sparse in input is valid.
for (int i = 0; i < sparsity->dim_metadata[1].array_indices->size; ++i) {
if (input_elements <=
max_batch_depth + sparsity->dim_metadata[1].array_indices->data[i])
return false;
}
return true;
}

template <KernelType kernel_type>
TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLiteFullyConnectedParams* params, OpData* data,
Expand Down Expand Up @@ -968,24 +998,32 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
"Unsupported sparse fully-connected weight format.");
return kTfLiteError;
}
const auto& input_shape = GetTensorShape(input);
const auto& filter_shape = GetTensorShape(filter);
const auto& output_shape = GetTensorShape(output);
const auto& bias_shape = GetTensorShape(bias);
if (!VerifySparsity(filter_shape, input_shape, output_shape, &sparsity)) {
TF_LITE_KERNEL_LOG(context, "Invalid sparse fully-connected format.");
return kTfLiteError;
}

if (sparsity.dim_metadata_size == kDimMetadataSizeRandomSparse) {
// Random sparse.
optimized_ops::FullyConnectedSparseWeight(
sparsity, op_params, GetTensorShape(input),
GetTensorData<float>(input), GetTensorShape(filter),
GetTensorData<float>(filter), GetTensorShape(bias),
GetTensorData<float>(bias), GetTensorShape(output),
GetTensorData<float>(output));
sparsity, op_params, // Disable formatting
input_shape, GetTensorData<float>(input), // Disable formatting
filter_shape, GetTensorData<float>(filter), // Disable formatting
bias_shape, GetTensorData<float>(bias), // Disable formatting
output_shape, GetTensorData<float>(output));
} else if (sparsity.dim_metadata_size == kDimMetadataSizeBlockSparse &&
sparsity.dim_metadata[2].dense_size == 4) {
// Block sparse with block size of 1x4.
optimized_ops::FullyConnectedSparseWeight1x4(
sparsity, op_params, GetTensorShape(input),
GetTensorData<float>(input), GetTensorShape(filter),
GetTensorData<float>(filter), GetTensorShape(bias),
GetTensorData<float>(bias), GetTensorShape(output),
GetTensorData<float>(output),
sparsity, op_params, // Disable formatting
input_shape, GetTensorData<float>(input), // Disable formatting
filter_shape, GetTensorData<float>(filter), // Disable formatting
bias_shape, GetTensorData<float>(bias), // Disable formatting
output_shape, GetTensorData<float>(output),
CpuBackendContext::GetFromContext(context));
} else {
TF_LITE_KERNEL_LOG(context,
Expand Down

0 comments on commit 6c0b2b7

Please sign in to comment.