Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add broadcasting support for tf.where #15982

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
eb856d1
Add broadcasting support for `tf.where`
yongtang Sep 23, 2018
5cb7665
Update shape function for `tf.where` / `SelectOp`
yongtang Sep 23, 2018
5ce52a0
Add template for BCastSelectFunctor
yongtang Sep 23, 2018
94b1e3e
Add GPU support for where_v2
yongtang Sep 23, 2018
2250c49
Add test case for broadcasting support of `where_v2`
yongtang Sep 23, 2018
399d5c2
Define where_v2 in array_ops.py
yongtang Sep 23, 2018
75b4505
Add broadcasting support for `tf.where`
yongtang Sep 23, 2018
0ec97a7
Fix `Experimental clang-format Check`
yongtang Sep 23, 2018
e5edd77
Update api compatibility test
yongtang Sep 23, 2018
ab4f60d
Update api_def
yongtang Sep 23, 2018
7a9fd72
Add additional test case for tf.where_v2, based on review feedback
yongtang Oct 8, 2018
fec417d
Fix broken tests
yongtang Nov 6, 2018
0d2c28c
Merge branch 'master' into 9284-tf.where-broadcasting
martinwicke Dec 6, 2018
e1f2b7c
Expose where_v2 as v1=["where_v2"], v2=["where"]
yongtang Dec 6, 2018
ad970c3
Add deprecation to tf.where
yongtang Dec 6, 2018
4059b95
Hide select_v2 API
yongtang Dec 6, 2018
bc3538f
Update api goldens
yongtang Dec 6, 2018
8d198f1
Update API to use @tf_export(v1=["where"]) for legacy `where`, and @t…
yongtang Dec 7, 2018
56aa488
Merge branch 'master' into 9284-tf.where-broadcasting
yongtang Mar 30, 2019
70e39f2
Update api compat and tf_upgrade_v2
yongtang Mar 30, 2019
d9b98c4
Rename tf.where_v2 to tf.compat.v2.where, as rename script has to
yongtang Apr 1, 2019
e75409c
Rename elem_bcast to then_else_bcast, and remove duplicate template s…
yongtang May 2, 2019
33cd7b8
Fix GPU build failure
yongtang May 2, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorflow/contrib/framework/__init__.py
Expand Up @@ -124,7 +124,7 @@
from tensorflow.python.ops.init_ops import convolutional_orthogonal_3d
from tensorflow.python.util.all_util import remove_undocumented

_allowed_symbols = ['nest']
_allowed_symbols = ['nest', 'broadcast_to', 'where']
_nest_allowed_symbols = [
'assert_same_structure',
'is_nested',
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/core/api_def/base_api/api_def_SelectV2.pbtxt
@@ -0,0 +1,3 @@
op {
graph_op_name: "SelectV2"
}
4 changes: 4 additions & 0 deletions tensorflow/core/api_def/python_api/api_def_SelectV2.pbtxt
@@ -0,0 +1,4 @@
op {
graph_op_name: "SelectV2"
visibility: HIDDEN
}
29 changes: 25 additions & 4 deletions tensorflow/core/kernels/cwise_op_gpu_select.cu.cc
Expand Up @@ -23,6 +23,22 @@ limitations under the License.
namespace tensorflow {
namespace functor {

template <typename T, int NDIMS>
struct BCastSelectFunctor<GPUDevice, T, NDIMS> {
void operator()(const GPUDevice& d,
typename TTypes<T, NDIMS>::Tensor output_tensor,
typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
typename TTypes<T, NDIMS>::ConstTensor then_tensor,
typename TTypes<T, NDIMS>::ConstTensor else_tensor,
typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast) {
output_tensor.device(d) = cond_tensor.broadcast(cond_bcast)
.select(then_tensor.broadcast(then_bcast),
else_tensor.broadcast(else_bcast));
}
};

template <typename T>
struct SelectFunctor<GPUDevice, T> {
void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
Expand Down Expand Up @@ -89,10 +105,15 @@ struct BatchSelectFunctor<GPUDevice, T> {
}
};

#define SELECT_FUNCTOR(T) \
template struct SelectFunctor<GPUDevice, T>; \
template struct SelectScalarFunctor<GPUDevice, T>; \
template struct BatchSelectFunctor<GPUDevice, T>;
#define SELECT_FUNCTOR(T) \
template struct SelectFunctor<GPUDevice, T>; \
template struct SelectScalarFunctor<GPUDevice, T>; \
template struct BatchSelectFunctor<GPUDevice, T>; \
template struct BCastSelectFunctor<GPUDevice, T, 1>; \
template struct BCastSelectFunctor<GPUDevice, T, 2>; \
template struct BCastSelectFunctor<GPUDevice, T, 3>; \
template struct BCastSelectFunctor<GPUDevice, T, 4>; \
template struct BCastSelectFunctor<GPUDevice, T, 5>;

SELECT_FUNCTOR(bool);
SELECT_FUNCTOR(Eigen::half);
Expand Down
167 changes: 156 additions & 11 deletions tensorflow/core/kernels/cwise_op_select.cc
Expand Up @@ -143,21 +143,138 @@ class SelectOp : public OpKernel {
private:
TF_DISALLOW_COPY_AND_ASSIGN(SelectOp);
};
template <typename Device, typename T>
class SelectV2Op : public OpKernel {
public:
explicit SelectV2Op(OpKernelConstruction* context) : OpKernel(context) {}

void Compute(OpKernelContext* ctx) override {
const Tensor* cond;
const Tensor* then;
const Tensor* else_;
OP_REQUIRES_OK(ctx, ctx->input("condition", &cond));
OP_REQUIRES_OK(ctx, ctx->input("t", &then));
OP_REQUIRES_OK(ctx, ctx->input("e", &else_));

// The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()),
// This matches the behavior of numpy.
// TODO (yongtang): Consolidate into n-ary broadcast, instead of multiple
// 2-ary broadcast.

// Combine `then` and `else`.
BCast then_else_bcast(BCast::FromShape(then->shape()),
BCast::FromShape(else_->shape()), false);
OP_REQUIRES(ctx, then_else_bcast.IsValid(),
errors::InvalidArgument(
"then ", then->shape().DebugString(), " and else ",
else_->shape().DebugString(), " must be broadcastable"));
// Combine `cond` with `then` and `else`.
BCast bcast(BCast::FromShape(cond->shape()),
BCast::FromShape(BCast::ToShape(then_else_bcast.output_shape())),
false);
OP_REQUIRES(ctx, bcast.IsValid(),
errors::InvalidArgument(
"condition ", cond->shape().DebugString(), ", then ",
then->shape().DebugString(), ", and else ",
else_->shape().DebugString(), " must be broadcastable"));

// Broadcast `cond`, `then` and `else` to combined shape,
// in order to obtain the reshape.
BCast cond_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())),
BCast::FromShape(cond->shape()), false);
BCast then_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())),
BCast::FromShape(then->shape()), false);
BCast else_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())),
BCast::FromShape(else_->shape()), false);
OP_REQUIRES(ctx, cond_bcast.IsValid() && then_bcast.IsValid() &&
else_bcast.IsValid(),
errors::InvalidArgument(
"condition ", cond->shape().DebugString(), ", then ",
then->shape().DebugString(), ", and else ",
else_->shape().DebugString(), " must be broadcastable"));

// Combined shape should be the final shape.
OP_REQUIRES(
ctx, cond_bcast.output_shape() == bcast.output_shape() &&
then_bcast.output_shape() == bcast.output_shape() &&
else_bcast.output_shape() == bcast.output_shape(),
errors::InvalidArgument("condition ", cond->shape().DebugString(),
", then ", then->shape().DebugString(),
", and else ", else_->shape().DebugString(),
" must be broadcastable to the same shape"));

Tensor* output = nullptr;
const TensorShape output_shape = BCast::ToShape(bcast.output_shape());
OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
{"t", "e"}, "output", output_shape, &output));

if (output->NumElements() == 0) {
return;
}

#define HANDLE_DIM(NDIMS) \
{ \
functor::BCastSelectFunctor<Device, T, NDIMS> func; \
func(ctx->eigen_device<Device>(), \
output->shaped<T, NDIMS>(bcast.result_shape()), \
cond->template shaped<bool, NDIMS>(cond_bcast.y_reshape()), \
then->template shaped<T, NDIMS>(then_bcast.y_reshape()), \
else_->template shaped<T, NDIMS>(else_bcast.y_reshape()), \
BCast::ToIndexArray<NDIMS>(cond_bcast.y_bcast()), \
BCast::ToIndexArray<NDIMS>(then_bcast.y_bcast()), \
BCast::ToIndexArray<NDIMS>(else_bcast.y_bcast())); \
}

#define REGISTER_SELECT(type) \
REGISTER_KERNEL_BUILDER( \
Name("Select").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
SelectOp<CPUDevice, type>);
const int ndims = static_cast<int>(bcast.result_shape().size());
switch (ndims) {
case 1:
HANDLE_DIM(1);
break;
case 2:
HANDLE_DIM(2);
break;
case 3:
HANDLE_DIM(3);
break;
case 4:
HANDLE_DIM(4);
break;
case 5:
HANDLE_DIM(5);
break;
default:
ctx->SetStatus(errors::Unimplemented(
"Broadcast between ", ctx->input(0).shape().DebugString(), " and ",
ctx->input(1).shape().DebugString(), " is not supported yet."));
break;
}
return;
}

private:
TF_DISALLOW_COPY_AND_ASSIGN(SelectV2Op);
};

#define REGISTER_SELECT(type) \
REGISTER_KERNEL_BUILDER( \
Name("Select").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
SelectOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("SelectV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
SelectV2Op<CPUDevice, type>);

TF_CALL_ALL_TYPES(REGISTER_SELECT);

#if GOOGLE_CUDA

// Registration of the GPU implementations.
#define REGISTER_SELECT_GPU(type) \
REGISTER_KERNEL_BUILDER( \
Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
SelectOp<GPUDevice, type>);
#define REGISTER_SELECT_GPU(type) \
REGISTER_KERNEL_BUILDER( \
Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
SelectOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("SelectV2").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
SelectV2Op<GPUDevice, type>);

REGISTER_SELECT_GPU(bool);
REGISTER_SELECT_GPU(Eigen::half);
Expand All @@ -174,9 +291,12 @@ REGISTER_SELECT_GPU(complex128);

#ifdef TENSORFLOW_USE_SYCL
// Registration of the SYCL implementations.
#define REGISTER_SELECT_SYCL(type) \
REGISTER_KERNEL_BUILDER( \
Name("Select").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
#define REGISTER_SELECT_SYCL(type) \
REGISTER_KERNEL_BUILDER( \
Name("Select").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
SelectOp<SYCLDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("SelectV2").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
SelectOp<SYCLDevice, type>);

REGISTER_SELECT_SYCL(float);
Expand Down Expand Up @@ -324,10 +444,35 @@ struct BatchSelectFunctor<CPUDevice, T> {
}
};

template <typename Device, typename T, int NDIMS>
struct BCastSelectFunctorBase {
void operator()(const Device& d,
typename TTypes<T, NDIMS>::Tensor output_tensor,
typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
typename TTypes<T, NDIMS>::ConstTensor then_tensor,
typename TTypes<T, NDIMS>::ConstTensor else_tensor,
typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast) {
output_tensor.device(d) = cond_tensor.broadcast(cond_bcast)
.select(then_tensor.broadcast(then_bcast),
else_tensor.broadcast(else_bcast));
}
};

template <typename T, int NDIMS>
struct BCastSelectFunctor<CPUDevice, T, NDIMS>
: BCastSelectFunctorBase<CPUDevice, T, NDIMS> {};

#ifdef TENSORFLOW_USE_SYCL
template <typename T>
struct BatchSelectFunctor<SYCLDevice, T>
: BatchSelectFunctorBase<SYCLDevice, T> {};

template <typename T, int NDIMS>
struct BCastSelectFunctor<SYCLDevice, T, NDIMS>
: BCastSelectFunctorBase<SYCLDevice, T, NDIMS> {};

#endif // TENSORFLOW_USE_SYCL

} // namespace functor
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/core/kernels/cwise_ops.h
Expand Up @@ -1208,6 +1208,18 @@ struct BatchSelectFunctor {
typename TTypes<T>::ConstMatrix else_flat_outer_dims);
};

template <typename Device, typename T, int NDIMS>
struct BCastSelectFunctor {
void operator()(const Device& d,
typename TTypes<T, NDIMS>::Tensor output_tensor,
typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
typename TTypes<T, NDIMS>::ConstTensor then_tensor,
typename TTypes<T, NDIMS>::ConstTensor else_tensor,
typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast);
};

} // end namespace functor
} // end namespace tensorflow

Expand Down
51 changes: 51 additions & 0 deletions tensorflow/core/ops/math_ops.cc
Expand Up @@ -820,6 +820,57 @@ REGISTER_OP("Select")
return Status::OK();
});

REGISTER_OP("SelectV2")
.Input("condition: bool")
.Input("t: T")
.Input("e: T")
.Output("output: T")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
auto* handle_data_1 = c->input_handle_shapes_and_types(1);
auto* handle_data_2 = c->input_handle_shapes_and_types(2);
// Merge handle shape and dtype if applicable.
if (handle_data_1 != nullptr && handle_data_2 != nullptr) {
const auto size = handle_data_1->size();
std::vector<shape_inference::ShapeAndType> merged_handle_data(size);
if (size != handle_data_2->size()) {
return errors::InvalidArgument(
"Trying to merge handles pointing to different numbers of "
"tensors.");
}

for (int i = 0; i < size; ++i) {
const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i];
const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i];
if (s1.dtype != s2.dtype) {
// TODO(apassos) resolve this in the manner of b/32476923
return errors::InvalidArgument(
"Trying to merge handles pointing to different dtypes.");
}
merged_handle_data[i].dtype = s1.dtype;
TF_RETURN_IF_ERROR(
c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape));
}

c->set_output_handle_shapes_and_types(0, merged_handle_data);
}

// The inputs 'cond', 'then', and 'else' must be broadcastable.
// TODO (yongtang): Consolidate 3-ary broadcast instead of
// multiple 2-ary broadcast.
ShapeHandle cond = c->input(0);
ShapeHandle then = c->input(1);
ShapeHandle else_ = c->input(2);
ShapeHandle other;
TF_RETURN_IF_ERROR(
BroadcastBinaryOpOutputShapeFnHelper(c, then, else_, &other));
ShapeHandle output;
TF_RETURN_IF_ERROR(
BroadcastBinaryOpOutputShapeFnHelper(c, cond, other, &output));
c->set_output(0, output);
return Status::OK();
});

// --------------------------------------------------------------------------

REGISTER_OP("MatMul")
Expand Down