Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,20 @@ TEST_F(AtenXlaTensorTest, TestTopK) {
}
}

TEST_F(AtenXlaTensorTest, TestTopKSymIntStatic) {
torch::Tensor a = torch::rand({10, 10}, torch::TensorOptions(torch::kFloat));
auto results = torch::topk(a, 5);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
auto xla_results = torch::topk_symint(xla_a, c10::SymInt(5));
AllClose(std::get<0>(results), std::get<0>(xla_results));
AllClose(std::get<1>(results), std::get<1>(xla_results));
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::topk_symint", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestSort) {
torch::Tensor a = torch::rand({4, 5, 3}, torch::TensorOptions(torch::kFloat));
for (int k = 1; k <= 3; ++k) {
Expand Down
44 changes: 44 additions & 0 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,50 @@ def test_expand_symint_correctness(self):
self.assertEqual(t3.shape[0], 2)
self.assertEqual(expand_out_aten.cpu(), expand_out_xla.cpu())

def test_topk_symint_ir_1(self):
t1 = torch.zeros([5, 2], device=dev)
t1[3][0] = 1
t1[3][1] = 1
t2 = torch.nonzero(t1)
t3 = torch.zeros([10, 2], device=dev)
values, indices = torch.topk(t3, t2.shape[0], dim=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if your change can handle the case where t3 is dynamic, eg t3=torch.nonzero(t1).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this feature is already supported by xla through dynamic padder (static topk also works with dynamic t3). I'll add a test case for it later.

self.assertIsInstance(t2.shape[0], torch.SymInt)
self.assertIsInstance(values.shape[0], torch.SymInt)
self.assertIsInstance(indices.shape[0], torch.SymInt)
self.assertEqual(str(t2.shape[0]), '<=10')
self.assertEqual(str(values.shape[0]), '<=10')
self.assertEqual(str(indices.shape[0]), '<=10')
self.assertEqual(t3.shape[1], 2)
self.assertEqual(values.shape[1], 2)
self.assertEqual(indices.shape[1], 2)

def test_topk_symint_ir_2(self):
t1 = torch.ones(20, device=dev)
t1[::3] = 0
t2 = torch.ones(10, device=dev)
t2[::2] = 0
k = torch.nonzero(t2).shape[0]
values, indices = torch.topk(t1, k, dim=0)
self.assertIsInstance(values.shape[0], torch.SymInt)
self.assertIsInstance(indices.shape[0], torch.SymInt)
self.assertEqual(str(values.shape[0]), '<=10')
self.assertEqual(str(indices.shape[0]), '<=10')

def test_topk_symint_correctness(self):

def test_fn(*tensors):
torch.manual_seed(0)
x = torch.rand(10, 10).to(tensors[0].device)
results = []
for tensor in tensors:
k = torch.nonzero(tensor).shape[0]
for dim in range(2):
results += list(torch.topk(x, k=k, dim=dim))
return results

self.runAtenTest([torch.randint(0, 2, size=(10,)) for _ in range(5)],
test_fn)


if __name__ == '__main__':
assert os.environ['XLA_EXPERIMENTAL'] != ''
Expand Down
23 changes: 16 additions & 7 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2842,13 +2842,22 @@ at::Tensor XLANativeFunctions::threshold_backward(const at::Tensor& grad_output,
threshold.to<double>()));
}

std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::topk(
const at::Tensor& self, int64_t k, int64_t dim, bool largest, bool sorted) {
TORCH_LAZY_FN_COUNTER("xla::");
auto results = tensor_methods::topk(bridge::GetXlaTensor(self), k, dim,
largest, sorted, /*stable=*/false);
return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)),
bridge::AtenFromXlaTensor(std::get<1>(results)));
std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::topk_symint(
const at::Tensor& self, c10::SymInt k, int64_t dim, bool largest,
bool sorted) {
TORCH_LAZY_FN_COUNTER("xla::");
if (!k.is_symbolic()) {
auto results =
tensor_methods::topk(bridge::GetXlaTensor(self), k.expect_int(), dim,
largest, sorted, /*stable=*/false);
return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)),
bridge::AtenFromXlaTensor(std::get<1>(results)));
} else {
auto results = tensor_methods::topk_symint(
bridge::GetXlaTensor(self), k, dim, largest, sorted, /*stable=*/false);
return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)),
bridge::AtenFromXlaTensor(std::get<1>(results)));
}
}

at::Tensor XLANativeFunctions::trace(const at::Tensor& self) {
Expand Down
67 changes: 67 additions & 0 deletions torch_xla/csrc/ops/topk_symint.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include "torch_xla/csrc/ops/topk_symint.h"

#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/xla_lower_util.h"

namespace torch_xla {
namespace {

xla::Shape NodeOutputShapeSymInt(const torch::lazy::Value& input,
int64_t k_upper_bound, int64_t dim,
bool largest, bool sorted, bool stable) {
xla::Shape input_shape = GetXlaShape(input);
std::vector<int64_t> dimensions(input_shape.dimensions().begin(),
input_shape.dimensions().end());
XLA_CHECK_LT(dim, input_shape.rank());
dimensions[dim] = k_upper_bound;
xla::Shape values_shape =
xla::ShapeUtil::MakeShape(input_shape.element_type(), dimensions);
xla::Shape indices_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, dimensions);
values_shape.set_dynamic_dimension(dim, true);
indices_shape.set_dynamic_dimension(dim, true);
return xla::ShapeUtil::MakeTupleShape({values_shape, indices_shape});
}

} // namespace

TopKSymInt::TopKSymInt(const torch::lazy::Value& input, const SymIntElements& k,
int64_t dim, bool largest, bool sorted, bool stable)
: XlaNode(torch::lazy::OpKind(at::aten::topk),
{input, torch::lazy::Value(k.GetSizeNodes().front())},
[&]() {
return NodeOutputShapeSymInt(input, k.GetUpperBounds().front(),
dim, largest, sorted, stable);
},
/*num_outputs=*/2,
torch::lazy::MHash(k.GetUpperBounds().front(), dim, largest,
sorted, stable)),
k_upper_bound_(k.GetUpperBounds().front()),
dim_(dim),
largest_(largest),
sorted_(sorted),
stable_(stable) {}

XlaOpVector TopKSymInt::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp size_op = loctx->GetOutputOp(operand(1));
std::vector<xla::XlaOp> results =
CreateTopK(input, k_upper_bound_, dim_, largest_, stable_);
std::vector<xla::XlaOp> resized_results;
std::transform(
results.begin(), results.end(), std::back_inserter(resized_results),
[&](xla::XlaOp op) { return xla::SetDimensionSize(op, size_op, dim_); });
return ReturnOps(resized_results, loctx);
}

std::string TopKSymInt::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", k<=" << k_upper_bound_ << ", dim=" << dim_
<< ", largest=" << largest_ << ", sorted=" << sorted_
<< ", stable=" << stable_;
return ss.str();
}

} // namespace torch_xla
25 changes: 25 additions & 0 deletions torch_xla/csrc/ops/topk_symint.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/torch_util.h"

namespace torch_xla {

class TopKSymInt : public XlaNode {
public:
TopKSymInt(const torch::lazy::Value& input, const SymIntElements& k,
int64_t dim, bool largest, bool sorted, bool stable);

std::string ToString() const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

private:
int64_t k_upper_bound_;
int64_t dim_;
bool largest_;
bool sorted_;
bool stable_;
};

} // namespace torch_xla
15 changes: 15 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
#include "torch_xla/csrc/ops/threshold.h"
#include "torch_xla/csrc/ops/threshold_backward.h"
#include "torch_xla/csrc/ops/topk.h"
#include "torch_xla/csrc/ops/topk_symint.h"
#include "torch_xla/csrc/ops/triangular_solve.h"
#include "torch_xla/csrc/ops/uniform.h"
#include "torch_xla/csrc/ops/unsqueeze.h"
Expand Down Expand Up @@ -2505,6 +2506,20 @@ std::tuple<XLATensorPtr, XLATensorPtr> topk(const XLATensorPtr& input,
input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long));
}

std::tuple<XLATensorPtr, XLATensorPtr> topk_symint(const XLATensorPtr& input,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick question, now that we have a symint version of topk, do we still need the original topk?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I basically followed the expand_symint design, which keeps the static expand implementation. I guess the original topk implementation is still needed because the codegen for handling SymInt is slightly different (e.g. SetDimensionSize).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, let's keep the static version their.

c10::SymInt k, int64_t dim,
bool largest, bool sorted,
bool stable) {
SymIntElements k_symint = SymIntElements(k);
torch::lazy::NodePtr node = torch::lazy::MakeNode<TopKSymInt>(
input->GetIrValue(), k_symint,
torch::lazy::GetCanonicalDimensionIndex(dim, input->shape().get().rank()),
largest, sorted, stable);
return std::make_tuple(
input->CreateFrom(torch::lazy::Value(node, 0)),
input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some pytorch tests require the static topk returned indices type to be int64, otherwise it throws an error in pytorch here. I saw nonzero returns int32 indices. Why not just do a int64 (pytorch) -> int32 (xla) -> int64 (pytorch) conversion for handling indices?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so what happened here is that underlying XLA type will still be xla::s32, however we mark the logical element type will be at::ScalarType::Long. This will create a problem since underlying s32 can not represent all numbers for long. We run into issues when we try to cast this kind of tensor(long logical type, real s32 type) to other dtypes, because Cast op doesn't know how to handle this kind of cases.

The right thing to do is probably do a manual s64->s32->s64 cast instead of playing the trick of logical_element_type here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification! Do we want to add something like res = torch::lazy::MakeNode<Cast>(res, at::ScalarType::Long) here or use
xla::ConvertElementType(indices, xla::PrimitiveType::S64, /*device=*/nullptr)) in CreateTopK like before?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's handle it inside CreateTopK so it is cleaner.

}

XLATensorPtr trace(const XLATensorPtr& input) {
auto input_shape_ref = input->shape();
XLA_CHECK_EQ((*input_shape_ref).rank(), 2)
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,11 @@ std::tuple<XLATensorPtr, XLATensorPtr> topk(const XLATensorPtr& input,
bool largest, bool sorted,
bool stable);

std::tuple<XLATensorPtr, XLATensorPtr> topk_symint(const XLATensorPtr& input,
c10::SymInt k, int64_t dim,
bool largest, bool sorted,
bool stable);

// Returns the sum of the elements of the diagonal of the input 2-D matrix.
XLATensorPtr trace(const XLATensorPtr& input);

Expand Down
1 change: 1 addition & 0 deletions xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ symint:
- diagonal_backward
- narrow_copy
- select_backward
- topk
autograd:
- einsum
- max_pool2d
Expand Down