Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,42 @@ TEST_F(AtenXlaTensorTest, TestCholesky) {
}
}

TEST_F(AtenXlaTensorTest, TestTriangularSolve) {
static const int dims[] = {4, 7};
for (bool batched_a : {true, false}) {
for (bool batched_b : {true, false}) {
for (auto m : dims) {
for (auto n : dims) {
for (bool upper : {true, false}) {
for (bool transpose : {true, false}) {
for (bool unitriangular : {true, false}) {
at::Tensor a = at::randn({m, m}, at::TensorOptions(at::kFloat));
at::Tensor b = at::randn({m, n}, at::TensorOptions(at::kFloat));
a = batched_a ? a.expand({3, m, m}).clone() : a;
b = batched_b ? b.expand({3, m, n}).clone() : b;
auto result = at::triangular_solve(
b, a, /*upper=*/upper, /*transpose=*/transpose,
/*unitriangular=*/unitriangular);
ForEachDevice([&](const Device& device) {
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
at::Tensor xla_b = bridge::CreateXlaTensor(b, device);
auto xla_result = at::triangular_solve(
xla_b, xla_a, /*upper=*/upper, /*transpose=*/transpose,
/*unitriangular=*/unitriangular);
AllClose(std::get<0>(result), std::get<0>(xla_result),
/*rtol=*/1e-3, /*atol=*/1e-4);
AllClose(std::get<1>(result), std::get<1>(xla_result),
/*rtol=*/1e-3, /*atol=*/1e-4);
});
}
}
}
}
}
}
}
}

TEST_F(AtenXlaTensorTest, TestKthValue) {
at::Tensor a = at::rand({4, 5, 3}, at::TensorOptions(at::kFloat));
for (int k = 1; k <= 3; ++k) {
Expand Down
12 changes: 12 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2392,6 +2392,18 @@ at::Tensor& AtenXlaType::transpose_(at::Tensor& self, int64_t dim0,
return self;
}

std::tuple<at::Tensor, at::Tensor> AtenXlaType::triangular_solve(
const at::Tensor& b, const at::Tensor& A, bool upper, bool transpose,
bool unitriangular) const {
// Currently, ATen doesn't have a left_side option. Once this
// is added, this API will have to be changed.
auto results = XLATensor::triangular_solve(
bridge::GetXlaTensor(b), bridge::GetXlaTensor(A), /*left_side=*/true,
upper, transpose, unitriangular);
return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)),
bridge::AtenFromXlaTensor(std::get<1>(results)));
}

at::Tensor AtenXlaType::tril(const at::Tensor& self, int64_t diagonal) const {
return bridge::AtenFromXlaTensor(
XLATensor::tril(bridge::GetXlaTensor(self), diagonal));
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,10 @@ class AtenXlaType : public AtenXlaTypeBase {

at::Tensor trace(const at::Tensor& self) const override;

std::tuple<at::Tensor, at::Tensor> triangular_solve(
const at::Tensor& b, const at::Tensor& A, bool upper, bool transpose,
bool unitriangular) const override;

at::Tensor one_hot(const at::Tensor& self,
int64_t num_classes) const override;

Expand Down
102 changes: 102 additions & 0 deletions torch_xla/csrc/ops/triangular_solve.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#include "torch_xla/csrc/ops/triangular_solve.h"

#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/lowering_context.h"

namespace torch_xla {
namespace ir {
namespace ops {
namespace {

// This function plays two roles:
// - Computes the output shape.
// - Computes the broadcasted shape for the operands.
// NB: This currently infers the shape when left_side is true, as done in ATen.
std::pair<xla::Shape, xla::Shape> InferTriangularSolveShape(
const xla::Shape& rhs_shape, const xla::Shape& lhs_shape) {
// Obtain the number of right-hand sides, and dimension of the square matrix.
xla::int64 nrhs = rhs_shape.dimensions(rhs_shape.rank() - 1);
xla::int64 n = lhs_shape.dimensions(lhs_shape.rank() - 1);
xla::Shape rhs_batch_shape(rhs_shape);
xla::Shape lhs_batch_shape(lhs_shape);
rhs_batch_shape.DeleteDimension(rhs_batch_shape.rank() - 1);
lhs_batch_shape.DeleteDimension(lhs_batch_shape.rank() - 1);
// If the shapes match in the batch dimensions, then we don't need to get
// the promoted shape, and can directly add the trailing dimension.
if (xla::ShapeUtil::Compatible(lhs_batch_shape, rhs_batch_shape)) {
rhs_batch_shape.add_dimensions(nrhs);
lhs_batch_shape.add_dimensions(n);
return std::pair<xla::Shape, xla::Shape>(rhs_batch_shape, lhs_batch_shape);
}
// Obtain the promoted shapes and add back the trailing dimension.
xla::Shape rhs_batch_promoted_shape =
XlaHelpers::GetPromotedShape(rhs_batch_shape, lhs_batch_shape);
xla::Shape lhs_batch_promoted_shape(rhs_batch_promoted_shape);
rhs_batch_promoted_shape.add_dimensions(nrhs);
lhs_batch_promoted_shape.add_dimensions(n);
return std::pair<xla::Shape, xla::Shape>(rhs_batch_promoted_shape,
lhs_batch_promoted_shape);
}

std::vector<xla::XlaOp> LowerTriangularSolve(const xla::XlaOp& rhs,
const xla::XlaOp& lhs,
bool left_side, bool lower,
bool transpose,
bool unit_diagonal) {
xla::Shape rhs_shape = XlaHelpers::ShapeOfXlaOp(rhs);
xla::Shape lhs_shape = XlaHelpers::ShapeOfXlaOp(lhs);
std::pair<xla::Shape, xla::Shape> broadcasted_shapes =
InferTriangularSolveShape(rhs_shape, lhs_shape);
xla::XlaOp rhs_broadcasted =
XlaHelpers::ImplicitBroadcast(rhs, rhs_shape, broadcasted_shapes.first);
xla::XlaOp lhs_broadcasted =
XlaHelpers::ImplicitBroadcast(lhs, lhs_shape, broadcasted_shapes.second);

xla::XlaOp solution = xla::TriangularSolve(
lhs_broadcasted, rhs_broadcasted, left_side, lower, unit_diagonal,
transpose ? xla::TriangularSolveOptions::TRANSPOSE
: xla::TriangularSolveOptions::NO_TRANSPOSE);
return {solution, lhs_broadcasted};
}

xla::Shape NodeOutputShape(const Value& rhs, const Value& lhs) {
std::pair<xla::Shape, xla::Shape> broadcasted_shapes =
InferTriangularSolveShape(rhs.shape(), lhs.shape());
return xla::ShapeUtil::MakeTupleShape(
{broadcasted_shapes.first, broadcasted_shapes.second});
}

} // namespace

TriangularSolve::TriangularSolve(const Value& rhs, const Value& lhs,
bool left_side, bool lower, bool transpose,
bool unit_diagonal)
: Node(ir::OpKind(at::aten::triangular_solve), {rhs, lhs},
[&]() { return NodeOutputShape(rhs, lhs); },
/*num_outputs=*/2,
xla::util::MHash(left_side, lower, unit_diagonal, transpose)),
left_side_(left_side),
lower_(lower),
unit_diagonal_(unit_diagonal),
transpose_(transpose) {}

XlaOpVector TriangularSolve::Lower(LoweringContext* loctx) const {
xla::XlaOp rhs = loctx->GetOutputOp(operand(0));
xla::XlaOp lhs = loctx->GetOutputOp(operand(1));
return ReturnOps(LowerTriangularSolve(rhs, lhs, left_side_, lower_,
transpose_, unit_diagonal_),
loctx);
}

std::string TriangularSolve::ToString() const {
std::stringstream ss;
ss << Node::ToString() << ", left_side=" << left_side_ << ", lower=" << lower_
<< ", transpose=" << transpose_ << ", unit_diagonal=" << unit_diagonal_;
return ss.str();
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
35 changes: 35 additions & 0 deletions torch_xla/csrc/ops/triangular_solve.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once

#include "torch_xla/csrc/ir.h"

namespace torch_xla {
namespace ir {
namespace ops {

class TriangularSolve : public Node {
public:
TriangularSolve(const Value& rhs, const Value& lhs, bool left_side,
bool lower, bool transpose, bool unit_diagonal);

std::string ToString() const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

bool left_side() const { return left_side_; }

bool lower() const { return lower_; }

bool transpose() const { return transpose_; }

bool unit_diagonal() const { return unit_diagonal_; }

private:
bool left_side_;
bool lower_;
bool transpose_;
bool unit_diagonal_;
};

} // namespace ops
} // namespace ir
} // namespace torch_xla
4 changes: 4 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,10 @@ class XLATensor {
// In-place version of the method above.
static void transpose_(XLATensor& input, xla::int64 dim0, xla::int64 dim1);

static std::tuple<XLATensor, XLATensor> triangular_solve(
const XLATensor& rhs, const XLATensor& lhs, bool left_side, bool upper,
bool transpose, bool unitriangular);

// Returns the lower triangular part of a matrix (2-D tensor) or batch of
// matrices input, the other elements of the result tensor out are set to 0.
static XLATensor tril(const XLATensor& input, xla::int64 diagonal);
Expand Down
12 changes: 12 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
#include "torch_xla/csrc/ops/threshold.h"
#include "torch_xla/csrc/ops/threshold_backward.h"
#include "torch_xla/csrc/ops/topk.h"
#include "torch_xla/csrc/ops/triangular_solve.h"
#include "torch_xla/csrc/ops/tril.h"
#include "torch_xla/csrc/ops/triu.h"
#include "torch_xla/csrc/ops/unsqueeze.h"
Expand Down Expand Up @@ -1746,6 +1747,17 @@ void XLATensor::transpose_(XLATensor& input, xla::int64 dim0, xla::int64 dim1) {
input.SetIrValue(ir::ops::TransposeOp(input.GetIrValue(), dim0, dim1));
}

std::tuple<XLATensor, XLATensor> XLATensor::triangular_solve(
const XLATensor& rhs, const XLATensor& lhs, bool left_side, bool upper,
bool transpose, bool unitriangular) {
// TriangularSolve takes lower instead of upper, hence the negation.
ir::NodePtr node = ir::MakeNode<ir::ops::TriangularSolve>(
rhs.GetIrValue(), lhs.GetIrValue(), left_side, !upper, transpose,
unitriangular);
return std::make_tuple(rhs.CreateFrom(ir::Value(node, 0)),
rhs.CreateFrom(ir::Value(node, 1)));
}

XLATensor XLATensor::tril(const XLATensor& input, xla::int64 diagonal) {
return input.CreateFrom(
ir::MakeNode<ir::ops::Tril>(input.GetIrValue(), diagonal));
Expand Down