Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,40 @@ TEST_F(AtenXlaTensorTest, TestIntegerAdd) {
});
}

TEST_F(AtenXlaTensorTest, TestKthValue) {
at::Tensor a = at::rand({4, 5, 3}, at::TensorOptions(at::kFloat));
for (int k = 1; k <= 3; ++k) {
for (int dim = 0; dim < 3; ++dim) {
for (bool keepdim : {false, true}) {
auto b = at::kthvalue(a, k, dim, keepdim);
ForEachDevice([&](const Device& device) {
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
auto xla_b = at::kthvalue(xla_a, k, dim, keepdim);
AllClose(std::get<0>(b), std::get<0>(xla_b));
AllClose(std::get<1>(b), std::get<1>(xla_b));
});
}
}
}
}

TEST_F(AtenXlaTensorTest, TestTopK) {
at::Tensor a = at::rand({4, 5, 3}, at::TensorOptions(at::kFloat));
for (int k = 1; k <= 3; ++k) {
for (int dim = 0; dim < 3; ++dim) {
for (bool largest : {false, true}) {
auto b = at::topk(a, k, dim, largest, /*sorted=*/true);
ForEachDevice([&](const Device& device) {
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
auto xla_b = at::topk(xla_a, k, dim, largest, /*sorted=*/true);
AllClose(std::get<0>(b), std::get<0>(xla_b));
AllClose(std::get<1>(b), std::get<1>(xla_b));
});
}
}
}
}

TEST_F(AtenXlaTensorTest, TestMin) {
at::Tensor a = at::rand({2, 2}, at::TensorOptions(at::kFloat));
at::Tensor b = at::rand({2, 2}, at::TensorOptions(at::kFloat));
Expand Down
23 changes: 23 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,29 @@ int64_t AtenXlaType::size(const at::Tensor& self, int64_t dim) const {
return bridge::GetXlaTensor(self).size(dim);
}

std::tuple<at::Tensor, at::Tensor> AtenXlaType::kthvalue(const at::Tensor& self,
int64_t k, int64_t dim,
bool keepdim) const {
auto results =
XLATensor::kthvalue(bridge::GetXlaTensor(self), k, dim, keepdim);
return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)),
bridge::AtenFromXlaTensor(std::get<1>(results)));
}

std::tuple<at::Tensor, at::Tensor> AtenXlaType::topk(const at::Tensor& self,
int64_t k, int64_t dim,
bool largest,
bool sorted) const {
// TODO: Implement the non default not-sorted topk on the XLA side.
if (!sorted) {
return AtenXlaTypeBase::topk(self, k, dim, largest, sorted);
}
auto results =
XLATensor::topk(bridge::GetXlaTensor(self), k, dim, largest, sorted);
return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)),
bridge::AtenFromXlaTensor(std::get<1>(results)));
}

at::Tensor AtenXlaType::embedding(const at::Tensor& weight,
const at::Tensor& indices,
int64_t padding_idx, bool scale_grad_by_freq,
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,14 @@ class AtenXlaType : public AtenXlaTypeBase {

int64_t size(const at::Tensor& self, int64_t dim) const override;

std::tuple<at::Tensor, at::Tensor> kthvalue(const at::Tensor& self, int64_t k,
int64_t dim,
bool keepdim) const override;

std::tuple<at::Tensor, at::Tensor> topk(const at::Tensor& self, int64_t k,
int64_t dim, bool largest,
bool sorted) const override;

at::Tensor embedding(const at::Tensor& weight, const at::Tensor& indices,
int64_t padding_idx, bool scale_grad_by_freq,
bool sparse) const override;
Expand Down
49 changes: 49 additions & 0 deletions torch_xla/csrc/ops/kth_value.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#include "torch_xla/csrc/ops/kth_value.h"

#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/xla_lower_util.h"

namespace torch_xla {
namespace ir {
namespace ops {
namespace {

xla::Shape NodeOutputShape(const Value& input, xla::int64 k, xla::int64 dim,
bool keepdim) {
auto lower_for_shape_fn =
[&](tensorflow::gtl::ArraySlice<const xla::XlaOp> operands)
-> xla::XlaOp {
return xla::Tuple(operands[0].builder(),
CreateKthValue(operands[0], k, dim, keepdim));
};
return InferOutputShape({input.shape()}, lower_for_shape_fn);
}

} // namespace

KthValue::KthValue(const Value& input, xla::int64 k, xla::int64 dim,
bool keepdim)
: Node(ir::OpKind(at::aten::kthvalue), {input},
NodeOutputShape(input, k, dim, keepdim),
/*num_outputs=*/2, xla::util::MHash(k, dim, keepdim)),
k_(k),
dim_(dim),
keepdim_(keepdim) {}

XlaOpVector KthValue::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
return ReturnOps(CreateKthValue(input, k_, dim_, keepdim_), loctx);
}

std::string KthValue::ToString() const {
std::stringstream ss;
ss << Node::ToString() << ", k=" << k_ << ", dim=" << dim_
<< ", keepdim=" << keepdim_;
return ss.str();
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
31 changes: 31 additions & 0 deletions torch_xla/csrc/ops/kth_value.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#pragma once

#include "torch_xla/csrc/ir.h"

namespace torch_xla {
namespace ir {
namespace ops {

class KthValue : public Node {
public:
KthValue(const Value& input, xla::int64 k, xla::int64 dim, bool keepdim);

std::string ToString() const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

xla::int64 k() const { return k_; };

xla::int64 dim() const { return dim_; };

bool keepdim() const { return keepdim_; }

private:
xla::int64 k_;
xla::int64 dim_;
bool keepdim_;
};

} // namespace ops
} // namespace ir
} // namespace torch_xla
50 changes: 50 additions & 0 deletions torch_xla/csrc/ops/topk.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include "torch_xla/csrc/ops/topk.h"

#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/xla_lower_util.h"

namespace torch_xla {
namespace ir {
namespace ops {
namespace {

xla::Shape NodeOutputShape(const Value& input, xla::int64 k, xla::int64 dim,
bool largest, bool sorted) {
auto lower_for_shape_fn =
[&](tensorflow::gtl::ArraySlice<const xla::XlaOp> operands)
-> xla::XlaOp {
return xla::Tuple(operands[0].builder(),
CreateTopK(operands[0], k, dim, largest, sorted));
};
return InferOutputShape({input.shape()}, lower_for_shape_fn);
}

} // namespace

TopK::TopK(const Value& input, xla::int64 k, xla::int64 dim, bool largest,
bool sorted)
: Node(ir::OpKind(at::aten::topk), {input},
NodeOutputShape(input, k, dim, largest, sorted),
/*num_outputs=*/2, xla::util::MHash(k, dim, largest, sorted)),
k_(k),
dim_(dim),
largest_(largest),
sorted_(sorted) {}

XlaOpVector TopK::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
return ReturnOps(CreateTopK(input, k_, dim_, largest_, sorted_), loctx);
}

std::string TopK::ToString() const {
std::stringstream ss;
ss << Node::ToString() << ", k=" << k_ << ", dim=" << dim_
<< ", largest=" << largest_ << ", sorted=" << sorted_;
return ss.str();
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
35 changes: 35 additions & 0 deletions torch_xla/csrc/ops/topk.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once

#include "torch_xla/csrc/ir.h"

namespace torch_xla {
namespace ir {
namespace ops {

class TopK : public Node {
public:
TopK(const Value& input, xla::int64 k, xla::int64 dim, bool largest,
bool sorted);

std::string ToString() const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

xla::int64 k() const { return k_; };

xla::int64 dim() const { return dim_; };

bool largest() const { return largest_; }

bool sorted() const { return sorted_; }

private:
xla::int64 k_;
xla::int64 dim_;
bool largest_;
bool sorted_;
};

} // namespace ops
} // namespace ir
} // namespace torch_xla
21 changes: 21 additions & 0 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "torch_xla/csrc/ops/generic.h"
#include "torch_xla/csrc/ops/index_select.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/ops/kth_value.h"
#include "torch_xla/csrc/ops/leaky_relu.h"
#include "torch_xla/csrc/ops/log_softmax.h"
#include "torch_xla/csrc/ops/log_softmax_backward.h"
Expand All @@ -66,6 +67,7 @@
#include "torch_xla/csrc/ops/sum.h"
#include "torch_xla/csrc/ops/threshold.h"
#include "torch_xla/csrc/ops/threshold_backward.h"
#include "torch_xla/csrc/ops/topk.h"
#include "torch_xla/csrc/ops/tril.h"
#include "torch_xla/csrc/ops/triu.h"
#include "torch_xla/csrc/ops/unsqueeze.h"
Expand Down Expand Up @@ -908,6 +910,25 @@ XLATensor XLATensor::select(const XLATensor& input, int64_t dim,
input.GetDevice());
}

std::tuple<XLATensor, XLATensor> XLATensor::kthvalue(const XLATensor& input,
xla::int64 k,
xla::int64 dim,
bool keepdim) {
ir::NodePtr node =
ir::MakeNode<ir::ops::KthValue>(input.GetIrValue(), k, dim, keepdim);
return std::make_tuple(Create(ir::Value(node, 0), input.GetDevice()),
Create(ir::Value(node, 1), input.GetDevice()));
}

std::tuple<XLATensor, XLATensor> XLATensor::topk(const XLATensor& input,
xla::int64 k, xla::int64 dim,
bool largest, bool sorted) {
ir::NodePtr node =
ir::MakeNode<ir::ops::TopK>(input.GetIrValue(), k, dim, largest, sorted);
return std::make_tuple(Create(ir::Value(node, 0), input.GetDevice()),
Create(ir::Value(node, 1), input.GetDevice()));
}

XLATensor XLATensor::dropout(const XLATensor& input, double p) {
return Create(ir::MakeNode<ir::ops::Dropout>(input.GetIrValue(), p),
input.GetDevice());
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,14 @@ class XLATensor {

static XLATensor select(const XLATensor& input, int64_t dim, int64_t index);

static std::tuple<XLATensor, XLATensor> kthvalue(const XLATensor& input,
xla::int64 k, xla::int64 dim,
bool keepdim);

static std::tuple<XLATensor, XLATensor> topk(const XLATensor& input,
xla::int64 k, xla::int64 dim,
bool largest, bool sorted);

static XLATensor dropout(const XLATensor& input, double p);

static XLATensor neg(const XLATensor& input);
Expand Down
72 changes: 72 additions & 0 deletions torch_xla/csrc/xla_lower_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <algorithm>
#include <vector>

#include "tensorflow/compiler/xla/client/lib/comparators.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
Expand Down Expand Up @@ -72,6 +73,77 @@ std::pair<xla::XlaOp, xla::XlaOp> DotBroadcast(const xla::XlaOp& lhs,

} // namespace

std::vector<xla::XlaOp> CreateKthValue(const xla::XlaOp& input, xla::int64 k,
xla::int64 dim, bool keepdim) {
// Here 'k' is 1 based (1...).
xla::Shape shape = XlaHelpers::ShapeOfXlaOp(input);
XLA_CHECK_LE(k, shape.dimensions(dim));
xla::Shape iota_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, shape.dimensions());
xla::XlaOp iota = xla::Iota(input.builder(), iota_shape, dim);
xla::XlaOp sort_result = xla::Sort(
{input, iota},
xla::CreateScalarLtComputation(
{shape.element_type(), xla::PrimitiveType::S32}, input.builder()),
dim);

std::vector<xla::int64> start_indices(shape.rank(), 0);
start_indices[dim] = k - 1;
std::vector<xla::int64> limit_indices(shape.dimensions().begin(),
shape.dimensions().end());
limit_indices[dim] = k;
std::vector<xla::int64> strides(shape.rank(), 1);

xla::XlaOp values = xla::Slice(xla::GetTupleElement(sort_result, 0),
start_indices, limit_indices, strides);
xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1),
start_indices, limit_indices, strides);
if (!keepdim) {
auto reshape_sizes = XlaHelpers::DropDimensions(shape.dimensions(), {dim});
values = xla::Reshape(values, reshape_sizes);
indices = xla::Reshape(indices, reshape_sizes);
}
// aten::kthvalue() wants Long tensors as indices.
return {values, xla::ConvertElementType(indices, xla::PrimitiveType::S64)};
}

std::vector<xla::XlaOp> CreateTopK(const xla::XlaOp& input, xla::int64 k,
xla::int64 dim, bool largest, bool sorted) {
// TODO: Implement the no sorted topk, which means emit winning K elements in
// native order.
XLA_CHECK(sorted) << "Not sorted CreateTopK() not implemented";

auto identity = [](const xla::XlaOp& op) -> xla::XlaOp { return op; };
auto neg = [](const xla::XlaOp& op) -> xla::XlaOp { return xla::Neg(op); };
auto input_transform = largest ? neg : identity;

// Here 'k' is 1 based (1...).
xla::Shape shape = XlaHelpers::ShapeOfXlaOp(input);
XLA_CHECK_LE(k, shape.dimensions(dim));
xla::Shape iota_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, shape.dimensions());
xla::XlaOp iota = xla::Iota(input.builder(), iota_shape, dim);
xla::XlaOp sort_result = xla::Sort(
{input_transform(input), iota},
xla::CreateScalarLtComputation(
{shape.element_type(), xla::PrimitiveType::S32}, input.builder()),
dim);

std::vector<xla::int64> start_indices(shape.rank(), 0);
std::vector<xla::int64> limit_indices(shape.dimensions().begin(),
shape.dimensions().end());
limit_indices[dim] = k;
std::vector<xla::int64> strides(shape.rank(), 1);

xla::XlaOp values =
input_transform(xla::Slice(xla::GetTupleElement(sort_result, 0),
start_indices, limit_indices, strides));
xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1),
start_indices, limit_indices, strides);
// aten::kthvalue() wants Long tensors as indices.
return {values, xla::ConvertElementType(indices, xla::PrimitiveType::S64)};
}

xla::XlaOp CreateMatMul(const xla::XlaOp& lhs, const xla::XlaOp& rhs) {
const auto precision_level = XlaHelpers::mat_mul_precision();
xla::PrecisionConfig precision_config =
Expand Down
Loading