From 4d38f69a0e838c53e7a3bb9a4167cc7dbd481cc8 Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Sun, 11 Aug 2019 02:23:42 -0700 Subject: [PATCH] Wire in dense gather. --- torch_xla/csrc/ops/gather.cpp | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/ops/gather.cpp b/torch_xla/csrc/ops/gather.cpp index aa91efda5d0d..e660dac6622a 100644 --- a/torch_xla/csrc/ops/gather.cpp +++ b/torch_xla/csrc/ops/gather.cpp @@ -1,6 +1,7 @@ #include "torch_xla/csrc/ops/gather.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_client/util.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" @@ -11,11 +12,24 @@ namespace ir { namespace ops { namespace { +bool IsSparseGather(const xla::XlaOp& input, const xla::XlaOp& index, + xla::int64 dim) { + xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input); + xla::Shape index_shape = XlaHelpers::ShapeOfXlaOp(index); + xla::int64 input_elements = xla::ShapeUtil::ElementsIn(input_shape); + xla::int64 index_elements = xla::ShapeUtil::ElementsIn(index_shape); + // Simple heuristic. Might need fine tuning. + return index_elements < input_elements / 3; +} + xla::Shape NodeOutputShape(const Value& input, const Value& index, xla::int64 dim) { auto lower_for_shape_fn = [&](tensorflow::gtl::ArraySlice operands) - -> xla::XlaOp { return TorchGather(operands[0], operands[1], dim); }; + -> xla::XlaOp { + return TorchGather(operands[0], operands[1], dim, + IsSparseGather(operands[0], operands[1], dim)); + }; return InferOutputShape({input.shape(), index.shape()}, lower_for_shape_fn); } @@ -34,7 +48,9 @@ NodePtr Gather::Clone(OpList operands) const { XlaOpVector Gather::Lower(LoweringContext* loctx) const { xla::XlaOp input = loctx->GetOutputOp(operand(0)); xla::XlaOp index = loctx->GetOutputOp(operand(1)); - return ReturnOp(TorchGather(input, index, dim_), loctx); + return ReturnOp( + TorchGather(input, index, dim_, IsSparseGather(input, index, dim_)), + loctx); } std::string Gather::ToString() const {