diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 68209338071c..ebaae3277b40 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -1156,6 +1156,92 @@ TEST_F(AtenXlaTensorTest, TestPermute) { } } +TEST_F(AtenXlaTensorTest, TestTriu) { + int size = 5; + at::Tensor input = GetTestTensor({size, size}); + // Test all diagonals and out of bounds (must be no-op). + for (int diagonal = -size; diagonal <= size; ++diagonal) { + at::Tensor output = at::triu(input, diagonal); + ForEachDevice([&](const Device& device) { + at::Tensor xla_input = bridge::CreateXlaTensor(input, device); + at::Tensor xla_output = at::triu(xla_input, diagonal); + AllClose(output, xla_output); + }); + } +} + +TEST_F(AtenXlaTensorTest, TestTriuNonSquare) { + int size = 5; + at::Tensor input = GetTestTensor({size, size + 1}); + // Test all diagonals and out of bounds (must be no-op). + for (int diagonal = -size; diagonal <= size; ++diagonal) { + at::Tensor output = at::triu(input, diagonal); + ForEachDevice([&](const Device& device) { + at::Tensor xla_input = bridge::CreateXlaTensor(input, device); + at::Tensor xla_output = at::triu(xla_input, diagonal); + AllClose(output, xla_output); + }); + } +} + +TEST_F(AtenXlaTensorTest, TestTriuBatch) { + int size = 5; + int batch_size = 3; + at::Tensor input = GetTestTensor({batch_size, size, size}); + // Test all diagonals and out of bounds (must be no-op). + for (int diagonal = -size; diagonal <= size; ++diagonal) { + at::Tensor output = at::triu(input, diagonal); + ForEachDevice([&](const Device& device) { + at::Tensor xla_input = bridge::CreateXlaTensor(input, device); + at::Tensor xla_output = at::triu(xla_input, diagonal); + AllClose(output, xla_output); + }); + } +} + +TEST_F(AtenXlaTensorTest, TestTril) { + int size = 5; + at::Tensor input = GetTestTensor({size, size}); + // Test all diagonals and out of bounds (must be no-op). + for (int diagonal = -size; diagonal <= size; ++diagonal) { + at::Tensor output = at::tril(input, diagonal); + ForEachDevice([&](const Device& device) { + at::Tensor xla_input = bridge::CreateXlaTensor(input, device); + at::Tensor xla_output = at::tril(xla_input, diagonal); + AllClose(output, xla_output); + }); + } +} + +TEST_F(AtenXlaTensorTest, TestTrilNonSquare) { + int size = 5; + at::Tensor input = GetTestTensor({size, size + 1}); + // Test all diagonals and out of bounds (must be no-op). + for (int diagonal = -size; diagonal <= size; ++diagonal) { + at::Tensor output = at::tril(input, diagonal); + ForEachDevice([&](const Device& device) { + at::Tensor xla_input = bridge::CreateXlaTensor(input, device); + at::Tensor xla_output = at::tril(xla_input, diagonal); + AllClose(output, xla_output); + }); + } +} + +TEST_F(AtenXlaTensorTest, TestTrilBatch) { + int size = 5; + int batch_size = 3; + at::Tensor input = GetTestTensor({batch_size, size, size}); + // Test all diagonals and out of bounds (must be no-op). + for (int diagonal = -size; diagonal <= size; ++diagonal) { + at::Tensor output = at::tril(input, diagonal); + ForEachDevice([&](const Device& device) { + at::Tensor xla_input = bridge::CreateXlaTensor(input, device); + at::Tensor xla_output = at::tril(xla_input, diagonal); + AllClose(output, xla_output); + }); + } +} + TEST_F(AtenXlaTensorTest, TestAvgPool2DBackward) { int kernel_size = 2; for (int stride = 1; stride <= 2; ++stride) { diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index d2119e9196b6..a37bdf53a6e3 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -792,6 +792,16 @@ at::Tensor AtenXlaType::where(const at::Tensor& condition, bridge::GetXlaTensor(other))); } +at::Tensor AtenXlaType::triu(const at::Tensor& self, int64_t diagonal) const { + return bridge::AtenFromXlaTensor( + XLATensor::triu(bridge::GetXlaTensor(self), diagonal)); +} + +at::Tensor AtenXlaType::tril(const at::Tensor& self, int64_t diagonal) const { + return bridge::AtenFromXlaTensor( + XLATensor::tril(bridge::GetXlaTensor(self), diagonal)); +} + void AtenXlaType::SetFullConvPrecision( bool use_full_conv_precision /*= true*/) { s_use_full_conv_precision_ = use_full_conv_precision; diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index 0d3ae1094b24..c41f9d82765f 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -263,6 +263,10 @@ class AtenXlaType : public AtenXlaTypeBase { at::Tensor where(const at::Tensor& condition, const at::Tensor& self, const at::Tensor& other) const override; + at::Tensor triu(const at::Tensor& self, int64_t diagonal) const override; + + at::Tensor tril(const at::Tensor& self, int64_t diagonal) const override; + static void SetFullConvPrecision(bool use_full_conv_precision = true); // Registers the ATEN types for the XLA tensors. diff --git a/torch_xla/csrc/matrix.cpp b/torch_xla/csrc/matrix.cpp new file mode 100644 index 000000000000..42bccc389a9c --- /dev/null +++ b/torch_xla/csrc/matrix.cpp @@ -0,0 +1,17 @@ +#include "torch_xla/csrc/matrix.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" + +namespace torch_xla { + +xla::XlaOp BuildTriu(const xla::XlaOp& input, int diagonal) { + return xla::Select(xla::TriangleMask(input, diagonal - 1), + xla::ZerosLike(input), input); +} + +xla::XlaOp BuildTril(const xla::XlaOp& input, int diagonal) { + return xla::Select(xla::TriangleMask(input, diagonal), input, + xla::ZerosLike(input)); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/matrix.h b/torch_xla/csrc/matrix.h new file mode 100644 index 000000000000..e8dd00babd97 --- /dev/null +++ b/torch_xla/csrc/matrix.h @@ -0,0 +1,12 @@ +#pragma once + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "torch/csrc/jit/ir.h" + +namespace torch_xla { + +xla::XlaOp BuildTriu(const xla::XlaOp& input, int diagonal); + +xla::XlaOp BuildTril(const xla::XlaOp& input, int diagonal); + +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/tril.cpp b/torch_xla/csrc/ops/tril.cpp new file mode 100644 index 000000000000..baa95632bd07 --- /dev/null +++ b/torch_xla/csrc/ops/tril.cpp @@ -0,0 +1,29 @@ +#include "torch_xla/csrc/ops/tril.h" +#include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/matrix.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +Tril::Tril(const Value& input, xla::int64 diagonal) + : Node(ir::OpKind(at::aten::triu), {input}, input.shape(), + /*num_outputs=*/1, xla::util::MHash(diagonal)), + diagonal_(diagonal) {} + +XlaOpVector Tril::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + xla::XlaOp output = BuildTril(input, diagonal_); + return ReturnOp(output, loctx); +} + +std::string Tril::ToString() const { + std::stringstream ss; + ss << Node::ToString() << ", diagonal=" << diagonal_; + return ss.str(); +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/tril.h b/torch_xla/csrc/ops/tril.h new file mode 100644 index 000000000000..12198b4caf8b --- /dev/null +++ b/torch_xla/csrc/ops/tril.h @@ -0,0 +1,27 @@ +#pragma once + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +// Node for the lower triangular part of a matrix (2-D tensor) or batch of +// matrices input. +class Tril : public Node { + public: + Tril(const Value& input, xla::int64 diagonal); + + XlaOpVector Lower(LoweringContext* loctx) const override; + + std::string ToString() const override; + + xla::int64 diagonal() const { return diagonal_; } + + private: + xla::int64 diagonal_; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/triu.cpp b/torch_xla/csrc/ops/triu.cpp new file mode 100644 index 000000000000..bddddf670a02 --- /dev/null +++ b/torch_xla/csrc/ops/triu.cpp @@ -0,0 +1,29 @@ +#include "torch_xla/csrc/ops/triu.h" +#include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/matrix.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +Triu::Triu(const Value& input, xla::int64 diagonal) + : Node(ir::OpKind(at::aten::triu), {input}, input.shape(), + /*num_outputs=*/1, xla::util::MHash(diagonal)), + diagonal_(diagonal) {} + +XlaOpVector Triu::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + xla::XlaOp output = BuildTriu(input, diagonal_); + return ReturnOp(output, loctx); +} + +std::string Triu::ToString() const { + std::stringstream ss; + ss << Node::ToString() << ", diagonal=" << diagonal_; + return ss.str(); +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/triu.h b/torch_xla/csrc/ops/triu.h new file mode 100644 index 000000000000..59eb2b75b9bc --- /dev/null +++ b/torch_xla/csrc/ops/triu.h @@ -0,0 +1,27 @@ +#pragma once + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +// Node for the upper triangular part of a matrix (2-D tensor) or batch of +// matrices input. +class Triu : public Node { + public: + Triu(const Value& input, xla::int64 diagonal); + + XlaOpVector Lower(LoweringContext* loctx) const override; + + std::string ToString() const override; + + xla::int64 diagonal() const { return diagonal_; } + + private: + xla::int64 diagonal_; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 46765e5be8b6..4644f21f6e5d 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -49,6 +49,8 @@ #include "torch_xla/csrc/ops/squeeze.h" #include "torch_xla/csrc/ops/threshold.h" #include "torch_xla/csrc/ops/threshold_backward.h" +#include "torch_xla/csrc/ops/tril.h" +#include "torch_xla/csrc/ops/triu.h" #include "torch_xla/csrc/ops/unsqueeze.h" #include "torch_xla/csrc/ops/view.h" #include "torch_xla/csrc/tensor_util.h" @@ -950,6 +952,16 @@ XLATensor XLATensor::unsqueeze(const XLATensor& input, int dim) { input.GetDevice()); } +XLATensor XLATensor::triu(const XLATensor& input, xla::int64 diagonal) { + return Create(ir::MakeNode(input.GetIrValue(), diagonal), + input.GetDevice()); +} + +XLATensor XLATensor::tril(const XLATensor& input, xla::int64 diagonal) { + return Create(ir::MakeNode(input.GetIrValue(), diagonal), + input.GetDevice()); +} + XLATensor XLATensor::where(const XLATensor& condition, const XLATensor& input, const XLATensor& other) { return Create(ir::ops::Where(condition.GetIrValue(), input.GetIrValue(), diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index fa3c62846fb9..ec9f9356d8fc 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -315,11 +315,9 @@ class XLATensor { static XLATensor max(const XLATensor& input, const XLATensor& other); - static XLATensor argmax(const XLATensor& input, xla::int64 dim, - bool keepdim); + static XLATensor argmax(const XLATensor& input, xla::int64 dim, bool keepdim); - static XLATensor argmin(const XLATensor& input, xla::int64 dim, - bool keepdim); + static XLATensor argmin(const XLATensor& input, xla::int64 dim, bool keepdim); // Like batch_norm, but returns additional save_mean and save_invstd used by // the backward pass. @@ -349,6 +347,14 @@ class XLATensor { // Insert a dimension of size one at the specified position. static XLATensor unsqueeze(const XLATensor& input, int dim); + // Returns the upper triangular part of a matrix (2-D tensor) or batch of + // matrices input, the other elements of the result tensor out are set to 0. + static XLATensor triu(const XLATensor& input, xla::int64 diagonal); + + // Returns the lower triangular part of a matrix (2-D tensor) or batch of + // matrices input, the other elements of the result tensor out are set to 0. + static XLATensor tril(const XLATensor& input, xla::int64 diagonal); + static XLATensor where(const XLATensor& condition, const XLATensor& input, const XLATensor& other);