Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gradient for operation 'SparseSlice' #19663

Merged
merged 8 commits into from
Jun 22, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
40 changes: 40 additions & 0 deletions tensorflow/core/api_def/base_api/api_def_SparseSliceGrad.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
op {
graph_op_name: "SparseSliceGrad"
in_arg {
name: "backprop_val_grad"
description: <<END
1-D. The gradient with respect to
the non-empty values of the sliced `SparseTensor`.
END
}
in_arg {
name: "input_indices"
description: <<END
2-D. The `indices` of the input `SparseTensor`.
END
}
in_arg {
name: "input_start"
description: <<END
1-D. tensor represents the start of the slice.
END
}
in_arg {
name: "output_indices"
description: <<END
2-D. The `indices` of the sliced `SparseTensor`.
END
}
out_arg {
name: "val_grad"
description: <<END
1-D. The gradient with respect to the non-empty values of input `SparseTensor`.
END
}
summary: "The gradient operator for the SparseSlice op."
description: <<END
This op takes in the upstream gradient w.r.t. non-empty values of
the sliced `SparseTensor`, and outputs the gradients w.r.t.
the non-empty values of input `SparseTensor`.
END
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
op {
graph_op_name: "SparseSliceGrad"
visibility: HIDDEN
}
7 changes: 7 additions & 0 deletions tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3886,6 +3886,7 @@ cc_library(
":sparse_reduce_op",
":sparse_reorder_op",
":sparse_reshape_op",
":sparse_slice_grad_op",
":sparse_slice_op",
":sparse_softmax",
":sparse_sparse_binary_op_shared",
Expand Down Expand Up @@ -3971,6 +3972,12 @@ tf_kernel_library(
],
)

tf_kernel_library(
name = "sparse_slice_grad_op",
prefix = "sparse_slice_grad_op",
deps = SPARSE_DEPS,
)

tf_kernel_library(
name = "sparse_slice_op",
prefix = "sparse_slice_op",
Expand Down
126 changes: 126 additions & 0 deletions tensorflow/core/kernels/sparse_slice_grad_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"

namespace tensorflow {

template <typename T>
class SparseSliceGradOp : public OpKernel {
public:
explicit SparseSliceGradOp(OpKernelConstruction *ctx) : OpKernel(ctx) {}

void Compute(OpKernelContext *ctx) override {
const Tensor *backprop_val_grad, *input_indices, *output_indices, *input_start;
OP_REQUIRES_OK(ctx, ctx->input("backprop_val_grad", &backprop_val_grad));
OP_REQUIRES_OK(ctx, ctx->input("input_indices", &input_indices));
OP_REQUIRES_OK(ctx, ctx->input("input_start", &input_start));
OP_REQUIRES_OK(ctx, ctx->input("output_indices", &output_indices));

OP_REQUIRES(ctx,
TensorShapeUtils::IsMatrix(input_indices->shape()) &&
TensorShapeUtils::IsMatrix(output_indices->shape()),
errors::InvalidArgument(
"Input and output indices should be matrices "
"but received shapes: ",
input_indices->shape().DebugString(), " and ",
output_indices->shape().DebugString()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsVector(backprop_val_grad->shape()),
errors::InvalidArgument(
"Input backprop_val_grad should be a vector but received shape: ",
backprop_val_grad->shape().DebugString()));
OP_REQUIRES(
ctx,
input_indices->dim_size(1) == output_indices->dim_size(1),
errors::InvalidArgument("The input and output should have the same "
"ndims: got: ", input_indices->dim_size(1), " and ",
output_indices->dim_size(1)));
OP_REQUIRES(
ctx, output_indices->dim_size(0) <= input_indices->dim_size(0),
errors::InvalidArgument("# rows of output_indices should be not greater "
"than of input_indices, got ",
output_indices->dim_size(0), " and ",
input_indices->dim_size(0)));
OP_REQUIRES(
ctx, backprop_val_grad->NumElements() == output_indices->dim_size(0),
errors::InvalidArgument("# elements of backprop_val_grad and # rows of "
"output_indices should match (#nnz of sum): got ",
backprop_val_grad->NumElements(), " and ",
output_indices->dim_size(0)));
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(input_start->shape()),
errors::InvalidArgument(
"The input_start should be a vector but received shape ",
input_start->shape().DebugString()));

const int num_dims = input_indices->dim_size(1);
OP_REQUIRES(ctx, num_dims == input_start->NumElements(),
errors::InvalidArgument(
"Expected input_start to be a vector of length ", num_dims,
" but got length ", input_start->NumElements()));

const int64 input_nnz = input_indices->dim_size(0);

Tensor *val_grad;
OP_REQUIRES_OK(ctx,
ctx->allocate_output(0, TensorShape({input_nnz}), &val_grad));

T *val_grad_flat = val_grad->flat<T>().data();
const T *backprop_val_grad_flat = backprop_val_grad->flat<T>().data();
memset(val_grad_flat, 0, sizeof(T) * input_nnz);

// Fill gradients for position where indices of input and output are same.
const auto input_indices_mat = input_indices->matrix<int64>();
const auto output_indices_mat = output_indices->matrix<int64>();
const auto input_start_flat = input_start->flat<int64>();
int64 j = 0;
for (int64 i = 0; i < input_nnz && j < backprop_val_grad->NumElements();
++i) {
bool is_same = true;
for (int d = 0; d < num_dims; ++d) {
const int64 a = input_indices_mat(i, d);
const int64 b = output_indices_mat(j, d);
const int64 offset = input_start_flat(d);
if (a != b + offset) {
is_same = false;
break;
}
}
if (is_same) {
val_grad_flat[i] = backprop_val_grad_flat[j];
++j;
}
}
OP_REQUIRES(
ctx, backprop_val_grad->NumElements() == j,
errors::Internal("Elements of backprop_val_grad aren't all propagated. "
"Num elements:", backprop_val_grad->NumElements(),
", used: ", j));
}
};

#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("SparseSliceGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
SparseSliceGradOp<type>)

TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
} // namespace tensorflow
14 changes: 14 additions & 0 deletions tensorflow/core/ops/sparse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,20 @@ REGISTER_OP("SparseSplit")
return Status::OK();
});

REGISTER_OP("SparseSliceGrad")
.Input("backprop_val_grad: T")
.Input("input_indices: int64")
.Input("input_start: int64")
.Input("output_indices: int64")
.Output("val_grad: T")
.Attr("T: numbertype")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle indices;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &indices));
c->set_output(0, c->Vector(c->Dim(indices, 0)));
return Status::OK();
});

REGISTER_OP("SparseSlice")
.Input("indices: int64")
.Input("values: T")
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/core/ops/sparse_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ TEST(SparseOpsTest, SparseAddGrad_ShapeFn) {
INFER_OK(op, "?;[?,?];[?,?];?", "[d1_0];[d2_0]");
}

TEST(SparseOpsTest, SparseSliceGrad_ShapeFn) {
ShapeInferenceTestOp op("SparseSliceGrad");

// Rank checks.
INFER_ERROR("must be rank 2", op, "?;[1];?;?");

INFER_OK(op, "?;?;?;?", "[?]");

// input[1].dim(0) determine output.
INFER_OK(op, "?;[?,?];?;?", "[d1_0]");
}

TEST(SparseOpsTest, SparseReorder_ShapeFn) {
ShapeInferenceTestOp op("SparseReorder");

Expand Down
1 change: 1 addition & 0 deletions tensorflow/python/kernel_tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,7 @@ tf_py_test(
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:sparse_grad",
"//tensorflow/python:sparse_ops",
],
)
Expand Down
22 changes: 20 additions & 2 deletions tensorflow/python/kernel_tests/sparse_slice_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
import numpy as np

from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import sparse_ops
import tensorflow.python.ops.sparse_grad # pylint: disable=unused-import
from tensorflow.python.platform import test


class SparseSliceOpTest(test.TestCase):

def _SparseTensor_4x6(self):
def _SparseTensor_4x6(self, val_dtype=np.int64):
# [0 | |2 | |4 |5 ]
# [ |11| |13|14| ]
# [20| | |23| |25]
Expand All @@ -37,7 +39,7 @@ def _SparseTensor_4x6(self):
[2, 3], [2, 5], [3, 0], [3, 2], [3, 3], [3, 5]]).astype(
np.int64)
val = np.array([0, 2, 4, 5, 11, 13, 14, 20, 23, 25, 30, 32, 33, 35]).astype(
np.int64)
val_dtype)
shape = np.array([4, 6]).astype(np.int64)
return sparse_tensor.SparseTensor(ind, val, shape)

Expand Down Expand Up @@ -244,6 +246,22 @@ def testSliceAllColumns(self):
self.assertAllEqual(sparse_tensor5.values.eval(), [5, 25, 35])
self.assertAllEqual(sparse_tensor5.dense_shape.eval(), [4, 1])

def testGradients(self):
sp_input = self._SparseTensor_4x6(val_dtype=np.float32)
start_and_size = [([0, 0], [4, 2]),
([0, 2], [5, 2]),
([0, 4], [5, 3])]

with self.test_session(use_gpu=False):
for start, size in start_and_size:
sp_output = sparse_ops.sparse_slice(sp_input, start, size)
nnz_in = len(sp_input.values.eval())
nnz_out = len(sp_output.values.eval())

err = gradient_checker.compute_gradient_error(
[sp_input.values], [(nnz_in,)], sp_output.values, (nnz_out,))
self.assertLess(err, 1e-3)


if __name__ == '__main__':
test.main()
29 changes: 29 additions & 0 deletions tensorflow/python/ops/sparse_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,35 @@ def _SparseReduceSumGrad(op, out_grad):
None, None)


@ops.RegisterGradient("SparseSlice")
def _SparseSliceGrad(op, *grads):
"""The backward operator for the SparseSlice op.

This op takes in the upstream gradient w.r.t. non-empty values of
the sliced `SparseTensor`, and outputs the gradients w.r.t.
the non-empty values of input `SparseTensor`.

Args:
op: the SparseSlice op
*grads: the incoming gradients, one element per output of `op`

Returns:
Gradient for each of the 5 input tensors of SparseSlice:
(indices, values, shape, start, size)
The gradients for the indices, shape, start and the size are None.
"""
backprop_val_grad = grads[1]
input_indices = op.inputs[0]
input_start = op.inputs[3]
output_indices = op.outputs[0]

val_grad = gen_sparse_ops.sparse_slice_grad(
backprop_val_grad, input_indices, input_start, output_indices)
val_grad.set_shape(op.inputs[1].get_shape())
# (indices, values, shape, start, size)
return (None, val_grad, None, None, None)


@ops.RegisterGradient("SparseTensorDenseMatMul")
def _SparseTensorDenseMatMulGrad(op, grad):
"""Gradients for the dense tensor in the SparseTensorDenseMatMul op.
Expand Down