Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions torch_xla/csrc/ops/gather.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "torch_xla/csrc/ops/gather.h"

#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/lowering_context.h"
Expand All @@ -11,11 +12,24 @@ namespace ir {
namespace ops {
namespace {

bool IsSparseGather(const xla::XlaOp& input, const xla::XlaOp& index,
xla::int64 dim) {
xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input);
xla::Shape index_shape = XlaHelpers::ShapeOfXlaOp(index);
xla::int64 input_elements = xla::ShapeUtil::ElementsIn(input_shape);
xla::int64 index_elements = xla::ShapeUtil::ElementsIn(index_shape);
// Simple heuristic. Might need fine tuning.
return index_elements < input_elements / 3;
}

xla::Shape NodeOutputShape(const Value& input, const Value& index,
xla::int64 dim) {
auto lower_for_shape_fn =
[&](tensorflow::gtl::ArraySlice<const xla::XlaOp> operands)
-> xla::XlaOp { return TorchGather(operands[0], operands[1], dim); };
-> xla::XlaOp {
return TorchGather(operands[0], operands[1], dim,
IsSparseGather(operands[0], operands[1], dim));
};
return InferOutputShape({input.shape(), index.shape()}, lower_for_shape_fn);
}

Expand All @@ -34,7 +48,9 @@ NodePtr Gather::Clone(OpList operands) const {
XlaOpVector Gather::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp index = loctx->GetOutputOp(operand(1));
return ReturnOp(TorchGather(input, index, dim_), loctx);
return ReturnOp(
TorchGather(input, index, dim_, IsSparseGather(input, index, dim_)),
loctx);
}

std::string Gather::ToString() const {
Expand Down