From 3e3c47885fc06c2bec7fe183b500045d7a0a5e7d Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Sun, 24 Feb 2019 16:41:38 -0800 Subject: [PATCH 1/2] Added aten::kthvalue() operation. --- test/cpp/test_aten_xla_tensor.cpp | 17 +++++++++++ torch_xla/csrc/aten_xla_type.cpp | 9 ++++++ torch_xla/csrc/aten_xla_type.h | 4 +++ torch_xla/csrc/ops/kth_value.cpp | 49 +++++++++++++++++++++++++++++++ torch_xla/csrc/ops/kth_value.h | 31 +++++++++++++++++++ torch_xla/csrc/tensor.cpp | 11 +++++++ torch_xla/csrc/tensor.h | 4 +++ torch_xla/csrc/xla_lower_util.cpp | 38 ++++++++++++++++++++++++ torch_xla/csrc/xla_lower_util.h | 3 ++ 9 files changed, 166 insertions(+) create mode 100644 torch_xla/csrc/ops/kth_value.cpp create mode 100644 torch_xla/csrc/ops/kth_value.h diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index b5fbaf419052..7b909d082dc8 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -376,6 +376,23 @@ TEST_F(AtenXlaTensorTest, TestIntegerAdd) { }); } +TEST_F(AtenXlaTensorTest, TestKthValue) { + at::Tensor a = at::rand({4, 5, 3}, at::TensorOptions(at::kFloat)); + for (int k = 1; k <= 3; ++k) { + for (int dim = 0; dim < 3; ++dim) { + for (bool keepdim : {false, true}) { + auto b = at::kthvalue(a, k, dim, keepdim); + ForEachDevice([&](const Device& device) { + at::Tensor xla_a = bridge::CreateXlaTensor(a, device); + auto xla_b = at::kthvalue(xla_a, k, dim, keepdim); + AllClose(std::get<0>(b), std::get<0>(xla_b)); + AllClose(std::get<1>(b), std::get<1>(xla_b)); + }); + } + } + } +} + TEST_F(AtenXlaTensorTest, TestMin) { at::Tensor a = at::rand({2, 2}, at::TensorOptions(at::kFloat)); at::Tensor b = at::rand({2, 2}, at::TensorOptions(at::kFloat)); diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 64c254029d0b..632449b3e325 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -666,6 +666,15 @@ int64_t AtenXlaType::size(const at::Tensor& self, int64_t dim) const { return bridge::GetXlaTensor(self).size(dim); } +std::tuple AtenXlaType::kthvalue(const at::Tensor& self, + int64_t k, int64_t dim, + bool keepdim) const { + auto results = + XLATensor::kthvalue(bridge::GetXlaTensor(self), k, dim, keepdim); + return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), + bridge::AtenFromXlaTensor(std::get<1>(results))); +} + at::Tensor AtenXlaType::embedding(const at::Tensor& weight, const at::Tensor& indices, int64_t padding_idx, bool scale_grad_by_freq, diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index 88564acc6e7a..8bcddc534881 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -232,6 +232,10 @@ class AtenXlaType : public AtenXlaTypeBase { int64_t size(const at::Tensor& self, int64_t dim) const override; + std::tuple kthvalue(const at::Tensor& self, int64_t k, + int64_t dim, + bool keepdim) const override; + at::Tensor embedding(const at::Tensor& weight, const at::Tensor& indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) const override; diff --git a/torch_xla/csrc/ops/kth_value.cpp b/torch_xla/csrc/ops/kth_value.cpp new file mode 100644 index 000000000000..fa5b1def4fe0 --- /dev/null +++ b/torch_xla/csrc/ops/kth_value.cpp @@ -0,0 +1,49 @@ +#include "torch_xla/csrc/ops/kth_value.h" + +#include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" +#include "torch_xla/csrc/xla_lower_util.h" + +namespace torch_xla { +namespace ir { +namespace ops { +namespace { + +xla::Shape NodeOutputShape(const Value& input, xla::int64 k, xla::int64 dim, + bool keepdim) { + auto lower_for_shape_fn = + [&](tensorflow::gtl::ArraySlice operands) + -> xla::XlaOp { + return xla::Tuple(operands[0].builder(), + CreateKthValue(operands[0], k, dim, keepdim)); + }; + return InferOutputShape({input.shape()}, lower_for_shape_fn); +} + +} // namespace + +KthValue::KthValue(const Value& input, xla::int64 k, xla::int64 dim, + bool keepdim) + : Node(ir::OpKind(at::aten::kthvalue), {input}, + NodeOutputShape(input, k, dim, keepdim), + /*num_outputs=*/2, xla::util::MHash(k, dim, keepdim)), + k_(k), + dim_(dim), + keepdim_(keepdim) {} + +XlaOpVector KthValue::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + return ReturnOps(CreateKthValue(input, k_, dim_, keepdim_), loctx); +} + +std::string KthValue::ToString() const { + std::stringstream ss; + ss << Node::ToString() << ", k=" << k_ << ", dim=" << dim_ + << ", keepdim=" << keepdim_; + return ss.str(); +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/kth_value.h b/torch_xla/csrc/ops/kth_value.h new file mode 100644 index 000000000000..95f039b1bc9f --- /dev/null +++ b/torch_xla/csrc/ops/kth_value.h @@ -0,0 +1,31 @@ +#pragma once + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +class KthValue : public Node { + public: + KthValue(const Value& input, xla::int64 k, xla::int64 dim, bool keepdim); + + std::string ToString() const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + xla::int64 k() const { return k_; }; + + xla::int64 dim() const { return dim_; }; + + bool keepdim() const { return keepdim_; } + + private: + xla::int64 k_; + xla::int64 dim_; + bool keepdim_; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 61815dd34cf6..b528d7f91c55 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -44,6 +44,7 @@ #include "torch_xla/csrc/ops/generic.h" #include "torch_xla/csrc/ops/index_select.h" #include "torch_xla/csrc/ops/infer_output_shape.h" +#include "torch_xla/csrc/ops/kth_value.h" #include "torch_xla/csrc/ops/leaky_relu.h" #include "torch_xla/csrc/ops/log_softmax.h" #include "torch_xla/csrc/ops/log_softmax_backward.h" @@ -908,6 +909,16 @@ XLATensor XLATensor::select(const XLATensor& input, int64_t dim, input.GetDevice()); } +std::tuple XLATensor::kthvalue(const XLATensor& input, + xla::int64 k, + xla::int64 dim, + bool keepdim) { + ir::NodePtr node = + ir::MakeNode(input.GetIrValue(), k, dim, keepdim); + return std::make_tuple(Create(ir::Value(node, 0), input.GetDevice()), + Create(ir::Value(node, 1), input.GetDevice())); +} + XLATensor XLATensor::dropout(const XLATensor& input, double p) { return Create(ir::MakeNode(input.GetIrValue(), p), input.GetDevice()); diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 23e9e98bc4e7..de222b75ad87 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -268,6 +268,10 @@ class XLATensor { static XLATensor select(const XLATensor& input, int64_t dim, int64_t index); + static std::tuple kthvalue(const XLATensor& input, + xla::int64 k, xla::int64 dim, + bool keepdim); + static XLATensor dropout(const XLATensor& input, double p); static XLATensor neg(const XLATensor& input); diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 839dd2bf9962..26a595e74ffb 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -3,6 +3,7 @@ #include #include +#include "tensorflow/compiler/xla/client/lib/comparators.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" #include "tensorflow/compiler/xla/xla_client/util.h" @@ -72,6 +73,43 @@ std::pair DotBroadcast(const xla::XlaOp& lhs, } // namespace +std::vector CreateKthValue(const xla::XlaOp& input, xla::int64 k, + xla::int64 dim, bool keepdim) { + // Here 'k' is 1 based (1...). + xla::Shape shape = XlaHelpers::ShapeOfXlaOp(input); + XLA_CHECK_LE(k, shape.dimensions(dim)); + xla::Shape iota_shape = + xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, shape.dimensions()); + xla::XlaOp iota = xla::Iota(input.builder(), iota_shape, dim); + // TODO: Remember to add is_stable=true as last Sort() argument when fetching + // the new TF head. + xla::XlaOp sort_result = xla::Sort( + {input, iota}, + xla::CreateScalarLtComputation( + {shape.element_type(), xla::PrimitiveType::S32}, input.builder()), + dim); + + std::vector start_indices(shape.rank(), 0); + start_indices[dim] = k - 1; + std::vector limit_indices(shape.dimensions().begin(), + shape.dimensions().end()); + limit_indices[dim] = k; + std::vector strides(shape.rank(), 1); + + xla::XlaOp values = xla::Slice(xla::GetTupleElement(sort_result, 0), + start_indices, limit_indices, strides); + xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1), + start_indices, limit_indices, strides); + if (!keepdim) { + auto reshape_sizes = + XlaHelpers::DropDimensions(shape.dimensions(), {dim}); + values = xla::Reshape(values, reshape_sizes); + indices = xla::Reshape(indices, reshape_sizes); + } + // aten::kthvalue() wants Long tensors as indices. + return {values, xla::ConvertElementType(indices, xla::PrimitiveType::S64)}; +} + xla::XlaOp CreateMatMul(const xla::XlaOp& lhs, const xla::XlaOp& rhs) { const auto precision_level = XlaHelpers::mat_mul_precision(); xla::PrecisionConfig precision_config = diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index 6ba6d8680759..1cf36c9413f2 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -5,6 +5,9 @@ namespace torch_xla { +std::vector CreateKthValue(const xla::XlaOp& input, xla::int64 k, + xla::int64 dim, bool keepdim); + xla::XlaOp CreateMatMul(const xla::XlaOp& lhs, const xla::XlaOp& rhs); xla::XlaOp BuildDropout(const xla::XlaOp& input, float probability); From b570860531f15c4ba277c1f0b06e63205cac5a85 Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Sun, 24 Feb 2019 17:51:16 -0800 Subject: [PATCH 2/2] Added aten::topk operation. --- test/cpp/test_aten_xla_tensor.cpp | 17 +++++++++++ torch_xla/csrc/aten_xla_type.cpp | 14 +++++++++ torch_xla/csrc/aten_xla_type.h | 4 +++ torch_xla/csrc/ops/topk.cpp | 50 +++++++++++++++++++++++++++++++ torch_xla/csrc/ops/topk.h | 35 ++++++++++++++++++++++ torch_xla/csrc/tensor.cpp | 10 +++++++ torch_xla/csrc/tensor.h | 4 +++ torch_xla/csrc/xla_lower_util.cpp | 42 +++++++++++++++++++++++--- torch_xla/csrc/xla_lower_util.h | 3 ++ 9 files changed, 175 insertions(+), 4 deletions(-) create mode 100644 torch_xla/csrc/ops/topk.cpp create mode 100644 torch_xla/csrc/ops/topk.h diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 7b909d082dc8..850639dae641 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -393,6 +393,23 @@ TEST_F(AtenXlaTensorTest, TestKthValue) { } } +TEST_F(AtenXlaTensorTest, TestTopK) { + at::Tensor a = at::rand({4, 5, 3}, at::TensorOptions(at::kFloat)); + for (int k = 1; k <= 3; ++k) { + for (int dim = 0; dim < 3; ++dim) { + for (bool largest : {false, true}) { + auto b = at::topk(a, k, dim, largest, /*sorted=*/true); + ForEachDevice([&](const Device& device) { + at::Tensor xla_a = bridge::CreateXlaTensor(a, device); + auto xla_b = at::topk(xla_a, k, dim, largest, /*sorted=*/true); + AllClose(std::get<0>(b), std::get<0>(xla_b)); + AllClose(std::get<1>(b), std::get<1>(xla_b)); + }); + } + } + } +} + TEST_F(AtenXlaTensorTest, TestMin) { at::Tensor a = at::rand({2, 2}, at::TensorOptions(at::kFloat)); at::Tensor b = at::rand({2, 2}, at::TensorOptions(at::kFloat)); diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 632449b3e325..528f09dbb393 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -675,6 +675,20 @@ std::tuple AtenXlaType::kthvalue(const at::Tensor& self, bridge::AtenFromXlaTensor(std::get<1>(results))); } +std::tuple AtenXlaType::topk(const at::Tensor& self, + int64_t k, int64_t dim, + bool largest, + bool sorted) const { + // TODO: Implement the non default not-sorted topk on the XLA side. + if (!sorted) { + return AtenXlaTypeBase::topk(self, k, dim, largest, sorted); + } + auto results = + XLATensor::topk(bridge::GetXlaTensor(self), k, dim, largest, sorted); + return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), + bridge::AtenFromXlaTensor(std::get<1>(results))); +} + at::Tensor AtenXlaType::embedding(const at::Tensor& weight, const at::Tensor& indices, int64_t padding_idx, bool scale_grad_by_freq, diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index 8bcddc534881..a4850b1b3e71 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -236,6 +236,10 @@ class AtenXlaType : public AtenXlaTypeBase { int64_t dim, bool keepdim) const override; + std::tuple topk(const at::Tensor& self, int64_t k, + int64_t dim, bool largest, + bool sorted) const override; + at::Tensor embedding(const at::Tensor& weight, const at::Tensor& indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) const override; diff --git a/torch_xla/csrc/ops/topk.cpp b/torch_xla/csrc/ops/topk.cpp new file mode 100644 index 000000000000..91ab764a20ca --- /dev/null +++ b/torch_xla/csrc/ops/topk.cpp @@ -0,0 +1,50 @@ +#include "torch_xla/csrc/ops/topk.h" + +#include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" +#include "torch_xla/csrc/xla_lower_util.h" + +namespace torch_xla { +namespace ir { +namespace ops { +namespace { + +xla::Shape NodeOutputShape(const Value& input, xla::int64 k, xla::int64 dim, + bool largest, bool sorted) { + auto lower_for_shape_fn = + [&](tensorflow::gtl::ArraySlice operands) + -> xla::XlaOp { + return xla::Tuple(operands[0].builder(), + CreateTopK(operands[0], k, dim, largest, sorted)); + }; + return InferOutputShape({input.shape()}, lower_for_shape_fn); +} + +} // namespace + +TopK::TopK(const Value& input, xla::int64 k, xla::int64 dim, bool largest, + bool sorted) + : Node(ir::OpKind(at::aten::topk), {input}, + NodeOutputShape(input, k, dim, largest, sorted), + /*num_outputs=*/2, xla::util::MHash(k, dim, largest, sorted)), + k_(k), + dim_(dim), + largest_(largest), + sorted_(sorted) {} + +XlaOpVector TopK::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + return ReturnOps(CreateTopK(input, k_, dim_, largest_, sorted_), loctx); +} + +std::string TopK::ToString() const { + std::stringstream ss; + ss << Node::ToString() << ", k=" << k_ << ", dim=" << dim_ + << ", largest=" << largest_ << ", sorted=" << sorted_; + return ss.str(); +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/topk.h b/torch_xla/csrc/ops/topk.h new file mode 100644 index 000000000000..2c694d806227 --- /dev/null +++ b/torch_xla/csrc/ops/topk.h @@ -0,0 +1,35 @@ +#pragma once + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +class TopK : public Node { + public: + TopK(const Value& input, xla::int64 k, xla::int64 dim, bool largest, + bool sorted); + + std::string ToString() const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + xla::int64 k() const { return k_; }; + + xla::int64 dim() const { return dim_; }; + + bool largest() const { return largest_; } + + bool sorted() const { return sorted_; } + + private: + xla::int64 k_; + xla::int64 dim_; + bool largest_; + bool sorted_; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index b528d7f91c55..f8e5c6bcb303 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -67,6 +67,7 @@ #include "torch_xla/csrc/ops/sum.h" #include "torch_xla/csrc/ops/threshold.h" #include "torch_xla/csrc/ops/threshold_backward.h" +#include "torch_xla/csrc/ops/topk.h" #include "torch_xla/csrc/ops/tril.h" #include "torch_xla/csrc/ops/triu.h" #include "torch_xla/csrc/ops/unsqueeze.h" @@ -919,6 +920,15 @@ std::tuple XLATensor::kthvalue(const XLATensor& input, Create(ir::Value(node, 1), input.GetDevice())); } +std::tuple XLATensor::topk(const XLATensor& input, + xla::int64 k, xla::int64 dim, + bool largest, bool sorted) { + ir::NodePtr node = + ir::MakeNode(input.GetIrValue(), k, dim, largest, sorted); + return std::make_tuple(Create(ir::Value(node, 0), input.GetDevice()), + Create(ir::Value(node, 1), input.GetDevice())); +} + XLATensor XLATensor::dropout(const XLATensor& input, double p) { return Create(ir::MakeNode(input.GetIrValue(), p), input.GetDevice()); diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index de222b75ad87..e818d3f04e21 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -272,6 +272,10 @@ class XLATensor { xla::int64 k, xla::int64 dim, bool keepdim); + static std::tuple topk(const XLATensor& input, + xla::int64 k, xla::int64 dim, + bool largest, bool sorted); + static XLATensor dropout(const XLATensor& input, double p); static XLATensor neg(const XLATensor& input); diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 26a595e74ffb..f62042e85eb2 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -81,8 +81,6 @@ std::vector CreateKthValue(const xla::XlaOp& input, xla::int64 k, xla::Shape iota_shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, shape.dimensions()); xla::XlaOp iota = xla::Iota(input.builder(), iota_shape, dim); - // TODO: Remember to add is_stable=true as last Sort() argument when fetching - // the new TF head. xla::XlaOp sort_result = xla::Sort( {input, iota}, xla::CreateScalarLtComputation( @@ -101,8 +99,7 @@ std::vector CreateKthValue(const xla::XlaOp& input, xla::int64 k, xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1), start_indices, limit_indices, strides); if (!keepdim) { - auto reshape_sizes = - XlaHelpers::DropDimensions(shape.dimensions(), {dim}); + auto reshape_sizes = XlaHelpers::DropDimensions(shape.dimensions(), {dim}); values = xla::Reshape(values, reshape_sizes); indices = xla::Reshape(indices, reshape_sizes); } @@ -110,6 +107,43 @@ std::vector CreateKthValue(const xla::XlaOp& input, xla::int64 k, return {values, xla::ConvertElementType(indices, xla::PrimitiveType::S64)}; } +std::vector CreateTopK(const xla::XlaOp& input, xla::int64 k, + xla::int64 dim, bool largest, bool sorted) { + // TODO: Implement the no sorted topk, which means emit winning K elements in + // native order. + XLA_CHECK(sorted) << "Not sorted CreateTopK() not implemented"; + + auto identity = [](const xla::XlaOp& op) -> xla::XlaOp { return op; }; + auto neg = [](const xla::XlaOp& op) -> xla::XlaOp { return xla::Neg(op); }; + auto input_transform = largest ? neg : identity; + + // Here 'k' is 1 based (1...). + xla::Shape shape = XlaHelpers::ShapeOfXlaOp(input); + XLA_CHECK_LE(k, shape.dimensions(dim)); + xla::Shape iota_shape = + xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, shape.dimensions()); + xla::XlaOp iota = xla::Iota(input.builder(), iota_shape, dim); + xla::XlaOp sort_result = xla::Sort( + {input_transform(input), iota}, + xla::CreateScalarLtComputation( + {shape.element_type(), xla::PrimitiveType::S32}, input.builder()), + dim); + + std::vector start_indices(shape.rank(), 0); + std::vector limit_indices(shape.dimensions().begin(), + shape.dimensions().end()); + limit_indices[dim] = k; + std::vector strides(shape.rank(), 1); + + xla::XlaOp values = + input_transform(xla::Slice(xla::GetTupleElement(sort_result, 0), + start_indices, limit_indices, strides)); + xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1), + start_indices, limit_indices, strides); + // aten::kthvalue() wants Long tensors as indices. + return {values, xla::ConvertElementType(indices, xla::PrimitiveType::S64)}; +} + xla::XlaOp CreateMatMul(const xla::XlaOp& lhs, const xla::XlaOp& rhs) { const auto precision_level = XlaHelpers::mat_mul_precision(); xla::PrecisionConfig precision_config = diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index 1cf36c9413f2..421a8d6c113a 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -8,6 +8,9 @@ namespace torch_xla { std::vector CreateKthValue(const xla::XlaOp& input, xla::int64 k, xla::int64 dim, bool keepdim); +std::vector CreateTopK(const xla::XlaOp& input, xla::int64 k, + xla::int64 dim, bool largest, bool sorted); + xla::XlaOp CreateMatMul(const xla::XlaOp& lhs, const xla::XlaOp& rhs); xla::XlaOp BuildDropout(const xla::XlaOp& input, float probability);